Source code for metawards._wards

from __future__ import annotations

from typing import List as _List
from typing import Union as _Union
from typing import Tuple as _Tuple

from .utils._profiler import Profiler, NullProfiler

from ._ward import Ward
from ._wardinfo import WardInfos, WardInfo

__all__ = ["Wards"]


[docs]class Wards: """This class holds an entire network of Ward objects"""
[docs] def __init__(self, wards: _List[Ward] = None): """Construct, optionally from a list of Ward objects""" self._wards = [] self._info = WardInfos() self._unresolved = [] if wards is not None: self.insert(wards)
[docs] def __str__(self): if len(self) == 0: return "Wards::null" elif len(self) < 10: return f"[ {', '.join([str(x) for x in self._wards[1:]])} ]" else: s = f"[ {', '.join([str(x) for x in self._wards[1:7]])}, ... " s += f"{', '.join([str(x) for x in self._wards[-3:]])} ]" return s
[docs] def __repr__(self): return self.__str__()
[docs] def __eq__(self, other): return self.__class__ == other.__class__ and \ self.__dict__ == other.__dict__
[docs] def insert(self, wards: _List[Ward], overwrite: bool = True, _need_deep_copy: bool = True) -> None: """Insert the passed wards onto this list. This will overwrite the existing ward if 'overwrite' is true, otherwise it will add the ward's data to the existing ward """ if not isinstance(wards, list): wards = [wards] for ward in wards: if isinstance(ward, Wards): for w in ward._wards: if w is not None: w = w.dereference(ward) self.insert(w, overwrite=overwrite, _need_deep_copy=False) elif ward is None: continue elif not isinstance(ward, Ward): raise TypeError( f"You cannot append a {ward} to a list of Ward objects!") # get the largest ID and then resize the list... largest_id = 0 for ward in wards: if isinstance(ward, Ward): if ward.id() is not None and ward.id() > largest_id: largest_id = ward.id() if largest_id >= len(self._wards): self._wards += [None] * (largest_id - len(self._wards) + 1) add_later = [] from copy import deepcopy for ward in wards: if isinstance(ward, Ward): # make sure that this is not a duplicate... info = ward._info if info in self._info: idx = self._info.index(info) if ward.id() is None: ward = deepcopy(ward) ward.set_id(idx) if idx != ward.id(): raise KeyError( f"You cannot have two different wards that have " f"the same info: {self._wards[idx]} : {ward}") # otherwise the IDs are the same, so this is an # update to an existing ward if ward.id() is None: add_later.append(ward) else: if _need_deep_copy: ward = deepcopy(ward) if overwrite or self._wards[ward.id()] is None: self._wards[ward.id()] = ward else: self._wards[ward.id()].merge(ward) self._info[ward.id()] = ward._info self._unresolved.append(ward.id()) for ward in add_later: # append this onto the end of the list idx = len(self._wards) if _need_deep_copy: ward = deepcopy(ward) ward.set_id(idx) self._wards.append(ward) self._info[ward.id()] = ward._info self._unresolved.append(ward.id()) self._resolve()
[docs] def add(self, ward: Ward) -> None: """Synonym for insert""" self.insert(ward, overwrite=False)
def __add__(self, other): from copy import deepcopy c = deepcopy(self) c.add(other) return c def __radd__(self, other): return self.__add__(other) def __iadd__(self, other): self.add(other) return self
[docs] def __mul__(self, scale: float) -> Wards: """Scale the number of workers and players by 'scale'""" return self.scale(work_ratio=scale, play_ratio=scale)
[docs] def __rmul__(self, scale: float) -> Wards: """Scale the number of workers and players by 'scale'""" return self.scale(work_ratio=scale, play_ratio=scale)
[docs] def __imul__(self, scale: float) -> Wards: """In-place multiply the number of workers and players by 'scale'""" return self.scale(work_ratio=scale, play_ratio=scale, _inplace=True)
[docs] def is_resolved(self) -> bool: """Return whether or not this is a fully resolved set of Wards (i.e. each ward only links to other wards in this set) """ return len(self._unresolved) == 0
[docs] def unresolved_wards(self) -> _List[int]: """Return the list of IDs of unresolved wards""" from copy import deepcopy return deepcopy(self._unresolved)
def _resolve(self) -> None: """Try to resolve all of the links. Note that this will only resolve as many wards as it can. Any unresolved wards will be available as 'unresolvd_wards' """ still_unresolved = [] for unresolved in self._unresolved: self._wards[unresolved].resolve(self, _inplace=True) if not self._wards[unresolved].is_resolved(): still_unresolved.append(unresolved) self._unresolved = still_unresolved
[docs] def get(self, id: _Union[int, WardInfo], dereference: bool = True) -> Ward: """Return the ward with the specified id - this can eb the integer ID of the ward or the WardInfo of the ward. If 'dereference' is True then this will dereference all of the IDs into WardInfo objects. This is useful if you want to use the resulting Ward with other Wards collections """ ward = self[id] if dereference: ward.dereference(self, _inplace=True) return ward
[docs] def getinfo(self, id: _Union[int, str, WardInfo]) -> WardInfo: """Return the WardInfo matching the ward with the passed ID""" idx = self.index(id) return self._info.wards[idx]
[docs] def __getitem__(self, id: _Union[int, str, WardInfo]) -> Ward: """Return the ward with specified id - this can be the integer ID of the ward, or the WardInfo of the ward. Note that this returns a copy of the Ward """ if isinstance(id, Ward): idx = self._info.index(id._info) if id._id is not None: if idx != id._id: return ValueError(f"No ward matching {id}") w = self._wards[idx] elif isinstance(id, WardInfo): w = self._wards[self._info.index(id)] elif isinstance(id, str): try: return self[int(id)] except Exception: pass return self[WardInfo(name=id)] else: w = self._wards[id] # must deepcopy this or else it can be changed behind our back from copy import deepcopy return deepcopy(w)
[docs] def index(self, id: _Union[int, str, WardInfo, Ward]) -> int: """Return the index of the ward that matches the passed id - which can be the integer ID or WardInfo - in this Wards object. This raises a ValueError if the ward doens't exist """ if isinstance(id, Ward): return self._info.index(id._info) elif isinstance(id, WardInfo): return self._info.index(id) elif isinstance(id, str): try: return self.index(int(id)) except Exception: pass return self.index(WardInfo(name=id)) else: try: id = int(id) except Exception: raise ValueError(f"No ward matching {id}") if id < 0: id = len(self._wards) + id if id < 0 or id >= len(self._wards): raise ValueError(f"No ward matching {id}") if self._wards[id] is None: raise ValueError(f"No ward matching {id}") return id
[docs] def __contains__(self, id: _Union[int, str, WardInfo, Ward]) -> bool: """Return whether or not the passed id - which can be an integer ID or WardInfo - is in this Wards object """ if isinstance(id, Ward): return id._info in self._info elif isinstance(id, WardInfo): return id in self._info elif isinstance(id, str): try: if self.__contains__(int(id)): return True except Exception: pass return self.__contains__(WardInfo(name=id)) else: id = int(id) if id < 0: id = len(self._wards) + id if id < 0 or id >= len(self._wards): return False else: return self._wards[id] is not None
[docs] def contains(self, id: _Union[int, str, WardInfo]) -> bool: """Return whether or not the passed id - which can be an integer ID or WardInfo - is in this Wards object """ return self.__contains__(id)
def __len__(self): return len(self._wards)
[docs] def num_players(self): """Return the total number of players in this network""" num = 0 for ward in self._wards: if ward is not None: num += ward.num_players() return num
[docs] def num_workers(self): """Return the total number of workers in this network""" num = 0 for ward in self._wards: if ward is not None: num += ward.num_workers() return num
[docs] def population(self): """Return the total population in this network""" num = 0 for ward in self._wards: if ward is not None: num += ward.population() return num
[docs] def scale(self, work_ratio: float = 1.0, play_ratio: float = 1.0, _inplace: bool = False) -> Wards: """Return a copy of these wards where the number of workers and players have been scaled by 'work_ratios' and 'play_ratios' respectively. These can be greater than 1.0, e.g. if you want to scale up the number of workers and players Parameters ---------- work_ratio: float The scaling ratio for workers play_ratio: float The scaling ratio for players Returns ------- Wards: A copy of this Wards scaled by the requested amount """ if _inplace: wards = self else: from copy import deepcopy wards = deepcopy(self) for ward in wards._wards: if ward is not None: ward.scale(work_ratio=work_ratio, play_ratio=play_ratio, _inplace=True) return wards
def _harmonise_nodes(self, other: Wards) -> None: """Make sure that this set of Wards contains all of the wards in 'other', and that all have the same ID and info. If there are any missing nodes, then zero-populated copies will be added """ errors = [] for ward in self._wards: if ward is None: pass elif ward not in other: errors.append(f"Missing ward {ward} from the overall network") else: other_ward = other[ward] if ward._id != other_ward._id or \ ward._info != other_ward._info or \ ward._pos != other_ward._pos: errors.append(f"Ward exists, but is different: {ward} " f"versus {other_ward}.") if len(errors) > 0: from .utils._console import Console Console.error("\n".join(errors)) raise ValueError("Cannot harmonise incompatible Wards") for ward in other._wards: if ward is not None and ward not in self: self.insert(ward.depopulate(zero_player_weights=True), _need_deep_copy=False) def _harmonise_links(self, other: Wards) -> None: """Make sure that the wards in this object has exactly the same links as the wards in 'other'. This is used as part of the Wards.harmonise function, and ensures that all subnet wards have identical ward and link indexes. This is always performed in-place """ if len(self) != len(other): from .utils._console import Console Console.error(f"Cannot harmonise links of incompatible wards. " f"Sizes do not match: {len(self)} versus " f"{len(other)}.\n" f"{self}\n" f"{other}") raise ValueError("Cannot harmonise incompatible Wards") for self_ward, other_ward in zip(self._wards, other._wards): if self_ward is None: assert other_ward is None elif other_ward is None: assert self_ward is None else: self_ward._harmonise_links(other_ward)
[docs] @staticmethod def harmonise(wardss: _List['Wards']) -> _Tuple['Wards', _List['Wards']]: """Harmonise the passed list of wards, returning a tuple that contains the overall sum of all of these wards, plus a new list where all Wards use IDs that are correct and valid across the entire group """ harmonised = [] for wards in wardss: if wards is not None and not isinstance(wards, Wards): raise TypeError(f"Cannot harmonise non-Wards objects") # create the overall Wards that will provide the IDs for all overall = Wards() for wards in wardss: if wards is None: continue hwards = Wards() for ward in wards._wards: if ward is not None: ward = ward.dereference(wards) overall.insert(ward, overwrite=False) ward.resolve(overall, _inplace=True) hwards.insert(ward, overwrite=False, _need_deep_copy=False) harmonised.append(hwards) for wards in harmonised: wards._harmonise_nodes(overall) wards._harmonise_links(overall) return (overall, harmonised)
[docs] def assert_sane(self): """Make sure that we don't refer to any non-existent wards""" if len(self._wards) == 0: return self._resolve() nwards = len(self._wards) - 1 for i, ward in enumerate(self._wards): if ward is None: continue if i != ward.id(): raise AssertionError(f"{ward} should have index {i}") for c in ward.work_connections(): if not isinstance(c, int): raise AssertionError(f"Unresolved connection {c}") elif c < 1 or c > nwards: raise AssertionError( f"{ward} has a work connection to an invalid " f"ward ID {c}. Range should be 1 <= n <= {nwards}") elif self._wards[c] is None: raise AssertionError( f"{ward} has a work connection to a null " f"ward ID {c}. This ward is null") for c in ward.play_connections(): if not isinstance(c, int): raise AssertionError(f"Unresolved connection {c}") elif c < 1 or c > nwards: raise AssertionError( f"{ward} has a play connection to an invalid " f"ward ID {c}. Range should be 1 <= n <= {nwards}") elif self._wards[c] is None: raise AssertionError( f"{ward} has a play connection to a null " f"ward ID {c}. This ward is null")
[docs] def to_data(self, profiler: Profiler = None): """Return a data representation of these wards that can be serialised to JSON """ if len(self) > 0: if profiler is None: profiler = NullProfiler() p = profiler.start("to_data") p = p.start("assert_sane") self.assert_sane() p = p.stop() p = p.start("convert_wards") nwards = len(self._wards) from .utils._console import Console with Console.progress(visible=(nwards > 250)) as progress: data = [] task = progress.add_task("Converting to data", total=nwards) for i, ward in enumerate(self._wards): if ward is None: continue else: data.append(ward.to_data()) if i % 250 == 0: progress.update(task, completed=i+1) progress.update(task, completed=nwards, force_update=True) p = p.stop() p = p.stop() return data else: return None
[docs] @staticmethod def from_data(data, profiler: Profiler = None): """Return the Wards constructed from a data represnetation, which may have come from deserialised JSON """ if data is None or len(data) == 0: return Wards() if profiler is None: profiler = NullProfiler() p = profiler.start("from_data") p = p.start("convert_wards") wards = Wards() nwards = len(data) from .utils._console import Console with Console.progress(visible=(nwards > 250)) as progress: task = progress.add_task("Converting from data", total=nwards) for i, x in enumerate(data): if x is not None: wards.insert(Ward.from_data(x), _need_deep_copy=False) if i % 250 == 0: progress.update(task, completed=i+1) progress.update(task, completed=nwards, force_update=True) p = p.stop() p = p.start("assert_sane") wards.assert_sane() p = p.stop() p = p.stop() return wards
[docs] def to_json(self, filename: str = None, indent: int = None, auto_bzip: bool = True) -> str: """Serialise the wards to JSON. This will write to a file if filename is set, otherwise it will return a JSON string. Parameters ---------- filename: str The name of the file to write the JSON to. The absolute path to the written file will be returned. If filename is None then this will serialise to a JSON string which will be returned. indent: int The number of spaces of indent to use when writing the json auto_bzip: bool Whether or not to automatically bzip2 the written json file Returns ------- str Returns either the absolute path to the written file, or the json-serialised string """ import json if indent is not None: indent = int(indent) if filename is None: return json.dumps(self.to_data(), indent=indent) else: from pathlib import Path filename = str(Path(filename).expanduser().resolve().absolute()) if auto_bzip: if not filename.endswith(".bz2"): filename += ".bz2" import bz2 with bz2.open(filename, "wt") as FILE: try: json.dump(self.to_data(), FILE, indent=indent) except Exception: import os FILE.close() os.unlink(filename) raise else: with open(filename, "w") as FILE: try: json.dump(self.to_data(), FILE, indent=indent) except Exception: import os FILE.close() os.unlink(filename) raise return filename
[docs] @staticmethod def from_json(s: str): """Return the Wards constructed from the passed json. This will either load from a passed json string, or from json loaded from the passed file """ import os import json if os.path.exists(s): try: import bz2 with bz2.open(s, "rt") as FILE: data = json.load(FILE) except Exception: data = None if data is None: with open(s, "rt") as FILE: data = json.load(FILE) else: try: data = json.loads(s) except Exception: data = None if data is None: from .utils._console import Console Console.error(f"Unable to load a network from '{s}'. Check that " f"this is valid JSON or that the file exists.") raise IOError(f"Cannot load Wards from '{s}'") return Wards.from_data(data)