from typing import List as _List
from typing import Union as _Union
from ..utils._get_functions import MetaFunction, accepts_stage
__all__ = ["iterate_custom",
           "build_custom_iterator"]
[docs]def build_custom_iterator(custom_function: _Union[str, MetaFunction],
                          parent_name="__main__") -> MetaFunction:
    """Build and return a custom iterator from the passed
       function. This will wrap 'iterate_custom' around
       the function to double-check that the custom
       function is doing everything correctly
       Parameters
       ----------
       custom_function
         This can either be a function, which will be wrapped and
         returned, or it can be a string. If it is a string then
         we will attempt to locate or import the function associated
         with that string. The search order is;
         1. Is this 'metawards.iterators.custom_function'?
         2. Is this 'custom_function' that is already imported'?
         3. Is this a file name in the current path, if yes then
            find the function in that file (either the first function
            called 'iterateXXX' or the specified function if
            custom_function is in the form module::function)
        parent_name: str
          This should be the __name__ of the calling function, e.g.
          call this function as build_custom_iterator(func, __name__)
        Returns
        -------
        iterator: MetaFunction
          The wrapped iterator
    """
    from ..utils._console import Console
    if isinstance(custom_function, str):
        Console.print(f"Importing a custom iterator from {custom_function}")
        # we need to find the function
        import metawards.iterators
        # is it metawards.iterators.{custom_function}
        try:
            func = getattr(metawards.iterators, custom_function)
            return build_custom_iterator(func)
        except Exception:
            pass
        # do we have the function in the current namespace?
        import sys
        try:
            func = getattr(sys.modules[__name__], custom_function)
            return build_custom_iterator(func)
        except Exception:
            pass
        # how about the __name__ namespace of the caller
        try:
            func = getattr(sys.modules[parent_name], custom_function)
            return build_custom_iterator(func)
        except Exception:
            pass
        # how about the __main__ namespace (e.g. if this was loaded
        # in a script)
        try:
            func = getattr(sys.modules["__main__"], custom_function)
            return build_custom_iterator(func)
        except Exception:
            pass
        # can we import this function as a file - need to check that
        # the user hasn't written this as module::function
        if custom_function.find("::") != -1:
            parts = custom_function.split("::")
            func_name = parts[-1]
            func_module = "::".join(parts[0:-1])
        else:
            func_name = None
            func_module = custom_function
        from ..utils._import_module import import_module
        module = import_module(func_module)
        if module is None:
            # we cannot find the iterator
            Console.error(
                f"Cannot find the iterator '{custom_function}'."
                f"Please make sure this is spelled correctly and "
                f"any python modules/files needed are in the "
                f"PYTHONPATH or current directory")
            raise ImportError(f"Could not import the iterator "
                              f"'{custom_function}'")
        if func_name is None:
            # find the last function that starts with 'iterate'
            import inspect
            funcs = []
            for name, value in inspect.getmembers(module):
                if name.startswith("iterate"):
                    if hasattr(value, "__call__"):
                        if value.__module__ == module.__name__:
                            # this is a function defined in this module
                            funcs.append(value)
            if len(funcs) > 0:
                func = funcs[0]
                if len(funcs) > 1:
                    Console.warning(
                        f"Multiple possible matching functions: {funcs}. "
                        f"Choosing {func}. Please use the module::function "
                        f"syntax if this is the wrong choice.")
            else:
                func = None
            if func is not None:
                return build_custom_iterator(func)
            Console.error(
                f"Could not find any function in the module "
                f"{custom_function} that has a name that starts "
                f"with 'iterate'. Please manually specify the "
                f"name using the '{custom_function}::your_function syntax")
            raise ImportError(f"Could not import the iterator "
                              f"{custom_function}")
        else:
            if hasattr(module, func_name):
                return build_custom_iterator(getattr(module, func_name))
            Console.error(
                f"Could not find the function {func_name} in the "
                f"module {func_module}. Check that the spelling "
                f"is correct and that the right version of the module "
                f"is being loaded.")
            raise ImportError(f"Could not import the iterator "
                              f"{custom_function}")
    if not hasattr(custom_function, "__call__"):
        Console.error(
            f"Cannot build an iterator for {custom_function} "
            f"as it is missing a __call__ function, i.e. it is "
            f"not a function.")
        raise ValueError(f"You can only build custom iterators for "
                         f"actual functions... {custom_function}")
    Console.print(f"Building a custom iterator for {custom_function}",
                  style="magenta")
    return lambda **kwargs: iterate_custom(custom_function=custom_function,
                                           **kwargs) 
[docs]def iterate_custom(custom_function: MetaFunction, stage: str,
                   **kwargs) -> _List[MetaFunction]:
    """This returns the default list of 'advance_XXX' functions that
       are called in sequence for each iteration of the model run.
       This iterator provides a custom iterator that uses
       'custom_function' passed from the user. This iterator makes
       sure that if 'stage' is not handled by the custom function,
       then the "iterate_default" functions for that stage
       are correctly called for all stages except "infect"
       Parameters
       ----------
       custom_function: MetaFunction
         A custom user-supplied function that returns the
         functions that the user would like to be called for
         each step.
       stage: str
         The stage of the day/model
       Returns
       -------
       funcs: List[MetaFunction]
         The list of functions that will be called in sequence
    """
    kwargs["stage"] = stage
    if custom_function is None:
        from ._iterate_default import iterate_default
        return iterate_default(**kwargs)
    elif stage == "infect" or accepts_stage(custom_function):
        # most custom functions operate at the 'infect' stage,
        # so iterators that don't specify a stage are assumed to
        # only operate here (every other stage is 'iterate_default')
        return custom_function(**kwargs)
    else:
        from ._iterate_default import iterate_default
        return iterate_default(**kwargs)