from typing import Union as _Union
from .._network import Network
from .._networks import Networks
from .._infections import Infections
from .._outputfiles import OutputFiles
from .._workspace import Workspace
from .._population import Population, Populations
from ._profiler import Profiler
from ._get_functions import get_initialise_functions, \
get_model_loop_functions, \
get_finalise_functions, \
MetaFunction, \
accepts_stage
__all__ = ["run_model"]
[docs]def run_model(network: _Union[Network, Networks],
infections: Infections,
rngs,
output_dir: OutputFiles,
population: Population = Population(initial=57104043),
nsteps: int = None,
profiler: Profiler = None,
nthreads: int = None,
iterator: _Union[str, MetaFunction] = None,
extractor: _Union[str, MetaFunction] = None,
mixer: _Union[str, MetaFunction] = None,
mover: _Union[str, MetaFunction] = None) -> Populations:
"""Actually run the model... Real work happens here. The model
will run until completion or until 'nsteps' have been
completed, whichever happens first.
Parameters
----------
network: Network or Networks
The network(s) on which to run the model
infections: Infections
The space used to record the infections
rngs: list
The list of random number generators to use, one per thread
population: Population
The initial population at the start of the model outbreak.
This is also used to set the date and day of the start of
the model outbreak
seed: int
The random number seed used for this model run. If this is
None then a very random random number seed will be used
output_dir: OutputFiles
The directory to write all of the output into
nsteps: int
The maximum number of steps to run in the outbreak. If None
then run until the outbreak has finished
profiler: Profiler
The profiler to use to profile - a new one is created if
one isn't passed
nthreads: int
Number of threads over which to parallelise this model run
iterator: MetaFunction or string
Function that will be used to dynamically get the functions
that will be used at each iteration to advance the
model. Any additional files or parameters needed by these
functions should be included in the `network.params` object.
extractor: MetaFunction or string
Function that will be used to dynamically get the functions
that will be used at each iteration to extract data from
the model run
mixer: MetaFunction or string
Function that will mix data from multiple demographics
so that this is shared during a model run
mover: MetaFunction or string
Function that can move the population between different
demographics
Returns
-------
trajectory: Populations
The trajectory of the population for every day of the model run
"""
if iterator is None:
from ..iterators._iterate_default import iterate_default
iterator = iterate_default
elif isinstance(iterator, str) or not accepts_stage(iterator):
from ..iterators._iterate_custom import build_custom_iterator
iterator = build_custom_iterator(iterator, __name__)
if extractor is None:
from ..extractors._extract_default import extract_default
extractor = extract_default
elif isinstance(extractor, str) or not accepts_stage(extractor):
from ..extractors._extract_custom import build_custom_extractor
extractor = build_custom_extractor(extractor, __name__)
if mixer is None:
from ..mixers._mix_default import mix_default
mixer = mix_default
elif isinstance(mixer, str) or not accepts_stage(mixer):
from ..mixers._mix_custom import build_custom_mixer
mixer = build_custom_mixer(mixer, __name__)
if mover is None:
from ..movers._move_default import move_default
mover = move_default
elif isinstance(mover, str) or not accepts_stage(mover):
from ..movers._move_custom import build_custom_mover
mover = build_custom_mover(mover, __name__)
if profiler is None:
from ._profiler import NullProfiler
profiler = NullProfiler()
p = profiler.start("run_model")
params = network.params
if params is None:
return population
from copy import deepcopy
population = deepcopy(population)
# create space to hold the population trajectory
trajectory = Populations()
p = p.start("clear_all_infections")
infections.clear(nthreads=nthreads)
p = p.stop()
# create a workspace that is used as part of the "analyse" stage to
# provide a scratch-pad while extracting data from the model
workspace = Workspace.build(network=network)
# get and call all of the functions that need to be called to
# initialise the model run
p = p.start("initialise_funcs")
funcs = get_initialise_functions(network=network, population=population,
infections=infections,
output_dir=output_dir,
workspace=workspace, rngs=rngs,
iterator=iterator, extractor=extractor,
mixer=mixer, mover=mover,
nthreads=nthreads, profiler=p)
# setup takes place on "day 0"
from ._console import Console
Console.rule(f"Day {population.day}", style="iteration")
for func in funcs:
p = p.start(str(func))
func(network=network, population=population,
infections=infections, output_dir=output_dir,
workspace=workspace, rngs=rngs, nthreads=nthreads,
profiler=p)
p = p.stop()
p = p.stop()
infecteds = population.infecteds
# save the initial population
trajectory.append(population)
p = p.start("run_model_loop")
iteration_count = 0
# keep looping until the outbreak is over or until we have completed
# at least 5 loop iterations
while (infecteds != 0) or (iteration_count < 5):
# construct a new profiler of the same type as 'profiler'
p2 = profiler.__class__()
# increment the day at the beginning, before anything happens.
# This way, the statistics for "day 1" are everything that
# happened since the end of day 0 and the end of day 1
population.increment_day()
p2 = p2.start(f"timing for day {population.day}")
Console.rule(f"Day {population.day}", style="iteration")
start_population = population.population
funcs = get_model_loop_functions(
network=network, population=population,
infections=infections,
output_dir=output_dir,
workspace=workspace, rngs=rngs,
iterator=iterator, extractor=extractor,
mixer=mixer, mover=mover,
nthreads=nthreads, profiler=p)
should_finish_early = False
for func in funcs:
p2 = p2.start(str(func))
try:
func(network=network, population=population,
infections=infections, output_dir=output_dir,
workspace=workspace, rngs=rngs, nthreads=nthreads,
profiler=p2)
except StopIteration:
# this function has signalled that the simulation
# should now stop - we record this request but will
# still let the other functions complete this
# iteration
Console.print(f"{func} has indicated that the model run "
f"should stop early. Will finish the run "
f"at the end of this iteration")
should_finish_early = True
p2 = p2.stop()
if population.population != start_population:
# something went wrong as the population should be conserved
# during the day
raise AssertionError(
f"The total population changed during the day. This "
f"should not happen and indicates a program bug. "
f"The starting population was {start_population}, "
f"while the end population is {population.population}. "
f"Detail is {population}")
infecteds = population.infecteds
Console.print(f"Number of infections: {infecteds}")
iteration_count += 1
p2 = p2.stop()
if not p2.is_null():
Console.print_profiler(p2)
# save the population trajectory
trajectory.append(population)
if should_finish_early:
Console.print(f"Exiting model run early due to function request")
break
elif nsteps is not None:
if iteration_count >= nsteps:
Console.print(
f"Exiting model run early as number of steps ({nsteps}) "
f"reached.")
break
# end of while loop
p = p.stop()
# finally get and call all of the functions needed to finalise
# the model run, e.g. closing files, performing overall analyses,
# writing summary files etc
p = p.start("finalise_funcs")
funcs = get_finalise_functions(network=network, population=population,
infections=infections,
output_dir=output_dir,
workspace=workspace, rngs=rngs,
iterator=iterator, extractor=extractor,
mixer=mixer, mover=mover,
nthreads=nthreads, trajectory=trajectory,
profiler=p)
for func in funcs:
p = p.start(str(func))
func(network=network, population=population,
infections=infections, output_dir=output_dir,
workspace=workspace, rngs=rngs, nthreads=nthreads,
trajectory=trajectory, profiler=p)
p = p.stop()
p = p.stop()
p.stop()
if not p.is_null():
Console.rule("Overall model timing")
Console.print_profiler(p)
Console.print(f"Ending on day {population.day}")
# only send back the overall statistics
return trajectory.strip_demographics()