Skip to content

Commit

Permalink
Default mapping function from NIR to torch for simple primitives
Browse files Browse the repository at this point in the history
This implements the NIR-to-torch conversion of issue neuromorphs#25
  • Loading branch information
bvogginger committed Mar 27, 2024
1 parent 9f6eda2 commit bd6f667
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 6 deletions.
110 changes: 104 additions & 6 deletions nirtorch/from_nir.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dataclasses
import inspect
from numbers import Number
from typing import Any, Callable, Dict, List, Optional, Union

import nir
import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -180,21 +182,117 @@ def _mod_nir_to_graph(
return graph


def _switch_default_models(nir_graph: nir.NIRNode) -> Optional[torch.nn.Module]:
if isinstance(nir_graph, nir.Input) or isinstance(nir_graph, nir.Output):
def _to_tensor(tensor: Union[np.ndarray, torch.Tensor]):
if isinstance(tensor, torch.Tensor):
return tensor
if isinstance(tensor, Number):
return torch.as_tensor(tensor, dtype=torch.float32)
return torch.from_numpy(tensor).float()


def _switch_default_models(node: nir.NIRNode) -> Optional[torch.nn.Module]:
if isinstance(node, nir.Input) or isinstance(node, nir.Output):
return torch.nn.Identity()

if isinstance(node, nir.Affine):
has_bias = node.bias is not None
module = torch.nn.Linear(
node.weight.shape[1], node.weight.shape[0], bias=has_bias
)
module.weight.data = _to_tensor(node.weight)
if has_bias:
module.bias.data = _to_tensor(node.bias)
return module

if isinstance(node, nir.Conv1d):
module = torch.nn.Conv1d(
in_channels=node.weight.shape[1],
out_channels=node.weight.shape[0],
kernel_size=node.weight.shape[2],
stride=node.stride,
padding=node.padding,
dilation=node.dilation,
groups=node.groups,
)
module.weight.data = _to_tensor(node.weight)
module.bias.data = _to_tensor(node.bias)
return module

if isinstance(node, nir.Conv2d):
module = torch.nn.Conv2d(
in_channels=node.weight.shape[1],
out_channels=node.weight.shape[0],
kernel_size=node.weight.shape[-2:],
stride=node.stride,
padding=node.padding,
dilation=node.dilation,
groups=node.groups,
)
module.weight.data = _to_tensor(node.weight)
module.bias.data = _to_tensor(node.bias)
return module

if isinstance(node, nir.Flatten):
return torch.nn.Flatten(node.start_dim, node.end_dim)

if isinstance(node, nir.Linear):
module = torch.nn.Linear(node.weight.shape[1], node.weight.shape[0], bias=False)
module.weight.data = _to_tensor(node.weight)
return module

if isinstance(node, nir.AvgPool2d):
return torch.nn.AvgPool2d(
kernel_size=tuple(node.kernel_size),
stride=tuple(node.stride),
padding=tuple(node.padding),
)

if isinstance(node, nir.SumPool2d):
return torch.nn.AvgPool2d(
kernel_size=tuple(node.kernel_size),
stride=tuple(node.stride),
padding=tuple(node.padding),
divisor_override=1, # turn AvgPool into SumPool
)


def _switch_models_with_map(
nir_graph: nir.NIRNode, model_map: Callable[[nn.Module], nn.Module]
) -> nir.NIRNode:
nir_graph: nir.NIRGraph, model_map: Callable[[nn.Module], nn.Module]
) -> nir.NIRGraph:
"""Replace the nodes of NIR graph by the respective torch modules based on mapping
function.
This function creates a new NIR graph where the NIR nodes are replaced by
the corresponding torch modules. The graph edges are copied from the
original NIR graph.
A mapping function (argument `model_map`) is used to convert from the NIR
nodes to the framework specific modules.
If a NIR node is not supported by the mapping function, we try to convert
the node with the default mapping function `_switch_default_models`, which
covers simple NIR primitives with equivalent torch.nn modules.
Args:
nir_graph: The NIR graph
model_map: A method that returns the torch module that corresponds to each NIR
node.
Returns:
The new NIR graph with torch modules as nodes and edges copied from the
original NIR graph
Raises:
NotImplementedError: if there is no mapping to a torch module for
a certain NIR node
"""
nodes = {}
for name, node in nir_graph.nodes.items():
mapped_module = model_map(node)
if mapped_module is None:
mapped_module = _switch_default_models(node)

if mapped_module is None:
raise NotImplementedError(f"Module {type(node)} not supported")
nodes[name] = mapped_module
# nodes = {name: model_map(node) for name, node in nir_graph.nodes.items()}
return nir.NIRGraph(nodes, nir_graph.edges)


Expand All @@ -219,7 +317,7 @@ def load(
Args:
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string
representing the path to the NIR object.
model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the a torch
model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the torch
module that corresponds to each NIR node.
return_state (bool): If True, the execution of the loaded graph will return a
tuple of [output, state], where state is a GraphExecutorState object.
Expand Down
15 changes: 15 additions & 0 deletions tests/test_from_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def forward(self, x, state=None):
raise NotImplementedError(f"Unsupported module {m}")


def _dummy_model_map(m: nir.NIRNode) -> torch.nn.Module:
pass


def test_extract_empty():
g = nir.NIRGraph({}, [])
with pytest.raises(ValueError):
Expand Down Expand Up @@ -156,3 +160,14 @@ def test_import_braille():
g = nir.read("tests/braille.nir")
m = load(g, _recurrent_model_map)
assert m(torch.empty(1, 12))[0].shape == (1, 7)


def test_load_torch_nn_primitives():
g = nir.NIRGraph(
nodes={
"i": nir.Input(np.array([10, 20])),
"f": nir.Flatten(np.array([1])),
},
edges=[("i", "f")],
) # Mock node
_ = load(g, _dummy_model_map, return_state=False)

0 comments on commit bd6f667

Please sign in to comment.