from typing import Union as _Union
from typing import Callable as _Callable
from typing import List as _List
from .._network import Network
from .._networks import Networks
from .._population import Population
from .._population import Populations
from .._infections import Infections
from .._outputfiles import OutputFiles
from .._workspace import Workspace
from ._profiler import Profiler
__all__ = ["get_functions", "get_model_loop_functions",
           "get_initialise_functions", "get_finalise_functions",
           "get_summary_functions",
           "accepts_stage", "MetaFunction",
           "call_function_on_network"]
MetaFunction = _Callable[..., None]
def accepts_stage(func: MetaFunction) -> bool:
    """Return whether the passed function accepts the "stage" argument,
       meaning that it can do different things for different day stages
       Parameters
       ----------
       func: function
         The function to be queries
       Returns
       -------
       result: bool
         Whether or not the function accepts the "stage" argument
    """
    import inspect
    try:
        return "stage" in inspect.signature(func).parameters
    except Exception as e:
        from ._console import Console
        Console.error(f"Could not find the signature for {func}. The error "
                      f"is {e.__class__}: {e}. This is likely because this "
                      f"is an in-built function compiled using cython. To "
                      f"fix this, make sure that the pyx file containing "
                      f"your cython starts with '# cython: binding=True`. "
                      f"This will switch on support in cython for adding "
                      f"signatures to compiled functions.")
        raise e
[docs]def get_functions(stage: str,
                  network: _Union[Network, Networks],
                  population: Population,
                  infections: Infections,
                  output_dir: OutputFiles,
                  workspace: Workspace,
                  iterator: MetaFunction,
                  extractor: MetaFunction,
                  mixer: MetaFunction,
                  mover: MetaFunction,
                  rngs, nthreads, profiler: Profiler,
                  trajectory: Populations = None,
                  results=None) -> _List[MetaFunction]:
    """Return the functions that must be called for the specified
       stage of the day;
       * "initialise": model initialisation. Called once before the
                       whole model run is performed
       * "setup": day setup. Called once at the start of the day.
                  Should be used to import new seeds, move populations
                  between demographics, move infected individuals
                  through disease stages etc. There is no calculation
                  performed at this stage.
       * "foi": foi calculation. Called to recalculate the force of infection
                (foi) for each ward in the network (and subnetworks).
                This is calculated based on the populations in each state
                in each ward in each demographic
       * "infect": Called to advance the outbreak by calculating
                   the number of new infections
       * "analyse": Called at the end of the day to analyse
                    the data and extract the results
       * "finalise": Called at the end of the model run to finalise
                     any outputs or produce overall summary files
       * "summary": Called at the end of lots of model runs, to write
                    overview summary files. Only the extractor has
                    a summary stage
       Parameters
       ----------
       stage: str
         The stage of the day/model for which to get the functions
       network: Network or Networks
         The network(s) to be modelled
       population: Population
         The population experiencing the outbreak
       infections: Infections
         Space to record the infections through the day
       iterator: function
         Iterator used to obtain the function used to advance
         the outbreak
       extractor: function
         Extractor used to analyse the data and output results
       mixer: function
         Mixer used to mix and merge data between different demographics
       mover: function
         Mover used to move populations between demographics
       rngs: list[random number generate pointers]
         Pointers to the random number generators to use for each thread
       nthreads: int
         The number of threads to use to run the model
       profiler: Profiler
         The profiler used to profile the calculation
       Returns
       -------
       functions: List[MetaFunction]
         The list of all functions that should be called for this
         stage of the day
    """
    stages = ["initialise", "setup", "foi", "infect",
              "analyse", "finalise", "summary"]
    if stage not in stages:
        raise ValueError(
            f"Cannot recognise the stage {stage}. Available stages "
            f"are {stages}")
    kwargs = {"stage": stage,
              "network": network,
              "population": population,
              "infections": infections,
              "rngs": rngs,
              "nthreads": nthreads,
              "profiler": profiler,
              "trajectory": trajectory,
              "results": results}
    if stage == "summary":
        funcs = extractor(**kwargs)
    else:
        funcs = mover(**kwargs) + iterator(**kwargs) + \
            
mixer(**kwargs) + extractor(**kwargs)
    return funcs 
[docs]def get_model_loop_functions(**kwargs) -> _List[MetaFunction]:
    """Convenience function that returns all of the functions
       that should be called during the model loop
       (i.e. the "setup", "foi", "infect" and "analyse" stages)
       Parameters
       ----------
       network: Network or Networks
         The network(s) to be modelled
       population: Population
         The population experiencing the outbreak
       infections: Infections
         Space to record the infections through the day
       iterator: function
         Iterator used to obtain the function used to advance
         the outbreak
       extractor: function
         Extractor used to analyse the data and output results
       mixer: function
         Mixer used to mix and merge data between different demographics
       mover: function
         Mover used to move populations between demographics
       rngs: list[random number generate pointers]
         Pointers to the random number generators to use for each thread
       nthreads: int
         The number of threads to use to run the model
       profiler: Profiler
         The profiler used to profile the calculation
       Returns
       -------
       functions: List[MetaFunction]
         The list of all functions that should be called for this
         stage of the day
    """
    funcs = get_functions(stage="setup", **kwargs)
    funcs += get_functions(stage="foi", **kwargs)
    funcs += get_functions(stage="infect", **kwargs)
    funcs += get_functions(stage="analyse", **kwargs)
    return funcs 
[docs]def get_initialise_functions(**kwargs) -> _List[MetaFunction]:
    """Convenience function that returns all of the functions
       that should be called during the initialisation step
       of the model (e.g. the "initialise" stage)
       Parameters
       ----------
       network: Network or Networks
         The network(s) to be modelled
       population: Population
         The population experiencing the outbreak
       infections: Infections
         Space to record the infections through the day
       iterator: function
         Iterator used to obtain the function used to advance
         the outbreak
       extractor: function
         Extractor used to analyse the data and output results
       mixer: function
         Mixer used to mix and merge data between different demographics
       mover: function
         Mover used to move populations between demographics
       rngs: list[random number generate pointers]
         Pointers to the random number generators to use for each thread
       nthreads: int
         The number of threads to use to run the model
       profiler: Profiler
         The profiler used to profile the calculation
       Returns
       -------
       functions: List[MetaFunction]
         The list of all functions that should be called for this
         stage of the day
    """
    return get_functions(stage="initialise", **kwargs) 
[docs]def get_finalise_functions(trajectory: Populations,
                           **kwargs) -> _List[MetaFunction]:
    """Convenience function that returns all of the functions
       that should be called during the finalisation step
       of the model (e.g. the "finalise" stage)
       Parameters
       ----------
       network: Network or Networks
         The network(s) to be modelled
       population: Population
         The population experiencing the outbreak
       trajectory: Populations
         The trajectory of populations over time
       infections: Infections
         Space to record the infections through the day
       iterator: function
         Iterator used to obtain the function used to advance
         the outbreak
       extractor: function
         Extractor used to analyse the data and output results
       mixer: function
         Mixer used to mix and merge data between different demographics
       mover: function
         Mover used to move populations between demographics
       rngs: list[random number generate pointers]
         Pointers to the random number generators to use for each thread
       nthreads: int
         The number of threads to use to run the model
       profiler: Profiler
         The profiler used to profile the calculation
       Returns
       -------
       functions: List[MetaFunction]
         The list of all functions that should be called for this
         stage of the day
    """
    return get_functions(stage="finalise", **kwargs) 
def get_summary_functions(network: _Union[Network, Networks],
                          results, output_dir: OutputFiles,
                          extractor: MetaFunction,
                          nthreads: int = 1, **kwargs
                          ):
    """Convenience function that returns all of the functions
       that should be called during the summary report
       stage of the simulation
    """
    kwargs["workspace"] = Workspace()
    kwargs["infections"] = Infections()
    kwargs["population"] = Population()
    kwargs["iterator"] = None
    kwargs["mixer"] = None
    kwargs["mover"] = None
    kwargs["rngs"] = None
    kwargs["profiler"] = None
    return get_functions(stage="summary", network=network,
                         output_dir=output_dir, results=results,
                         extractor=extractor, nthreads=nthreads, **kwargs)
[docs]def call_function_on_network(network: _Union[Network, Networks],
                             infections: Infections,
                             workspace: Workspace,
                             population: Population,
                             func: MetaFunction = None,
                             parallel: MetaFunction = None,
                             nthreads: int = 1,
                             switch_to_parallel: int = 2,
                             call_on_overall: bool = False,
                             **kwargs):
    """Call either 'func' or 'parallel' (depending on the
       number of threads, nthreads) on the passed Network,
       or on all demographic subnetworks
       Parameters
       ----------
       network: Network or Networks
         The network that is being modelled
       nthreads: int
         The number of threads to use to perform the work
       func: MetaFunction
         Function that performs the work in serial
       parallel: MetaFunction
         Function that performs the work in parallel
       switch_to_parallel: int
         Use the parallel function when nthreads is greater or equal
         to this value
    """
    if parallel is not None:
        if func is None or nthreads >= switch_to_parallel:
            kwargs["nthreads"] = nthreads
            func = parallel
    if isinstance(network, Networks):
        # call the function on all of the demographic sub-networks
        for i, subnet in enumerate(network.subnets):
            subinf = infections.subinfs[i]
            subwork = workspace.subspaces[i]
            subpop = population.subpops[i]
            func(network=subnet, infections=subinf,
                 population=subpop, workspace=subwork,
                 **kwargs)
        if call_on_overall:
            func(network=network.overall, infections=infections,
                 population=population, workspace=workspace,
                 **kwargs)
    else:
        func(network=network, infections=infections,
             population=population, workspace=workspace,
             **kwargs)