Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchFX] Pre-hook insertion support #2861

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 77 additions & 31 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
# Insert call_module nodes to the model
graph = model.graph
for target_point in target_points:
target_node = _get_target_node(graph, target_point)
_insert_call_module(graph, target_node, module_attr_name)
_insert_call_module(graph, target_point, module_attr_name)

return leaf_module_insertion_transformation

Expand Down Expand Up @@ -100,13 +99,12 @@ def qdq_insertion_tranformation(model: torch.fx.GraphModule):
" Please use non shared qdq pairs for the weights quantization."
)
for target_point in target_points:
target_node = _get_target_node(model.graph, target_point)
insert_one_qdq_after_node(model, target_node, quantizer)
insert_one_qdq(model, target_point, quantizer)

return qdq_insertion_tranformation


def insert_one_qdq_after_node(model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize):
def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, quantizer: FakeQuantize):
"""
Inserts quantize-dequantize after the target node to the target model.

Expand Down Expand Up @@ -146,6 +144,7 @@ def insert_one_qdq_after_node(model: torch.fx.GraphModule, target_node: torch.fx

# 2. replace activation_post_process node with quantize and dequantize
graph = model.graph
target_node = get_graph_node_by_name(graph, target_point.target_node_name)
# TODO(dlyakhov): use metatype to get correct input_port_id
# Do not quantize already quantized nodes
# inserting_before handle only order in the graph generated code.
Expand All @@ -170,51 +169,98 @@ def insert_one_qdq_after_node(model: torch.fx.GraphModule, target_node: torch.fx
# for qparams that are not scale/zero_point (like axis, dtype) we store
# them as literals in the graph.
quantize_op_inputs.append(value_or_node)
with graph.inserting_after(target_node):

input_node = get_input_node(target_point, target_node)
quantize_op_inputs[0] = input_node

ctx_manager = get_ctx_manager(graph, target_point)
with ctx_manager(target_node):
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
user_dq_nodes = []
with graph.inserting_after(quantized_node):
for user in target_node.users:
if user is quantized_node:
continue
user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {})))

for user, dq_node in user_dq_nodes:
user.replace_input_with(target_node, dq_node)
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
user_dq_nodes = []
with graph.inserting_after(quantized_node):
for user in target_node.users:
if user is quantized_node:
continue
user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {})))

for user, dq_node in user_dq_nodes:
user.replace_input_with(target_node, dq_node)
elif target_point.target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
with graph.inserting_after(quantized_node):
dq_node = graph.call_function(dequantize_op, tuple(dq_inputs), {})

args = list(target_node.args)
args[target_point.input_port_id] = dq_node
target_node.args = tuple(args)
else:
raise nncf.InternalError(f"Unexpected target type: {target_point.target_type}")


def _insert_call_module(graph: torch.fx.Graph, target_node: torch.fx.Node, module_attr_name: str):
def _insert_call_module(graph: torch.fx.Graph, target_point: PTTargetPoint, module_attr_name: str):
"""
Inserts module call node to the graph after the target node.

:param graph: Graph to insert module call node.
:param target_node: Target node, module call node is being iserted just after the target node.
:param module_attr_name: The name of the graph attribute which keeps the target module.
"""
with graph.inserting_after(target_node):
target_node = get_graph_node_by_name(graph, target_point.target_node_name)
input_node = get_input_node(target_point, target_node)
ctx_manager = get_ctx_manager(graph, target_point)
with ctx_manager(target_node):
return graph.create_node(
"call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_graph_node"
"call_module",
module_attr_name,
(input_node,),
{},
name=f"{module_attr_name}_{str(target_point.target_type)}_graph_node",
)


def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint) -> torch.fx.Node:
def get_input_node(target_point: PTTargetPoint, target_node: torch.fx.Node) -> torch.fx.Node:
"""
Returns TorchFX graph node correspondent to the target point.
Returns an input node according to the given target point.

:param graph: Target torch.fx.Graph.
:param target_point: A target point to find the target node.
:return: TorchFX graph node correspondent to the target point.
:param target_point: Given target point.
:param target_node: The target node of the given target point.
:return: An input node according to the given target point.
"""
# TODO(dlyakhov): Support node insertion on a specific input port id.
target_type = target_point.target_type
target_node = get_graph_node_by_name(graph, target_point.target_node_name)
if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
target_node = target_node.all_input_nodes[target_point.input_port_id]
elif target_type != TargetType.OPERATOR_POST_HOOK:
raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}")
return target_node
if target_type not in [
TargetType.OPERATOR_PRE_HOOK,
TargetType.OPERATOR_POST_HOOK,
TargetType.OPERATION_WITH_WEIGHTS,
]:
raise nncf.InternalError(f"Unexpected target type: {target_type}")
if target_type == TargetType.OPERATOR_POST_HOOK:
return target_node
return target_node.args[target_point.input_port_id]


def get_ctx_manager(graph: torch.fx.Graph, target_point: PTTargetPoint) -> Callable:
"""
Return insertion context manager according to the given target point.
An insertion context manager sets the point at which create_node and
companion methods will insert into the torch.fx.Graph.

:param graph: torch.fx.Graph instance.
:param target_point: Given target point.
:return: Insertion context manager according to the given target point.
"""
if target_point.target_type not in [
TargetType.OPERATOR_PRE_HOOK,
TargetType.OPERATOR_POST_HOOK,
TargetType.OPERATION_WITH_WEIGHTS,
]:
raise nncf.InternalError(f"Unexpected target type: {target_point.target_type}")

if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
return graph.inserting_after
return graph.inserting_before


def _set_module_to_the_graph_module(
Expand Down
Loading
Loading