diff --git a/nirtorch/__init__.py b/nirtorch/__init__.py index e4ae538..87eff9d 100644 --- a/nirtorch/__init__.py +++ b/nirtorch/__init__.py @@ -1,5 +1,5 @@ -from .graph import extract_torch_graph # noqa F401 from .from_nir import load # noqa F401 +from .graph import extract_torch_graph # noqa F401 from .to_nir import extract_nir_graph # noqa F401 __version__ = version = "0.2.1" diff --git a/nirtorch/from_nir.py b/nirtorch/from_nir.py index c8ec7d9..77a5fb6 100644 --- a/nirtorch/from_nir.py +++ b/nirtorch/from_nir.py @@ -1,6 +1,6 @@ import dataclasses import inspect -from typing import Callable, Dict, List, Optional, Any, Union +from typing import Any, Callable, Dict, List, Optional, Union import nir import torch @@ -13,8 +13,8 @@ @dataclasses.dataclass class GraphExecutorState: - """State for the GraphExecutor that keeps track of both the state of hidden - units and caches the output of previous modules, for use in (future) recurrent + """State for the GraphExecutor that keeps track of both the state of hidden units + and caches the output of previous modules, for use in (future) recurrent computations.""" state: Dict[str, Any] = dataclasses.field(default_factory=dict) @@ -24,14 +24,14 @@ class GraphExecutorState: class GraphExecutor(nn.Module): """Executes the NIR graph in PyTorch. - By default the graph executor is stateful, since there may be recurrence or - stateful modules in the graph. Specifically, that means accepting and returning a - state object (`GraphExecutorState`). If that is not desired, + By default the graph executor is stateful, since there may be recurrence or + stateful modules in the graph. Specifically, that means accepting and returning a + state object (`GraphExecutorState`). If that is not desired, set `return_state=False` in the constructor. Arguments: graph (Graph): The graph to execute - return_state (bool, optional): Whether to return the state object. + return_state (bool, optional): Whether to return the state object. Defaults to True. Raises: @@ -92,6 +92,7 @@ def _apply_module( data: Optional[torch.Tensor] = None, ): """Applies a module and keeps track of its state. + TODO: Use pytree to recursively construct the state """ inputs = [] @@ -205,7 +206,7 @@ def load( """Load a NIR graph and convert it to a torch module using the given model map. Because the graph can contain recurrence and stateful modules, the execution accepts - a secondary state argument and returns a tuple of [output, state], instead of just + a secondary state argument and returns a tuple of [output, state], instead of just the output as follows >>> executor = nirtorch.load(nir_graph, model_map) @@ -216,13 +217,13 @@ def load( If you do not wish to operate with state, set `return_state=False`. Args: - nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string + nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string representing the path to the NIR object. model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the a torch module that corresponds to each NIR node. - return_state (bool): If True, the execution of the loaded graph will return a - tuple of [output, state], where state is a GraphExecutorState object. - If False, only the NIR graph output will be returned. Note that state is + return_state (bool): If True, the execution of the loaded graph will return a + tuple of [output, state], where state is a GraphExecutorState object. + If False, only the NIR graph output will be returned. Note that state is required for recurrence to work in the graphs. Returns: diff --git a/nirtorch/graph_utils.py b/nirtorch/graph_utils.py index 1f67d71..d807d49 100644 --- a/nirtorch/graph_utils.py +++ b/nirtorch/graph_utils.py @@ -65,12 +65,12 @@ def find_all_ancestors( # return execution_order # + def trace_execution( node: T, edge_fn: Callable[[T], List[T]], visited: Set[T] = None ) -> List[T]: - """Traces the execution of a node by listing them in order, coloring recursive - nodes to avoid adding the same node twice. - """ + """Traces the execution of a node by listing them in order, coloring recursive nodes + to avoid adding the same node twice.""" if visited is None: visited = set() @@ -83,4 +83,4 @@ def trace_execution( for child in edge_fn(node): if child not in visited: successors += trace_execution(child, edge_fn, visited) - return [node] + successors \ No newline at end of file + return [node] + successors diff --git a/nirtorch/to_nir.py b/nirtorch/to_nir.py index 47e7d03..a819c21 100644 --- a/nirtorch/to_nir.py +++ b/nirtorch/to_nir.py @@ -1,5 +1,5 @@ -from typing import Any, Callable, Optional, Sequence import logging +from typing import Any, Callable, Optional, Sequence import nir import numpy as np @@ -15,13 +15,14 @@ def extract_nir_graph( model_name: Optional[str] = "model", ignore_submodules_of=None, model_fwd_args=[], - ignore_dims: Optional[Sequence[int]]=None, + ignore_dims: Optional[Sequence[int]] = None, ) -> nir.NIRNode: """Given a `model`, generate an NIR representation using the specified `model_map`. Assumptions and known issues: - - Cannot deal with layers like torch.nn.Identity(), since the input tensor and output - tensor will be the same object, and therefore lead to cyclic connections. + - Cannot deal with layers like torch.nn.Identity(), since the input tensor and + output tensor will be the same object, and therefore lead to cyclic + connections. Args: model (nn.Module): The model of interest diff --git a/tests/test_bidirectional.py b/tests/test_bidirectional.py index 42e3df6..52f8f6d 100644 --- a/tests/test_bidirectional.py +++ b/tests/test_bidirectional.py @@ -1,8 +1,8 @@ import nir import numpy as np import torch -import nirtorch +import nirtorch use_snntorch = False # use_snntorch = True diff --git a/tests/test_conversion.py b/tests/test_conversion.py index c04be4f..ee48b5d 100644 --- a/tests/test_conversion.py +++ b/tests/test_conversion.py @@ -1,7 +1,7 @@ +import nir import torch import torch.nn as nn -import nir import nirtorch diff --git a/tests/test_from_nir.py b/tests/test_from_nir.py index d2680b9..b2907f4 100644 --- a/tests/test_from_nir.py +++ b/tests/test_from_nir.py @@ -1,7 +1,7 @@ import nir import numpy as np -import torch import pytest +import torch from nirtorch.from_nir import load @@ -56,7 +56,10 @@ def test_extract_empty(): def test_extract_illegal_name(): - graph = nir.NIRGraph({"a.b": nir.Input(np.ones(1)), "a.c": nir.Linear(np.array([[1.]]))}, [("a.b", "a.c")]) + graph = nir.NIRGraph( + {"a.b": nir.Input(np.ones(1)), "a.c": nir.Linear(np.array([[1.0]]))}, + [("a.b", "a.c")], + ) torch_graph = load(graph, _torch_model_map) assert "a_c" in torch_graph._modules @@ -131,6 +134,7 @@ def _map_stateful(node): m = load(g, _map_stateful, return_state=False) assert not isinstance(m(torch.ones(10)), tuple) + def test_execute_recurrent(): w = np.ones((1, 1)) g = nir.NIRGraph( diff --git a/tests/test_graph.py b/tests/test_graph.py index 20d17f5..6e10625 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,10 +1,10 @@ +import nir import pytest import torch import torch.nn as nn from norse.torch import LIBoxCell, LIFCell, SequentialState from sinabs.layers import Merge -import nir from nirtorch import extract_nir_graph, extract_torch_graph @@ -238,6 +238,7 @@ def test_root_has_no_source(): len(graph.find_source_nodes_of(graph.find_node(my_branched_model.relu1))) == 0 ) + @pytest.mark.skip(reason="Root tracing is broken") def test_get_root(): graph = extract_torch_graph(my_branched_model, sample_data=data, model_name=None) @@ -282,7 +283,9 @@ def test_sequential_flatten(): assert tuple(g.nodes["input"].input_type["input"]) == (3, 4) d = torch.empty(2, 3, 4) - g = extract_nir_graph(torch.nn.Flatten(1), lambda x: nir.Flatten(d.shape, 1), d, ignore_dims=[0]) + g = extract_nir_graph( + torch.nn.Flatten(1), lambda x: nir.Flatten(d.shape, 1), d, ignore_dims=[0] + ) assert tuple(g.nodes["input"].input_type["input"]) == (3, 4) @@ -311,6 +314,7 @@ def forward(self, x, state=None): assert d.nodes.keys() == {"input", "l", "r", "output"} assert set(d.edges) == {("input", "r"), ("r", "l"), ("l", "output"), ("r", "r")} + @pytest.mark.skip(reason="Subgraphs are currently flattened") def test_captures_recurrence_manually(): def export_affine_rec_gru(module): diff --git a/tests/test_graph_utils.py b/tests/test_graph_utils.py index 36ec5f3..95b51b9 100644 --- a/tests/test_graph_utils.py +++ b/tests/test_graph_utils.py @@ -21,7 +21,7 @@ def from_string(graph): def __hash__(self) -> int: return self.name.__hash__() - + def __eq__(self, other: object) -> bool: return self.name == other.name diff --git a/tests/test_to_nir.py b/tests/test_to_nir.py index 3ea108d..d35d217 100644 --- a/tests/test_to_nir.py +++ b/tests/test_to_nir.py @@ -115,7 +115,9 @@ def extractor(module: nn.Module): return nir.Affine(module.weight, module.bias) raw_input_shape = (1, 3) - g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0]) + g = extract_nir_graph( + model, extractor, torch.ones(raw_input_shape), ignore_dims=[0] + ) exp_input_shape = (3,) assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape)) assert g.nodes["model"].weight.shape == (1, 3) @@ -129,13 +131,17 @@ def extractor(module: nn.Module): return nir.Affine(module.weight, module.bias) raw_input_shape = (1, 10, 3) - g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, -2]) + g = extract_nir_graph( + model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, -2] + ) exp_input_shape = (3,) assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape)) assert g.nodes["model"].weight.shape == (1, 3) raw_input_shape = (1, 10, 3) - g = extract_nir_graph(model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, 1]) + g = extract_nir_graph( + model, extractor, torch.ones(raw_input_shape), ignore_dims=[0, 1] + ) exp_input_shape = (3,) assert np.alltrue(g.nodes["input"].input_type["input"] == np.array(exp_input_shape))