Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom dependencies #178

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions xsimlab/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def get_processes_to_validate(self):

return {k: list(v) for k, v in processes_to_validate.items()}

def get_process_dependencies(self):
def get_process_dependencies(self, custom_dependencies={}):
"""Return a dictionary where keys are each process of the model and
values are lists of the names of dependent processes (or empty
lists for processes that have no dependencies).
Expand All @@ -423,6 +423,10 @@ def get_process_dependencies(self):
]
)

# actually add custom dependencies
for p_name, deps in custom_dependencies.items():
self._dep_processes[p_name].update(deps)

for p_name, p_obj in self._processes_obj.items():
for var in filter_variables(p_obj, intent=VarIntent.OUT).values():
if var.metadata["var_type"] == VarType.ON_DEMAND:
Expand Down Expand Up @@ -534,13 +538,16 @@ class Model(AttrMapping):

active = []

def __init__(self, processes):
def __init__(self, processes, custom_dependencies={}):
"""
Parameters
----------
processes : dict
Dictionnary with process names as keys and classes (decorated with
Dictionary with process names as keys and classes (decorated with
:func:`process`) as values.
custom_dependencies : dict
Dictionary of custom dependencies.
keys are process names and values iterable of process names that it depends on

Raises
------
Expand Down Expand Up @@ -572,7 +579,15 @@ def __init__(self, processes):

self._processes_to_validate = builder.get_processes_to_validate()

self._dep_processes = builder.get_process_dependencies()
# clean custom dependencies
self._custom_dependencies = {}
for p_name, c_deps in custom_dependencies.items():
c_deps = {c_deps} if isinstance(c_deps, str) else set(c_deps)
self._custom_dependencies[p_name] = c_deps

self._dep_processes = builder.get_process_dependencies(
self._custom_dependencies
)
self._processes = builder.get_sorted_processes()

super(Model, self).__init__(self._processes)
Expand Down Expand Up @@ -1065,7 +1080,7 @@ def drop_processes(self, keys):

Parameters
----------
keys : str or list of str
keys : str or iterable of str
Name(s) of the processes to drop.

Returns
Expand All @@ -1074,13 +1089,52 @@ def drop_processes(self, keys):
New Model instance with dropped processes.

"""
if isinstance(keys, str):
keys = [keys]
keys = {keys} if isinstance(keys, str) else set(keys)

processes_cls = {
k: type(obj) for k, obj in self._processes.items() if k not in keys
}
return type(self)(processes_cls)

# we also should check for chains of deps e.g.
# a->b->c->d->e where {b,c,d} are removed
# then we have a->e left over.
# perform a depth-first search on custom dependencies
# and let the custom deps propagate forward
completed = set()
for key in self._custom_dependencies:
if key in completed:
continue
key_stack = [key]
while key_stack:
cur = key_stack[-1]
if cur in completed:
key_stack.pop()
continue

# if we have custom dependencies that are removed
# and are fully traversed, add their deps to the current
child_keys = keys.intersection(self._custom_dependencies[cur])
if child_keys.issubset(completed):
# all children are added, so we are safe
self._custom_dependencies[cur].update(
*[
self._custom_dependencies[child_key]
for child_key in child_keys
]
)
self._custom_dependencies[cur] -= child_keys
completed.add(cur)
key_stack.pop()
else: # if child_keys - completed:
# we need to search deeper: add to the stack.
key_stack.extend([k for k in child_keys - completed])

# now also remove keys from custom deps
for key in keys:
if key in self._custom_dependencies:
del self._custom_dependencies[key]

return type(self)(processes_cls, self._custom_dependencies)

def __eq__(self, other):
if not isinstance(other, self.__class__):
Expand Down
58 changes: 58 additions & 0 deletions xsimlab/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,36 @@ def test_get_process_dependencies(self, model):
# order of dependencies is not ensured
assert set(actual[p_name]) == set(expected[p_name])

def test_get_process_dependencies_custom(self, model):
@xs.process
class A:
pass

@xs.process
class B:
pass

@xs.process
class C:
pass

actual = xs.Model(
{"a": A, "b": B}, custom_dependencies={"a": "b"}
).dependent_processes
expected = {"a": ["b"], "b": []}

for p_name in expected:
assert set(actual[p_name]) == set(expected[p_name])

# also test with a list
actual = xs.Model(
{"a": A, "b": B, "c": C}, custom_dependencies={"a": ["b", "c"]}
).dependent_processes
expected = {"a": ["b", "c"], "b": [], "c": []}

for p_name in expected:
assert set(actual[p_name]) == set(expected[p_name])

@pytest.mark.parametrize(
"p_name,dep_p_name",
[
Expand Down Expand Up @@ -294,6 +324,34 @@ def test_drop_processes(self, no_init_model, simple_model, p_names):
m = no_init_model.drop_processes(p_names)
assert m == simple_model

def test_drop_processes_custom(self):
@xs.process
class A:
pass

@xs.process
class B:
pass

@xs.process
class C:
pass

@xs.process
class D:
pass

@xs.process
class E:
pass

model = xs.Model(
{"a": A, "b": B, "c": C, "d": D, "e": E},
custom_dependencies={"d": "c", "c": "b", "b": {"a", "e"}},
)
model = model.drop_processes(["b", "c"])
assert set(model.dependent_processes["d"]) == {"a", "e"}

def test_visualize(self, model):
pytest.importorskip("graphviz")
ipydisp = pytest.importorskip("IPython.display")
Expand Down