From 2a76b9c48d4b31d1c98c765d397f082929593acc Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 3 Dec 2024 15:33:37 +0100 Subject: [PATCH] [TorchFX][MicroFix] Folded constats do not require grad --- .../experimental/torch/fx/constant_folding.py | 37 ++++++++++--------- tests/torch/fx/test_model_transformer.py | 4 ++ 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/nncf/experimental/torch/fx/constant_folding.py b/nncf/experimental/torch/fx/constant_folding.py index 945f79b497c..f60db84a928 100644 --- a/nncf/experimental/torch/fx/constant_folding.py +++ b/nncf/experimental/torch/fx/constant_folding.py @@ -233,25 +233,26 @@ def constant_fold( :param constraint_fn: Constraint function which takes a node and returs the constraint: should the node be constant folded or not. """ - with torch.utils._python_dispatch._disable_current_modes(): - cf = ConstantFolder(gm) - cf.run() + with torch.no_grad(): + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm) + cf.run() - for node, constant in cf.node_replacements.items(): - if constraint_fn is not None and not constraint_fn(node): - continue - _replace_node_with_constant(gm, node, constant) + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + _replace_node_with_constant(gm, node, constant) - erased_params = [] - for node in gm.graph.find_nodes(op="get_attr"): - if len(node.users) == 0: - if hasattr(gm, node.target): - delattr(gm, node.target) - erased_params.append(node) + erased_params = [] + for node in gm.graph.find_nodes(op="get_attr"): + if len(node.users) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) - for node in erased_params: - gm.graph.erase_node(node) + for node in erased_params: + gm.graph.erase_node(node) - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() diff --git a/tests/torch/fx/test_model_transformer.py b/tests/torch/fx/test_model_transformer.py index 23039ee99ef..a3d10431845 100644 --- a/tests/torch/fx/test_model_transformer.py +++ b/tests/torch/fx/test_model_transformer.py @@ -502,6 +502,10 @@ def test_constant_folding(): captured_model = get_torch_fx_model(model, torch.ones(model.INPUT_SIZE)) folded_model = deepcopy(captured_model) constant_fold(folded_model) + + # Check the folded const does not require gradient + assert not folded_model._frozen_param0.requires_grad + ex_input = torch.ones(model.INPUT_SIZE) assert torch.allclose(captured_model(ex_input), folded_model(ex_input))