Skip to content

Commit

Permalink
Generalize all annotators into a generic parameterizable annotator (#…
Browse files Browse the repository at this point in the history
…7298)

Change-Id: I69ecdbef9d7b83a87655e97758215303374b5f04
  • Loading branch information
Tessil authored Jan 8, 2025
1 parent 9a23cff commit 08770b7
Show file tree
Hide file tree
Showing 18 changed files with 339 additions and 1,098 deletions.
12 changes: 11 additions & 1 deletion backends/arm/quantizer/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@ python_library(
srcs = ["arm_quantizer.py"],
deps = [
":arm_quantizer_utils",
":quantization_annotator",
"//caffe2:torch",
"//executorch/backends/arm/quantizer/quantization_annotation:quantization_annotation",
"//executorch/exir:lib",
],
)

python_library(
name = "quantization_annotator",
srcs = ["quantization_annotator.py"],
deps = [
":arm_quantizer_utils",
":quantization_config",
"//caffe2:torch",
],
)

python_library(
name = "quantization_config",
srcs = ["quantization_config.py"],
Expand Down
123 changes: 10 additions & 113 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,16 @@

from __future__ import annotations

import copy
import functools
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn.functional as F
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager

from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.arm_quantizer_utils import (
mark_nodes_as_annotated,
propagate_annotation,
)
from executorch.backends.arm.quantizer.quantization_annotation import (
OP_TO_ANNOTATOR,
OperatorConfig,
OperatorPatternType,
)
from executorch.backends.arm.quantizer.arm_quantizer_utils import mark_node_as_annotated
from executorch.backends.arm.quantizer.quantization_annotator import annotate_graph

from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.fake_quantize import (
FakeQuantize,
Expand Down Expand Up @@ -58,44 +50,6 @@
]


def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
supported_operators: Dict[str, List[OperatorPatternType]] = {
# Both conv and linear should be able to handle relu + hardtanh fusion since
# those are clamp ops
"conv2d": [
[torch.nn.Conv2d, torch.nn.ReLU],
[torch.nn.Conv2d, F.relu],
[F.conv2d, torch.nn.ReLU],
[F.conv2d, F.relu],
],
"linear": [[torch.nn.Linear], [F.linear]],
"add": [[torch.add]],
"max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
"adaptive_avg_pool2d": [
[torch.nn.AdaptiveAvgPool2d],
[F.adaptive_avg_pool2d],
],
"mul": [[torch.mul]],
"sub": [[torch.sub]],
"min_max": [[torch.min], [torch.max]],
}
return copy.deepcopy(supported_operators)


def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [
get_symmetric_quantization_config(),
get_symmetric_quantization_config(is_per_channel=True),
]:
ops = _supported_symmetric_quantized_operators()
for pattern_list in ops.values():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
)
return copy.deepcopy(supported_config_and_operators)


@functools.lru_cache
def get_symmetric_quantization_config(
is_per_channel: bool = False,
Expand Down Expand Up @@ -180,10 +134,6 @@ def get_symmetric_quantization_config(
return quantization_config


def _get_supported_config_and_operators() -> List[OperatorConfig]:
return _get_supported_symmetric_config_and_operators()


NodeFilterType = Callable[[Node], bool]
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
a Node and returns whether the node should be annotated or not.
Expand Down Expand Up @@ -255,26 +205,6 @@ def not_module_type_or_name_filter(n: Node) -> bool:


class ArmQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()

# A list of supported static quantization annotators, in order of application.
# For example, fusions come before singular ops.
# The name must match the name used when registering the annotator.
STATIC_ANNOTATION_ORDER = [
"linear",
"conv",
"adaptive_avg_pool2d",
"max_pool2d",
"add",
"sub",
"mul",
"min_max",
"mm",
"one_to_one",
"generic",
"upsample_nearest2d",
]

def __init__(self) -> None:
super().__init__()
self.global_config: Optional[QuantizationConfig] = None
Expand Down Expand Up @@ -331,7 +261,6 @@ def annotate(self, model: GraphModule) -> GraphModule:
The annotated model.
"""
model = self._annotate_for_static_quantization_config(model)
propagate_annotation(model)
return model

def _annotate_all_static_patterns(
Expand All @@ -353,8 +282,7 @@ def _annotate_all_static_patterns(
if quantization_config is None:
return model

for op in self.STATIC_ANNOTATION_ORDER:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
annotate_graph(model, quantization_config, filter_fn)
return model

def _annotate_for_static_quantization_config(
Expand All @@ -363,6 +291,9 @@ def _annotate_for_static_quantization_config(
"""Matches the correct QuantizationConfig with the correct module using a filter
when running _annotate_all_static_patterns.
"""
if self.io_config:
self._annotate_io(model, self.io_config)

module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_static_patterns(
Expand All @@ -381,9 +312,6 @@ def _annotate_for_static_quantization_config(
_get_not_module_type_or_name_filter(tp_list, module_name_list),
)

if self.io_config:
self._annotate_io(model, self.io_config)

return model

def _annotate_io(
Expand All @@ -399,44 +327,13 @@ def _annotate_io(
node,
quantization_config.get_output_act_qspec(),
)
mark_nodes_as_annotated([node])
mark_node_as_annotated(node)
if node.op == "output":
parent = node.all_input_nodes[0]
_annotate_input_qspec_map(
node, parent, quantization_config.get_input_act_qspec()
)
mark_nodes_as_annotated([node])
mark_node_as_annotated(node)

def validate(self, model: GraphModule) -> None:
pass

@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return cls.supported_config_and_operators

@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.supported_config_and_operators:
op_configs.add(spec)
return list(op_configs)

@classmethod
def get_supported_operator_for_quantization_config(
cls, quantization_config: Optional[QuantizationConfig]
) -> List[OperatorPatternType]:
if quantization_config is None:
all_ops = []
for _, ops in cls.supported_config_and_operators:
all_ops.extend(ops)
return all_ops

for config, ops in cls.supported_config_and_operators:
# note: this assumes each entry in cls.supported_spec_and_operators
# corresponds to one spec, e.g. we don't have
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
# where the first and second entry have the same spec but did not
# merge the op list
if config == quantization_config:
return ops
return []
Loading

0 comments on commit 08770b7

Please sign in to comment.