From 99f0c44d6d674acc7c55703c80d208ec808642f3 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Tue, 11 Feb 2025 17:30:46 +0200 Subject: [PATCH] [PT] disable grad for extracted module (#3266) ### Changes Disable gradient for extracted modules ### Reason for changes Possible memory leaks ### Tests manual/job/post_training_quantization/606 --- nncf/torch/extractor.py | 2 ++ tests/torch/test_extractor.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/nncf/torch/extractor.py b/nncf/torch/extractor.py index 640a27f5ac6..fa11cf4fe2c 100644 --- a/nncf/torch/extractor.py +++ b/nncf/torch/extractor.py @@ -153,6 +153,8 @@ def extract_bn(node: NNCFNode, model: NNCFNetwork) -> Optional[Union[nn.BatchNor for name, _ in chain(extracted_bn.named_parameters(), extracted_bn.named_buffers()): setattr(extracted_bn, name, deepcopy(getattr(bn_module, name))) extracted_bn.eval() + extracted_bn.weight.requires_grad = False + extracted_bn.bias.requires_grad = False return extracted_bn diff --git a/tests/torch/test_extractor.py b/tests/torch/test_extractor.py index 64eef988d33..60b6be33bdc 100644 --- a/tests/torch/test_extractor.py +++ b/tests/torch/test_extractor.py @@ -61,10 +61,10 @@ def test_extract_model(model_cls, input_node_name, output_node_name): model = wrap_model(model_cls().eval(), example_input=example_input, trace_parameters=True) extracted_module = extract_model(model, [input_node_name], [output_node_name]) - with torch.no_grad(): - ret1 = model(example_input) - ret2 = extracted_module(example_input) - assert torch.any(torch.isclose(ret1, ret2)) + ret1 = model(example_input) + ret2 = extracted_module(example_input) + assert not ret2.grad_fn + assert torch.any(torch.isclose(ret1, ret2)) @pytest.mark.parametrize( @@ -122,10 +122,11 @@ def test_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_ q_model = transformer.transform(layout) extracted_module = extract_model(model, [input_node_name], [output_node_name]) - with torch.no_grad(): - ret1 = q_model(example_input) - ret2 = extracted_module(example_input) - assert torch.all(torch.isclose(ret1, ret2)) + + ret1 = q_model(example_input) + ret2 = extracted_module(example_input) + assert torch.all(torch.isclose(ret1, ret2)) + assert not ret2.grad_fn extracted_fn = extracted_module if isinstance(extracted_fn, nn.Sequential):