from __future__ import annotations
from typing import List as _List
from typing import Union as _Union
from typing import Tuple as _Tuple
from .utils._profiler import Profiler, NullProfiler
from ._ward import Ward
from ._wardinfo import WardInfos, WardInfo
__all__ = ["Wards"]
[docs]class Wards:
"""This class holds an entire network of Ward objects"""
[docs] def __init__(self, wards: _List[Ward] = None):
"""Construct, optionally from a list of Ward objects"""
self._wards = []
self._info = WardInfos()
self._unresolved = []
if wards is not None:
self.insert(wards)
[docs] def __str__(self):
if len(self) == 0:
return "Wards::null"
elif len(self) < 10:
return f"[ {', '.join([str(x) for x in self._wards[1:]])} ]"
else:
s = f"[ {', '.join([str(x) for x in self._wards[1:7]])}, ... "
s += f"{', '.join([str(x) for x in self._wards[-3:]])} ]"
return s
[docs] def __repr__(self):
return self.__str__()
[docs] def __eq__(self, other):
return self.__class__ == other.__class__ and \
self.__dict__ == other.__dict__
[docs] def insert(self, wards: _List[Ward], overwrite: bool = True,
_need_deep_copy: bool = True) -> None:
"""Insert the passed wards onto this list. This will overwrite
the existing ward if 'overwrite' is true, otherwise it will
add the ward's data to the existing ward
"""
if not isinstance(wards, list):
wards = [wards]
for ward in wards:
if isinstance(ward, Wards):
for w in ward._wards:
if w is not None:
w = w.dereference(ward)
self.insert(w, overwrite=overwrite,
_need_deep_copy=False)
elif ward is None:
continue
elif not isinstance(ward, Ward):
raise TypeError(
f"You cannot append a {ward} to a list of Ward objects!")
# get the largest ID and then resize the list...
largest_id = 0
for ward in wards:
if isinstance(ward, Ward):
if ward.id() is not None and ward.id() > largest_id:
largest_id = ward.id()
if largest_id >= len(self._wards):
self._wards += [None] * (largest_id - len(self._wards) + 1)
add_later = []
from copy import deepcopy
for ward in wards:
if isinstance(ward, Ward):
# make sure that this is not a duplicate...
info = ward._info
if info in self._info:
idx = self._info.index(info)
if ward.id() is None:
ward = deepcopy(ward)
ward.set_id(idx)
if idx != ward.id():
raise KeyError(
f"You cannot have two different wards that have "
f"the same info: {self._wards[idx]} : {ward}")
# otherwise the IDs are the same, so this is an
# update to an existing ward
if ward.id() is None:
add_later.append(ward)
else:
if _need_deep_copy:
ward = deepcopy(ward)
if overwrite or self._wards[ward.id()] is None:
self._wards[ward.id()] = ward
else:
self._wards[ward.id()].merge(ward)
self._info[ward.id()] = ward._info
self._unresolved.append(ward.id())
for ward in add_later:
# append this onto the end of the list
idx = len(self._wards)
if _need_deep_copy:
ward = deepcopy(ward)
ward.set_id(idx)
self._wards.append(ward)
self._info[ward.id()] = ward._info
self._unresolved.append(ward.id())
self._resolve()
[docs] def add(self, ward: Ward) -> None:
"""Synonym for insert"""
self.insert(ward, overwrite=False)
def __add__(self, other):
from copy import deepcopy
c = deepcopy(self)
c.add(other)
return c
def __radd__(self, other):
return self.__add__(other)
def __iadd__(self, other):
self.add(other)
return self
[docs] def __mul__(self, scale: float) -> Wards:
"""Scale the number of workers and players by 'scale'"""
return self.scale(work_ratio=scale, play_ratio=scale)
[docs] def __rmul__(self, scale: float) -> Wards:
"""Scale the number of workers and players by 'scale'"""
return self.scale(work_ratio=scale, play_ratio=scale)
[docs] def __imul__(self, scale: float) -> Wards:
"""In-place multiply the number of workers and players by 'scale'"""
return self.scale(work_ratio=scale, play_ratio=scale, _inplace=True)
[docs] def is_resolved(self) -> bool:
"""Return whether or not this is a fully resolved set of Wards
(i.e. each ward only links to other wards in this set)
"""
return len(self._unresolved) == 0
[docs] def unresolved_wards(self) -> _List[int]:
"""Return the list of IDs of unresolved wards"""
from copy import deepcopy
return deepcopy(self._unresolved)
def _resolve(self) -> None:
"""Try to resolve all of the links. Note that this will only resolve
as many wards as it can. Any unresolved wards will be available
as 'unresolvd_wards'
"""
still_unresolved = []
for unresolved in self._unresolved:
self._wards[unresolved].resolve(self, _inplace=True)
if not self._wards[unresolved].is_resolved():
still_unresolved.append(unresolved)
self._unresolved = still_unresolved
[docs] def get(self, id: _Union[int, WardInfo], dereference: bool = True) -> Ward:
"""Return the ward with the specified id - this can eb the integer
ID of the ward or the WardInfo of the ward. If 'dereference'
is True then this will dereference all of the IDs into
WardInfo objects. This is useful if you want to use the
resulting Ward with other Wards collections
"""
ward = self[id]
if dereference:
ward.dereference(self, _inplace=True)
return ward
[docs] def getinfo(self, id: _Union[int, str, WardInfo]) -> WardInfo:
"""Return the WardInfo matching the ward with the passed ID"""
idx = self.index(id)
return self._info.wards[idx]
[docs] def __getitem__(self, id: _Union[int, str, WardInfo]) -> Ward:
"""Return the ward with specified id - this can be the integer
ID of the ward, or the WardInfo of the ward. Note that
this returns a copy of the Ward
"""
if isinstance(id, Ward):
idx = self._info.index(id._info)
if id._id is not None:
if idx != id._id:
return ValueError(f"No ward matching {id}")
w = self._wards[idx]
elif isinstance(id, WardInfo):
w = self._wards[self._info.index(id)]
elif isinstance(id, str):
try:
return self[int(id)]
except Exception:
pass
return self[WardInfo(name=id)]
else:
w = self._wards[id]
# must deepcopy this or else it can be changed behind our back
from copy import deepcopy
return deepcopy(w)
[docs] def index(self, id: _Union[int, str, WardInfo, Ward]) -> int:
"""Return the index of the ward that matches the passed
id - which can be the integer ID or WardInfo - in this
Wards object. This raises a ValueError if the
ward doens't exist
"""
if isinstance(id, Ward):
return self._info.index(id._info)
elif isinstance(id, WardInfo):
return self._info.index(id)
elif isinstance(id, str):
try:
return self.index(int(id))
except Exception:
pass
return self.index(WardInfo(name=id))
else:
try:
id = int(id)
except Exception:
raise ValueError(f"No ward matching {id}")
if id < 0:
id = len(self._wards) + id
if id < 0 or id >= len(self._wards):
raise ValueError(f"No ward matching {id}")
if self._wards[id] is None:
raise ValueError(f"No ward matching {id}")
return id
[docs] def __contains__(self, id: _Union[int, str, WardInfo, Ward]) -> bool:
"""Return whether or not the passed id - which can be an integer
ID or WardInfo - is in this Wards object
"""
if isinstance(id, Ward):
return id._info in self._info
elif isinstance(id, WardInfo):
return id in self._info
elif isinstance(id, str):
try:
if self.__contains__(int(id)):
return True
except Exception:
pass
return self.__contains__(WardInfo(name=id))
else:
id = int(id)
if id < 0:
id = len(self._wards) + id
if id < 0 or id >= len(self._wards):
return False
else:
return self._wards[id] is not None
[docs] def contains(self, id: _Union[int, str, WardInfo]) -> bool:
"""Return whether or not the passed id - which can be an integer
ID or WardInfo - is in this Wards object
"""
return self.__contains__(id)
def __len__(self):
return len(self._wards)
[docs] def num_work_links(self):
"""Return the total number of work links"""
n = 0
for ward in self._wards:
if ward is not None:
n += ward.num_work_links()
return n
[docs] def num_play_links(self):
"""Return the total number of play links"""
n = 0
for ward in self._wards:
if ward is not None:
n += ward.num_play_links()
return n
[docs] def num_players(self):
"""Return the total number of players in this network"""
num = 0
for ward in self._wards:
if ward is not None:
num += ward.num_players()
return num
[docs] def num_workers(self):
"""Return the total number of workers in this network"""
num = 0
for ward in self._wards:
if ward is not None:
num += ward.num_workers()
return num
[docs] def population(self):
"""Return the total population in this network"""
num = 0
for ward in self._wards:
if ward is not None:
num += ward.population()
return num
[docs] def scale(self, work_ratio: float = 1.0,
play_ratio: float = 1.0, _inplace: bool = False) -> Wards:
"""Return a copy of these wards where the number of workers
and players have been scaled by 'work_ratios' and 'play_ratios'
respectively. These can be greater than 1.0, e.g. if you want
to scale up the number of workers and players
Parameters
----------
work_ratio: float
The scaling ratio for workers
play_ratio: float
The scaling ratio for players
Returns
-------
Wards: A copy of this Wards scaled by the requested amount
"""
if _inplace:
wards = self
else:
from copy import deepcopy
wards = deepcopy(self)
for ward in wards._wards:
if ward is not None:
ward.scale(work_ratio=work_ratio, play_ratio=play_ratio,
_inplace=True)
return wards
def _harmonise_nodes(self, other: Wards) -> None:
"""Make sure that this set of Wards contains all of the wards
in 'other', and that all have the same ID and info. If
there are any missing nodes, then zero-populated
copies will be added
"""
errors = []
for ward in self._wards:
if ward is None:
pass
elif ward not in other:
errors.append(f"Missing ward {ward} from the overall network")
else:
other_ward = other[ward]
if ward._id != other_ward._id or \
ward._info != other_ward._info or \
ward._pos != other_ward._pos:
errors.append(f"Ward exists, but is different: {ward} "
f"versus {other_ward}.")
if len(errors) > 0:
from .utils._console import Console
Console.error("\n".join(errors))
raise ValueError("Cannot harmonise incompatible Wards")
for ward in other._wards:
if ward is not None and ward not in self:
self.insert(ward.depopulate(zero_player_weights=True),
_need_deep_copy=False)
def _harmonise_links(self, other: Wards) -> None:
"""Make sure that the wards in this object has exactly the same
links as the wards in 'other'.
This is used as part of the Wards.harmonise function, and
ensures that all subnet wards have identical ward and link
indexes. This is always performed in-place
"""
if len(self) != len(other):
from .utils._console import Console
Console.error(f"Cannot harmonise links of incompatible wards. "
f"Sizes do not match: {len(self)} versus "
f"{len(other)}.\n"
f"{self}\n"
f"{other}")
raise ValueError("Cannot harmonise incompatible Wards")
for self_ward, other_ward in zip(self._wards, other._wards):
if self_ward is None:
assert other_ward is None
elif other_ward is None:
assert self_ward is None
else:
self_ward._harmonise_links(other_ward)
[docs] @staticmethod
def harmonise(wardss: _List['Wards']) -> _Tuple['Wards', _List['Wards']]:
"""Harmonise the passed list of wards, returning a tuple that
contains the overall sum of all of these wards, plus a new list
where all Wards use IDs that are correct and valid
across the entire group
"""
harmonised = []
for wards in wardss:
if wards is not None and not isinstance(wards, Wards):
raise TypeError(f"Cannot harmonise non-Wards objects")
# create the overall Wards that will provide the IDs for all
overall = Wards()
for wards in wardss:
if wards is None:
continue
hwards = Wards()
for ward in wards._wards:
if ward is not None:
ward = ward.dereference(wards)
overall.insert(ward, overwrite=False)
ward.resolve(overall, _inplace=True)
hwards.insert(ward, overwrite=False,
_need_deep_copy=False)
harmonised.append(hwards)
for wards in harmonised:
wards._harmonise_nodes(overall)
wards._harmonise_links(overall)
return (overall, harmonised)
[docs] def assert_sane(self):
"""Make sure that we don't refer to any non-existent wards"""
if len(self._wards) == 0:
return
self._resolve()
nwards = len(self._wards) - 1
for i, ward in enumerate(self._wards):
if ward is None:
continue
if i != ward.id():
raise AssertionError(f"{ward} should have index {i}")
for c in ward.work_connections():
if not isinstance(c, int):
raise AssertionError(f"Unresolved connection {c}")
elif c < 1 or c > nwards:
raise AssertionError(
f"{ward} has a work connection to an invalid "
f"ward ID {c}. Range should be 1 <= n <= {nwards}")
elif self._wards[c] is None:
raise AssertionError(
f"{ward} has a work connection to a null "
f"ward ID {c}. This ward is null")
for c in ward.play_connections():
if not isinstance(c, int):
raise AssertionError(f"Unresolved connection {c}")
elif c < 1 or c > nwards:
raise AssertionError(
f"{ward} has a play connection to an invalid "
f"ward ID {c}. Range should be 1 <= n <= {nwards}")
elif self._wards[c] is None:
raise AssertionError(
f"{ward} has a play connection to a null "
f"ward ID {c}. This ward is null")
[docs] def to_data(self, profiler: Profiler = None):
"""Return a data representation of these wards that can
be serialised to JSON
"""
if len(self) > 0:
if profiler is None:
profiler = NullProfiler()
p = profiler.start("to_data")
p = p.start("assert_sane")
self.assert_sane()
p = p.stop()
p = p.start("convert_wards")
nwards = len(self._wards)
from .utils._console import Console
with Console.progress(visible=(nwards > 250)) as progress:
data = []
task = progress.add_task("Converting to data", total=nwards)
for i, ward in enumerate(self._wards):
if ward is None:
continue
else:
data.append(ward.to_data())
if i % 250 == 0:
progress.update(task, completed=i+1)
progress.update(task, completed=nwards, force_update=True)
p = p.stop()
p = p.stop()
return data
else:
return None
[docs] @staticmethod
def from_data(data, profiler: Profiler = None):
"""Return the Wards constructed from a data represnetation,
which may have come from deserialised JSON
"""
if data is None or len(data) == 0:
return Wards()
if profiler is None:
profiler = NullProfiler()
p = profiler.start("from_data")
p = p.start("convert_wards")
wards = Wards()
nwards = len(data)
from .utils._console import Console
with Console.progress(visible=(nwards > 250)) as progress:
task = progress.add_task("Converting from data", total=nwards)
for i, x in enumerate(data):
if x is not None:
wards.insert(Ward.from_data(x), _need_deep_copy=False)
if i % 250 == 0:
progress.update(task, completed=i+1)
progress.update(task, completed=nwards, force_update=True)
p = p.stop()
p = p.start("assert_sane")
wards.assert_sane()
p = p.stop()
p = p.stop()
return wards
[docs] def to_json(self, filename: str = None, indent: int = None,
auto_bzip: bool = True) -> str:
"""Serialise the wards to JSON. This will write to a file
if filename is set, otherwise it will return a JSON string.
Parameters
----------
filename: str
The name of the file to write the JSON to. The absolute
path to the written file will be returned. If filename is None
then this will serialise to a JSON string which will be
returned.
indent: int
The number of spaces of indent to use when writing the json
auto_bzip: bool
Whether or not to automatically bzip2 the written json file
Returns
-------
str
Returns either the absolute path to the written file, or
the json-serialised string
"""
import json
if indent is not None:
indent = int(indent)
if filename is None:
return json.dumps(self.to_data(), indent=indent)
else:
from pathlib import Path
filename = str(Path(filename).expanduser().resolve().absolute())
if auto_bzip:
if not filename.endswith(".bz2"):
filename += ".bz2"
import bz2
with bz2.open(filename, "wt") as FILE:
try:
json.dump(self.to_data(), FILE, indent=indent)
except Exception:
import os
FILE.close()
os.unlink(filename)
raise
else:
with open(filename, "w") as FILE:
try:
json.dump(self.to_data(), FILE, indent=indent)
except Exception:
import os
FILE.close()
os.unlink(filename)
raise
return filename
[docs] @staticmethod
def from_json(s: str):
"""Return the Wards constructed from the passed json. This will
either load from a passed json string, or from json loaded
from the passed file
"""
import os
import json
if os.path.exists(s):
try:
import bz2
with bz2.open(s, "rt") as FILE:
data = json.load(FILE)
except Exception:
data = None
if data is None:
with open(s, "rt") as FILE:
data = json.load(FILE)
else:
try:
data = json.loads(s)
except Exception:
data = None
if data is None:
from .utils._console import Console
Console.error(f"Unable to load a network from '{s}'. Check that "
f"this is valid JSON or that the file exists.")
raise IOError(f"Cannot load Wards from '{s}'")
return Wards.from_data(data)