from dataclasses import dataclass as _dataclass
from dataclasses import field as _field
from typing import List as _List, Dict as _Dict
from datetime import date as _date
from ._inputfiles import InputFiles
from ._disease import Disease
from ._variableset import VariableSets, VariableSet
__all__ = ["Parameters", "get_repository_version", "get_repository"]
_default_folder_name = "parameters"
_repositories = {}
def get_repository(repository: str = None, error_on_missing=True):
"""Return the full path to the passed MetaWardsData repository.
This will return the default repository if None is passed
Returns a tuple of the repository path and version information
"""
import os
from pathlib import Path
if repository is None:
repository = os.getenv("METAWARDSDATA", None)
if repository is None:
repository = os.path.join(Path.home(),
"GitHub", "MetaWardsData")
repository = os.path.expanduser(os.path.expandvars(repository))
repository = str(Path(repository).absolute().resolve())
if not os.path.exists(repository) or not os.path.isdir(repository):
if error_on_missing:
raise FileNotFoundError(
f"Cannot find the MetaWardsData repository "
f"at {repository}. Please follow the instructions "
f"at https://metawards.org/model_data.html to download "
f"and install the model data.")
else:
return (None, None)
v = get_repository_version(repository)
return (repository, v)
def generate_repository_version(repository):
"""Try to run the './version' script within the passed repository,
to generate the required 'version.txt' file
"""
import subprocess
import os
if repository is not None:
script = os.path.join(repository, "version")
subprocess.run(script, cwd=repository, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL)
def get_repository_version(repository: str):
"""Read and return the Git version of the passed repository
Parameters
----------
repository: str
The full path to the repository whose version should be obtained.
If this is 'None' then the default repository will be used
($METAWARDSDATA or $HOME/GitHub/MetaWardsData)
Returns
-------
version_data: dict
A dictionary containing version information for the repository
"""
global _repositories
if repository is None:
return None
if repository in _repositories:
return _repositories[repository]
import os
filename = os.path.join(repository, "version.txt")
try:
with open(filename) as FILE:
import json
version = json.load(FILE)
_repositories[repository] = version
return version
except Exception:
pass
# could not get the version, so see if we have permission
# to run the 'version' program
try:
generate_repository_version(repository)
with open(filename) as FILE:
import json
version = json.load(FILE)
version["filepath"] = repository
_repositories[repository] = version
return version
except Exception:
from .utils._console import Console
Console.error(f"""
Could not find the repository version info in {filename}. Please make sure
that you have run './version' in that repository to generate the version
info.""")
_repositories[repository] = {"filepath": repository,
"repository": "unknown",
"version": "unknown",
"branch": "unknown",
"is_dirty": True}
return _repositories[repository]
[docs]@_dataclass
class Parameters:
"""The full set of Parameters that are used to control the model
outbreak over a Network. The combination of a Network and
a Parameters defines the model outbreak.
Load the Parameters using the Parameters.load function, and
then add extra data using the various "set" and "add" functions,
e.g.
Examples
--------
>>> params = Parameters.load("march29")
>>> params.set_disease("ncov")
>>> params.set_input_files("2011Data")
>>> params.add_seeds("ExtraSeedsBrighton.dat")
"""
#: The set of input files that define the model Network
input_files: InputFiles = None
#: The set of parameters that define the disease
disease_params: Disease = None
#: The set of files that contain additional seeds that
#: seed the outbreak during the model run
additional_seeds: _List[str] = None
#: The fraction of day considered "day" for work, e.g. 0.7 * 24 hours
length_day: float = 0.7
static_play_at_home: float = 0.0
dyn_play_at_home: float = 0.0
#: The cutoff distance in km beyond which workers or players cannot move
dyn_dist_cutoff: float = 10000000.0
play_to_work: float = 0.0
work_to_play: float = 0.0
#: proportion of daily imports if there are additional infections
daily_imports: float = 0.0
#: The index of the seeding ward if there are daily imports
ward_seed_index: int = None
#: The number of initial infections if there are daily imports
initial_inf: int = 5
#: how to treat the * state (stage 0). This should be a string
#: describing the method. Currently "R", "E" and "disable" are
#: supported. Not needed if the mapping is specified explicitly
#: in the disease
stage_0: str = "R"
#: Seasonality parameter
UV: float = 0.0
#: Date when transmission should be at a maximum
UV_max: _date = None
#: The global scale_uv. This is combined with the population and
#: per-ward level scale_uvs to give a single value
scale_uv: float = 1.0
#: The global background force of infection (FOI). This is combined
#: with the per-ward level bg_foi to give a single value
bg_foi: float = 0.0
#: User parameters
user_params: _Dict[str, float] = _field(default_factory=dict)
#: All of the VariableSet adjustments that have been applied
#: to these parameters
adjustments: _List[VariableSet] = None
_name: str = None
_version: str = None
_authors: str = None
_contacts: str = None
_references: str = None
_filename: str = None
_repository: str = None
_repository_version: str = None
_repository_branch: str = None
_repository_dir: str = None
#: The parameters for demographic sub-networks. If this is None then
#: the parameters are the same as the overall parameters
_subparams = None
[docs] def __str__(self):
parts = []
for key, value in [("Parameters", self._name),
("loaded_from", self._filename),
("repository", self._repository),
("repository_branch", self._repository_branch),
("repository_version", self._repository_version),
("length_day", self.length_day),
("initial_inf", self.initial_inf),
("static_play_at_home", self.static_play_at_home),
("dyn_play_at_home", self.dyn_play_at_home),
("dyn_dist_cutoff", self.dyn_play_at_home),
("play_to_work", self.play_to_work),
("work_to_play", self.work_to_play),
("daily_imports", self.daily_imports),
("UV", self.UV),
("UV_max", self.UV_max),
("scale_uv", self.scale_uv),
("bg_foi", self.bg_foi),
("stage_0", self.stage_0)]:
if value is not None:
parts.append(f"* {key}: {value}")
return "\n".join(parts)
[docs] @staticmethod
def default():
"""Return the default set of parameters"""
try:
(repository, v) = Parameters.get_repository()
repository_dir = repository
repository = v["repository"]
repository_branch = v["branch"]
repository_version = v["version"]
return Parameters(_repository=repository,
_repository_dir=repository_dir,
_repository_branch=repository_branch,
_repository_version=repository_version)
except Exception:
pass
return Parameters()
[docs] @staticmethod
def get_repository(repository: str = None):
"""Return the repository location and version information
for the passed repository
Parameters
----------
repository: str
Location on the filesystem of the repository. If this
is None then it will be searched for using first
the environment variable METAWARDSDATA, then
$HOME/GitHub/MetaWardsData, then ./METAWARDSDATA
Returns
-------
(repository, version): tuple
A tuple of the location on disk of the repository,
plus the version information (git ID etc)
"""
return get_repository(repository)
[docs] @staticmethod
def load(parameters: str = "march29",
repository: str = None,
folder: str = _default_folder_name,
filename: str = None):
"""This will return a Parameters object containing all of the
parameters loaded from the parameters found in file
f"{repository}/{folder}/{parameters}.json"
By default this will load the march29 parameters from
$HOME/GitHub/model_data/2011Data/parameters/march29.json
Alternatively, you can provide the exact path to the
filename via the 'filename' argument
Parameters
----------
parameters: str
The name of the parameters to load. This is the name that
will be searched for in the METAWARDSDATA parameters directory
repository: str
The location of the cloned METAWARDSDATA repository
folder: str
The name of the folder within the METAWARDSDATA repository
that contains the parameters
filename: str
The name of the file to load the parameters from - this directly
loads this file without searching through the METAWARDSDATA
repository
Returns
-------
params: Parameters
The constructed and validated parameters
"""
repository_version = None
repository_branch = None
repository_dir = None
if filename is None:
import os
(repository, v) = Parameters.get_repository(repository)
filename = os.path.join(repository, folder, f"{parameters}.json")
repository_dir = repository
repository = v["repository"]
repository_branch = v["branch"]
repository_version = v["version"]
json_file = filename
try:
with open(json_file, "r") as FILE:
import json
data = json.load(FILE)
except Exception as e:
from .utils._console import Console
Console.error(f"""
Could not find the parameters file {json_file}. Either it does not exist or
was corrupted. Error was {e.__class__} {e}. "Please see
https://metawards.org/model_data for instructions on how to download and
set the model data.""")
raise FileNotFoundError(f"Could not find or read {json_file}: "
f"{e.__class__} {e}")
par = Parameters(
length_day=data.get("length_day", 0.7),
initial_inf=data.get("initial_inf", 0),
static_play_at_home=data.get("static_play_at_home", 0.0),
dyn_play_at_home=data.get("dyn_play_at_home", 0.0),
dyn_dist_cutoff=data.get("dyn_dist_cutoff", 10000000.0),
play_to_work=data.get("play_to_work", 0.0),
work_to_play=data.get("work_to_play", 0.0),
daily_imports=data.get("daily_imports", 0),
UV=data.get("UV", 0.0),
UV_max=data.get("UV_max", None),
scale_uv=data.get("scale_uv", 1.0),
bg_foi=data.get("bg_foi", 0.0),
_name=data.get("name", parameters),
_authors=data.get("author(s)", "unknown"),
_version=data.get("version", "unknown"),
_contacts=data.get("contact(s)", "unknown"),
_references=data.get("reference(s)", "none"),
_filename=json_file,
_repository=repository,
_repository_dir=repository_dir,
_repository_branch=repository_branch,
_repository_version=repository_version
)
return par
[docs] def __getitem__(self, demographic: str):
"""Return the parameters that should be used for the demographic
subnetwork called 'demographic'. If these have not been set
specifically then the parameters for the overall network
are used
"""
if demographic == "overall":
return self
if self._subparams is None:
self._subparams = {}
if demographic not in self._subparams:
from copy import deepcopy
self._subparams[demographic] = deepcopy(self)
self._subparams[demographic]._subparams = {}
return self._subparams[demographic]
[docs] def copy(self, include_subparams: bool = False):
"""Return a safe copy of these parameters, which does not
include any subnetwork parameters if 'include_subparams' is False
"""
from copy import deepcopy
params = deepcopy(self)
if not include_subparams:
params._subparams = None
return params
[docs] def specialised_demographics(self) -> _List[str]:
"""Return the names of demographics that have specialised
parameters that are different to those of the overall
network
"""
if self._subparams is None:
return []
else:
return list(self._subparams.keys())
[docs] def add_seeds(self, filename: str):
"""Add an 'additional seeds' file that can be used to
seed wards with new infections at different times and
locations. Several additional_seed files can be added
Parameters
----------
filename: str
Name of the file containing the additional seeds
"""
# resolve the filename to the GitHub repo if possible...
if self.additional_seeds is None:
self.additional_seeds = []
import os
if not os.path.exists(filename):
if self._repository_dir is not None:
f = os.path.join(self._repository_dir, "extra_seeds", filename)
if os.path.exists(f):
filename = f
self.additional_seeds.append(filename)
[docs] def set_disease(self, disease: Disease, silent: bool = True):
""""Set the disease that will be modelled
Parameters:
disease: The disease to be modelled. If a string is passed
then the disease will be loaded using that string
silent: Whether or not to suppress printing out the disease
"""
if isinstance(disease, str):
disease = Disease.load(disease,
repository=self._repository_dir)
if not silent:
from .utils._console import Console
Console.print(disease, markdown=True)
from copy import deepcopy
self.disease_params = deepcopy(disease)
[docs] def set_variables(self, variables: VariableSet):
"""This function sets the adjustable variable values to those
specified in 'variables' in A COPY OF THIS PARAMETERS OBJECT.
This returns the copy. It does not change this object
Parameters
----------
variables: VariableSet
The variables that will be adjusted before the model run.
This adjusts the parameters and returns them in a deep copy
Returns
-------
params: Parameters
A copy of this set of parameters with the variables adjusted
"""
from copy import deepcopy
params = deepcopy(self)
if isinstance(variables, dict):
variables = VariableSet(variables)
variables.adjust(params)
return params
[docs] @staticmethod
def read_variables(filename: str, line_numbers: _List[int]):
"""Read in extra variable parameters from the specified line number(s)
of the specified file, returning the list
of the dictionaries of variables that have been
read. You can then apply those variable parameters
using the 'set_variables' function
Parameters
----------
filename: str
The file from which to read the adjustable variables
line_numbers: List[int]
All of the line numbers from which to read. If this is
None then all lines will be read.
Returns
-------
variables: VariableSets
The VariableSets containing all of the adjustable variables
"""
return VariableSets.read(filename, line_numbers)