diff --git a/nncf/experimental/torch/fx/nncf_graph_builder.py b/nncf/experimental/torch/fx/nncf_graph_builder.py index d6358b47c4c..946ac27ce84 100644 --- a/nncf/experimental/torch/fx/nncf_graph_builder.py +++ b/nncf/experimental/torch/fx/nncf_graph_builder.py @@ -65,11 +65,14 @@ def _get_layer_attributes( return None @staticmethod - def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]: + def _get_node_type_and_metatype( + node: torch.fx.Node, model: torch.fx.GraphModule + ) -> Tuple[str, om.OperatorMetatype]: """ Retrieves node's type and metatype. :param node: Given node. + :param model: Given GraphModule. :return: Node's type and metatype. """ if node.op == "placeholder": @@ -95,6 +98,11 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe node_metatype = UnknownMetatype if node_metatype is UnknownMetatype: nncf_logger.debug(f"Unknown metatype for node: {node}") + + if node_metatype.get_subtypes(): + layer_attrs = GraphConverter._get_layer_attributes(node, node_metatype, model) + node_subtype = node_metatype.determine_subtype(layer_attrs) + node_metatype = node_subtype or node_metatype return node_type, node_metatype @staticmethod @@ -111,12 +119,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph: nncf_graph = PTNNCFGraph() for source_node in model.graph.nodes: - node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node) - - if node_metatype.get_subtypes(): - layer_attrs = GraphConverter._get_layer_attributes(source_node, node_metatype, model) - node_subtype = node_metatype.determine_subtype(layer_attrs) - node_metatype = node_subtype or node_metatype + node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node, model) nncf_graph.add_nncf_node( node_name=source_node.name,