Skip to content

Commit

Permalink
[PT] disable grad for extracted module (openvinotoolkit#3266)
Browse files Browse the repository at this point in the history
### Changes

Disable gradient for extracted modules

### Reason for changes

Possible memory leaks

### Tests

manual/job/post_training_quantization/606
  • Loading branch information
AlexanderDokuchaev authored Feb 11, 2025
1 parent c24bf74 commit 99f0c44
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
2 changes: 2 additions & 0 deletions nncf/torch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def extract_bn(node: NNCFNode, model: NNCFNetwork) -> Optional[Union[nn.BatchNor
for name, _ in chain(extracted_bn.named_parameters(), extracted_bn.named_buffers()):
setattr(extracted_bn, name, deepcopy(getattr(bn_module, name)))
extracted_bn.eval()
extracted_bn.weight.requires_grad = False
extracted_bn.bias.requires_grad = False
return extracted_bn


Expand Down
17 changes: 9 additions & 8 deletions tests/torch/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def test_extract_model(model_cls, input_node_name, output_node_name):

model = wrap_model(model_cls().eval(), example_input=example_input, trace_parameters=True)
extracted_module = extract_model(model, [input_node_name], [output_node_name])
with torch.no_grad():
ret1 = model(example_input)
ret2 = extracted_module(example_input)
assert torch.any(torch.isclose(ret1, ret2))
ret1 = model(example_input)
ret2 = extracted_module(example_input)
assert not ret2.grad_fn
assert torch.any(torch.isclose(ret1, ret2))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -122,10 +122,11 @@ def test_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_
q_model = transformer.transform(layout)

extracted_module = extract_model(model, [input_node_name], [output_node_name])
with torch.no_grad():
ret1 = q_model(example_input)
ret2 = extracted_module(example_input)
assert torch.all(torch.isclose(ret1, ret2))

ret1 = q_model(example_input)
ret2 = extracted_module(example_input)
assert torch.all(torch.isclose(ret1, ret2))
assert not ret2.grad_fn

extracted_fn = extracted_module
if isinstance(extracted_fn, nn.Sequential):
Expand Down

0 comments on commit 99f0c44

Please sign in to comment.