Skip to content

Commit

Permalink
Merge pull request #112 from bvogginger/metadata_for_all_nodes
Browse files Browse the repository at this point in the history
Add metadata to missing nodes
  • Loading branch information
bvogginger authored Aug 20, 2024
2 parents b9b359a + a8e3841 commit 95be0db
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 11 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.277
rev: v0.6.0
hooks:
- id: ruff

Expand All @@ -21,4 +21,4 @@ repos:
rev: v1.7.5
hooks:
- id: docformatter
args: [--in-place, --black, --wrap-summaries=88, --wrap-descriptions=88]
args: [--in-place, --black, --wrap-summaries=88, --wrap-descriptions=88]
1 change: 1 addition & 0 deletions nir/ir/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class Conv2d(NIRNode):
dilation: Union[int, Tuple[int, int]] # Dilation
groups: int # Groups
bias: np.ndarray # Bias C_out
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
if isinstance(self.padding, str) and self.padding not in ["same", "valid"]:
Expand Down
2 changes: 2 additions & 0 deletions nir/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ class Input(NIRNode):
# Shape of incoming data (overrrides input_type from
# NIRNode to allow for non-keyword (positional) initialization)
input_type: Types
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = parse_shape_argument(self.input_type, "input")
Expand Down Expand Up @@ -479,6 +480,7 @@ class Output(NIRNode):
# Type of incoming data (overrrides input_type from
# NIRNode to allow for non-keyword (positional) initialization)
output_type: Types
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.output_type = parse_shape_argument(self.output_type, "output")
Expand Down
2 changes: 2 additions & 0 deletions nir/ir/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Linear(NIRNode):
"""

weight: np.ndarray # Weight term
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
assert len(self.weight.shape) >= 2, "Weight must be at least 2D"
Expand All @@ -69,6 +70,7 @@ class Scale(NIRNode):
"""

scale: np.ndarray # Scaling factor
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = {"input": np.array(self.scale.shape)}
Expand Down
1 change: 1 addition & 0 deletions nir/ir/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AvgPool2d(NIRNode):
kernel_size: np.ndarray # (Height, Width)
stride: np.ndarray # (Height, width)
padding: np.ndarray # (Height, width)
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = {"input": None}
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ find={include = ["nir*"]}
line-length = 100
lint.per-file-ignores = {"docs/conf.py" = ["E402"]}
exclude = ["paper/"]
extend-exclude = ["*.ipynb"]

20 changes: 11 additions & 9 deletions tests/test_readwrite.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import tempfile
import inspect
import sys
import tempfile

import numpy as np

import nir
from tests import mock_affine, mock_conv
from tests import mock_affine, mock_conv, mock_linear

ALL_NODES = []
for name, obj in inspect.getmembers(sys.modules["nir.ir"]):
Expand Down Expand Up @@ -47,7 +47,7 @@ def factory_test_graph(ir: nir.NIRGraph):
assert_equivalence(ir, ir2)


def factory_test_metadata(node):
def factory_test_metadata(ir: nir.NIRGraph):
def compare_dicts(d1, d2):
for k, v in d1.items():
if isinstance(v, np.ndarray):
Expand All @@ -58,12 +58,14 @@ def compare_dicts(d1, d2):
assert v == d2[k]

metadata = {"some": "metadata", "with": 2, "data": np.array([1, 2, 3])}
node.metadata = metadata
compare_dicts(node.metadata, metadata)
for node in ir.nodes.values():
node.metadata = metadata
compare_dicts(node.metadata, metadata)
tmp = tempfile.mktemp()
nir.write(tmp, node)
node2 = nir.read(tmp)
compare_dicts(node2.metadata, metadata)
nir.write(tmp, ir)
ir2 = nir.read(tmp)
for node in ir2.nodes.values():
compare_dicts(node.metadata, metadata)


def test_simple():
Expand Down Expand Up @@ -146,7 +148,7 @@ def test_linear():
tau = np.array([1, 1, 1])
r = np.array([1, 1, 1])
v_leak = np.array([1, 1, 1])
ir = nir.NIRGraph.from_list(mock_affine(2, 2), nir.LI(tau, r, v_leak))
ir = nir.NIRGraph.from_list(mock_linear(2, 2), nir.LI(tau, r, v_leak))
factory_test_graph(ir)
factory_test_metadata(ir)

Expand Down

0 comments on commit 95be0db

Please sign in to comment.