Skip to content

Commit

Permalink
[TorchFX] Model transformer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 27, 2024
1 parent 2925666 commit 1d42c0d
Show file tree
Hide file tree
Showing 23 changed files with 1,064 additions and 27 deletions.
32 changes: 14 additions & 18 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def module_insertion_transformation(model: torch.fx.GraphModule):
if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
_set_new_node_meta(new_node, target_node, module_to_insert)
with graph.inserting_after(target_node):
for user in target_node.users:
for user in list(target_node.users):
if user is new_node:
continue
user.replace_input_with(target_node, new_node)
Expand Down Expand Up @@ -110,12 +110,13 @@ def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
return leaf_module_insertion_transformation


def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor, input_port_id: int) -> TransformationFNType:
"""
Return transformation which updates constant of the given node with bias to the given value.
:param node: Node with bias which requires bias constant update.
:param value: New value to use as the bias constant.
:param input_port_id: Input port id to get constant node from.
:return: Transformation which updates constant of the given node with bias to the given value.
"""

Expand All @@ -131,29 +132,27 @@ def bias_update_transformation(model: torch.fx.GraphModule):
raise nncf.InternalError(f"Node {graph_node.name} has {len(add_nodes)} outputs with adds, 1 expected")

bias_node = add_nodes[0]
with graph.inserting_before(bias_node):
new_constant = create_getattr_from_value(model, graph, target_node_name + "_shifted_bias", value)

args = list(bias_node.args)
# A bias node suppose to have constant on the second input port.
args[1] = new_constant
bias_node.args = tuple(args)
graph.eliminate_dead_code()
constant_update_fn(model, bias_node, value, input_port_id=input_port_id)

return bias_update_transformation


def constant_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
def constant_update_transformation_builder(
node: NNCFNode, value: torch.Tensor, input_port_id: int
) -> TransformationFNType:
"""
Return transformation which updates constant of the given node to the given value.
:param node: Node which requires bias constant update.
:param value: New value to use as the node constant.
:param input_port_id: Input port id to get constant node from.
:return: Transformation which updates constant of the given node to the given value.
"""

def constant_update_transformation(model: torch.fx.GraphModule):
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=1)
constant_update_fn(
model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=input_port_id
)

return constant_update_transformation

Expand Down Expand Up @@ -204,7 +203,7 @@ def qdq_insertion_transformation_builder(

def qdq_insertion_transformation(model: torch.fx.GraphModule):
if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1:
raise RuntimeError(
raise nncf.InternalError(
"Insertion of shared qdq pair for the weights is not supported."
" Please use non shared qdq pairs for the weights quantization."
)
Expand Down Expand Up @@ -267,11 +266,8 @@ def output_insertion_transformation(model: torch.fx.GraphModule):
output_node = output_nodes[0]

args = output_node.args
if isinstance(args, tuple):
assert len(args) == 1
args = args[0] + (cloned_input,)
else:
args += (cloned_input,)
assert len(args) == 1
args = tuple(args[0]) + (cloned_input,)
output_node.args = (args,)

return output_insertion_transformation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
def create_bias_correction_command(
node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph
) -> FXApplyTransformationCommand:
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data))
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data, input_port_id=1))

@staticmethod
def model_extraction_command(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
def create_bias_correction_command(
node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph
) -> FXApplyTransformationCommand:
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data))
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data, input_port_id=1))

@staticmethod
def model_extraction_command(
Expand Down
3 changes: 2 additions & 1 deletion nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Dict, List, Optional, Set, Tuple

import torch
from torch.quantization.fake_quantize import FakeQuantize

import nncf
import nncf.torch.graph.operator_metatypes as om
Expand Down Expand Up @@ -230,7 +231,7 @@ def _create_quantizer(
scale_shape: Tuple,
parameters: FakeQuantizeParameters,
target_type: TargetType,
) -> BaseQuantizer:
) -> FakeQuantize:
mode = quantizer_config.mode
quantizer_cls = QUANTIZATION_MODULES.get(mode)
narrow_range = target_type == TargetType.OPERATION_WITH_WEIGHTS and mode == QuantizationMode.SYMMETRIC
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 conv2d_1" [id=6, type=conv2d];
"7 add__updated_constant0" [id=7, type=get_attr];
"8 add_" [id=8, type=add_];
"9 _tensor_constant0_1" [id=9, type=get_attr];
"10 add__1" [id=10, type=add_];
"11 add" [id=11, type=add];
"12 _param_constant4" [id=12, type=get_attr];
"13 _param_constant5" [id=13, type=get_attr];
"14 conv2d_2" [id=14, type=conv2d];
"15 _tensor_constant0_2" [id=15, type=get_attr];
"16 add_1" [id=16, type=add];
"17 output" [id=17, type=output];
"0 arg0_1" -> "3 conv2d" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"3 conv2d" -> "6 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"3 conv2d" -> "8 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "6 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "6 conv2d_1" [label="(3,)", style=solid];
"6 conv2d_1" -> "10 add__1" [label="(1, 3, 3, 3)", style=solid];
"7 add__updated_constant0" -> "8 add_" [label="(1,)", style=solid];
"8 add_" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"9 _tensor_constant0_1" -> "10 add__1" [label="(1,)", style=solid];
"10 add__1" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"11 add" -> "14 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"12 _param_constant4" -> "14 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"13 _param_constant5" -> "14 conv2d_2" [label="(3,)", style=solid];
"14 conv2d_2" -> "16 add_1" [label="(1, 3, 3, 3)", style=solid];
"15 _tensor_constant0_2" -> "16 add_1" [label="(1,)", style=solid];
"16 add_1" -> "17 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 TEST_MODULE_0" [id=3, type=call_module];
"4 TEST_MODULE_1" [id=4, type=call_module];
"5 conv2d" [id=5, type=conv2d];
"6 TEST_MODULE_3" [id=6, type=call_module];
"7 _param_constant2" [id=7, type=get_attr];
"8 _param_constant3" [id=8, type=get_attr];
"9 TEST_MODULE_2" [id=9, type=call_module];
"10 conv2d_1" [id=10, type=conv2d];
"11 _tensor_constant0" [id=11, type=get_attr];
"12 add_" [id=12, type=add_];
"13 _tensor_constant0_1" [id=13, type=get_attr];
"14 add__1" [id=14, type=add_];
"15 add" [id=15, type=add];
"16 _param_constant4" [id=16, type=get_attr];
"17 _param_constant5" [id=17, type=get_attr];
"18 conv2d_2" [id=18, type=conv2d];
"19 _tensor_constant0_2" [id=19, type=get_attr];
"20 add_1" [id=20, type=add];
"21 output" [id=21, type=output];
"0 arg0_1" -> "3 TEST_MODULE_0" [label="(1, 3, 3, 3)", style=solid];
"0 arg0_1" -> "5 conv2d" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant0" -> "4 TEST_MODULE_1" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant0" -> "5 conv2d" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "5 conv2d" [label="(3,)", style=solid];
"5 conv2d" -> "6 TEST_MODULE_3" [label="(1, 3, 3, 3)", style=solid];
"5 conv2d" -> "10 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"5 conv2d" -> "12 add_" [label="(1, 3, 3, 3)", style=solid];
"7 _param_constant2" -> "9 TEST_MODULE_2" [label="(3, 3, 1, 1)", style=solid];
"7 _param_constant2" -> "10 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"8 _param_constant3" -> "10 conv2d_1" [label="(3,)", style=solid];
"10 conv2d_1" -> "14 add__1" [label="(1, 3, 3, 3)", style=solid];
"11 _tensor_constant0" -> "12 add_" [label="(1,)", style=solid];
"12 add_" -> "15 add" [label="(1, 3, 3, 3)", style=solid];
"13 _tensor_constant0_1" -> "14 add__1" [label="(1,)", style=solid];
"14 add__1" -> "15 add" [label="(1, 3, 3, 3)", style=solid];
"15 add" -> "18 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"16 _param_constant4" -> "18 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"17 _param_constant5" -> "18 conv2d_2" [label="(3,)", style=solid];
"18 conv2d_2" -> "20 add_1" [label="(1, 3, 3, 3)", style=solid];
"19 _tensor_constant0_2" -> "20 add_1" [label="(1,)", style=solid];
"20 add_1" -> "21 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 TEST_MODULE_0" [id=3, type=call_module];
"4 TEST_MODULE_1" [id=4, type=call_module];
"5 conv2d" [id=5, type=conv2d];
"6 TEST_MODULE_3" [id=6, type=call_module];
"7 _param_constant2" [id=7, type=get_attr];
"8 _param_constant3" [id=8, type=get_attr];
"9 TEST_MODULE_2" [id=9, type=call_module];
"10 conv2d_1" [id=10, type=conv2d];
"11 _tensor_constant0" [id=11, type=get_attr];
"12 add_" [id=12, type=add_];
"13 _tensor_constant0_1" [id=13, type=get_attr];
"14 add__1" [id=14, type=add_];
"15 add" [id=15, type=add];
"16 _param_constant4" [id=16, type=get_attr];
"17 _param_constant5" [id=17, type=get_attr];
"18 conv2d_2" [id=18, type=conv2d];
"19 _tensor_constant0_2" [id=19, type=get_attr];
"20 add_1" [id=20, type=add];
"21 output" [id=21, type=output];
"0 arg0_1" -> "3 TEST_MODULE_0" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant0" -> "4 TEST_MODULE_1" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "5 conv2d" [label="(3,)", style=solid];
"3 TEST_MODULE_0" -> "5 conv2d" [label="(1, 3, 3, 3)", style=solid];
"4 TEST_MODULE_1" -> "5 conv2d" [label="(3, 3, 1, 1)", style=solid];
"5 conv2d" -> "6 TEST_MODULE_3" [label="(1, 3, 3, 3)", style=solid];
"6 TEST_MODULE_3" -> "10 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"6 TEST_MODULE_3" -> "12 add_" [label="(1, 3, 3, 3)", style=solid];
"7 _param_constant2" -> "9 TEST_MODULE_2" [label="(3, 3, 1, 1)", style=solid];
"8 _param_constant3" -> "10 conv2d_1" [label="(3,)", style=solid];
"9 TEST_MODULE_2" -> "10 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"10 conv2d_1" -> "14 add__1" [label="(1, 3, 3, 3)", style=solid];
"11 _tensor_constant0" -> "12 add_" [label="(1,)", style=solid];
"12 add_" -> "15 add" [label="(1, 3, 3, 3)", style=solid];
"13 _tensor_constant0_1" -> "14 add__1" [label="(1,)", style=solid];
"14 add__1" -> "15 add" [label="(1, 3, 3, 3)", style=solid];
"15 add" -> "18 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"16 _param_constant4" -> "18 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"17 _param_constant5" -> "18 conv2d_2" [label="(3,)", style=solid];
"18 conv2d_2" -> "20 add_1" [label="(1, 3, 3, 3)", style=solid];
"19 _tensor_constant0_2" -> "20 add_1" [label="(1,)", style=solid];
"20 add_1" -> "21 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant2" [id=1, type=get_attr];
"2 _param_constant3" [id=2, type=get_attr];
"3 conv2d_1" [id=3, type=conv2d];
"4 _tensor_constant0" [id=4, type=get_attr];
"5 add_" [id=5, type=add_];
"6 _tensor_constant0_1" [id=6, type=get_attr];
"7 add__1" [id=7, type=add_];
"8 add" [id=8, type=add];
"9 _param_constant4" [id=9, type=get_attr];
"10 _param_constant5" [id=10, type=get_attr];
"11 conv2d_2" [id=11, type=conv2d];
"12 _tensor_constant0_2" [id=12, type=get_attr];
"13 add_1" [id=13, type=add];
"14 output" [id=14, type=output];
"0 arg0_1" -> "3 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"0 arg0_1" -> "5 add_" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant2" -> "3 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant3" -> "3 conv2d_1" [label="(3,)", style=solid];
"3 conv2d_1" -> "7 add__1" [label="(1, 3, 3, 3)", style=solid];
"4 _tensor_constant0" -> "5 add_" [label="(1,)", style=solid];
"5 add_" -> "8 add" [label="(1, 3, 3, 3)", style=solid];
"6 _tensor_constant0_1" -> "7 add__1" [label="(1,)", style=solid];
"7 add__1" -> "8 add" [label="(1, 3, 3, 3)", style=solid];
"8 add" -> "11 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"9 _param_constant4" -> "11 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"10 _param_constant5" -> "11 conv2d_2" [label="(3,)", style=solid];
"11 conv2d_2" -> "13 add_1" [label="(1, 3, 3, 3)", style=solid];
"12 _tensor_constant0_2" -> "13 add_1" [label="(1,)", style=solid];
"13 add_1" -> "14 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant2_cloned" [id=5, type=clone];
"6 _param_constant3" [id=6, type=get_attr];
"7 conv2d_1" [id=7, type=conv2d];
"8 _tensor_constant0" [id=8, type=get_attr];
"9 add_" [id=9, type=add_];
"10 _tensor_constant0_1" [id=10, type=get_attr];
"11 add__1" [id=11, type=add_];
"12 add" [id=12, type=add];
"13 _param_constant4" [id=13, type=get_attr];
"14 _param_constant5" [id=14, type=get_attr];
"15 conv2d_2" [id=15, type=conv2d];
"16 _tensor_constant0_2" [id=16, type=get_attr];
"17 add_1" [id=17, type=add];
"18 output" [id=18, type=output];
"0 arg0_1" -> "3 conv2d" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"3 conv2d" -> "7 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"3 conv2d" -> "9 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "5 _param_constant2_cloned" [label="(3, 3, 1, 1)", style=solid];
"4 _param_constant2" -> "7 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant2_cloned" -> "18 output" [label=None, style=solid];
"6 _param_constant3" -> "7 conv2d_1" [label="(3,)", style=solid];
"7 conv2d_1" -> "11 add__1" [label="(1, 3, 3, 3)", style=solid];
"8 _tensor_constant0" -> "9 add_" [label="(1,)", style=solid];
"9 add_" -> "12 add" [label="(1, 3, 3, 3)", style=solid];
"10 _tensor_constant0_1" -> "11 add__1" [label="(1,)", style=solid];
"11 add__1" -> "12 add" [label="(1, 3, 3, 3)", style=solid];
"12 add" -> "15 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"13 _param_constant4" -> "15 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"14 _param_constant5" -> "15 conv2d_2" [label="(3,)", style=solid];
"15 conv2d_2" -> "17 add_1" [label="(1, 3, 3, 3)", style=solid];
"16 _tensor_constant0_2" -> "17 add_1" [label="(1,)", style=solid];
"17 add_1" -> "18 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 arg0_1_cloned" [id=1, type=clone];
"2 _param_constant0" [id=2, type=get_attr];
"3 _param_constant1" [id=3, type=get_attr];
"4 conv2d" [id=4, type=conv2d];
"5 _param_constant2" [id=5, type=get_attr];
"6 _param_constant3" [id=6, type=get_attr];
"7 conv2d_1" [id=7, type=conv2d];
"8 _tensor_constant0" [id=8, type=get_attr];
"9 add_" [id=9, type=add_];
"10 _tensor_constant0_1" [id=10, type=get_attr];
"11 add__1" [id=11, type=add_];
"12 add" [id=12, type=add];
"13 _param_constant4" [id=13, type=get_attr];
"14 _param_constant5" [id=14, type=get_attr];
"15 conv2d_2" [id=15, type=conv2d];
"16 _tensor_constant0_2" [id=16, type=get_attr];
"17 add_1" [id=17, type=add];
"18 output" [id=18, type=output];
"0 arg0_1" -> "1 arg0_1_cloned" [label="(1, 3, 3, 3)", style=solid];
"0 arg0_1" -> "4 conv2d" [label="(1, 3, 3, 3)", style=solid];
"1 arg0_1_cloned" -> "18 output" [label=None, style=solid];
"2 _param_constant0" -> "4 conv2d" [label="(3, 3, 1, 1)", style=solid];
"3 _param_constant1" -> "4 conv2d" [label="(3,)", style=solid];
"4 conv2d" -> "7 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"4 conv2d" -> "9 add_" [label="(1, 3, 3, 3)", style=solid];
"5 _param_constant2" -> "7 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"6 _param_constant3" -> "7 conv2d_1" [label="(3,)", style=solid];
"7 conv2d_1" -> "11 add__1" [label="(1, 3, 3, 3)", style=solid];
"8 _tensor_constant0" -> "9 add_" [label="(1,)", style=solid];
"9 add_" -> "12 add" [label="(1, 3, 3, 3)", style=solid];
"10 _tensor_constant0_1" -> "11 add__1" [label="(1,)", style=solid];
"11 add__1" -> "12 add" [label="(1, 3, 3, 3)", style=solid];
"12 add" -> "15 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"13 _param_constant4" -> "15 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"14 _param_constant5" -> "15 conv2d_2" [label="(3,)", style=solid];
"15 conv2d_2" -> "17 add_1" [label="(1, 3, 3, 3)", style=solid];
"16 _tensor_constant0_2" -> "17 add_1" [label="(1,)", style=solid];
"17 add_1" -> "18 output" [label="(1, 3, 3, 3)", style=solid];
}
Loading

0 comments on commit 1d42c0d

Please sign in to comment.