Skip to content

Commit

Permalink
[TorchFX] Fix depthwise weights quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 23, 2024
1 parent 1104f1b commit 4397b9e
Show file tree
Hide file tree
Showing 4 changed files with 1,489 additions and 1,313 deletions.
25 changes: 25 additions & 0 deletions nncf/experimental/torch/fx/groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nncf.torch.graph.operator_metatypes as om

FX_OPERATORS_WEIGHTS_METATYPES = (
om.PTConv1dMetatype,
om.PTConv2dMetatype,
om.PTConv3dMetatype,
om.PTLinearMetatype,
om.PTDepthwiseConv1dSubtype,
om.PTDepthwiseConv2dSubtype,
om.PTDepthwiseConv3dSubtype,
om.PTConvTranspose1dMetatype,
om.PTConvTranspose2dMetatype,
om.PTConvTranspose3dMetatype,
)
3 changes: 2 additions & 1 deletion nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.experimental.torch.fx.groups import FX_OPERATORS_WEIGHTS_METATYPES
from nncf.experimental.torch.fx.transformations import qdq_insertion_transformation_builder
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
Expand Down Expand Up @@ -348,6 +349,6 @@ def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]:
def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
retval = set()
for node in nncf_graph.get_all_nodes():
if node.metatype in [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype]:
if node.metatype in FX_OPERATORS_WEIGHTS_METATYPES:
retval.add(node)
return list(retval)
Loading

0 comments on commit 4397b9e

Please sign in to comment.