Skip to content

Commit

Permalink
Support import aliasing
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Apr 19, 2023
1 parent 79d8f1c commit d8a01cc
Show file tree
Hide file tree
Showing 22 changed files with 89 additions and 20 deletions.
34 changes: 28 additions & 6 deletions docs/api/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pkgutil
import sys
from typing import Dict
from typing import List

from sphinx.ext.autodoc import mock
Expand Down Expand Up @@ -50,30 +51,51 @@ def collect_api_entities() -> List[str]:
except Exception as e:
skipped_modules[modname] = str(e)

api_fqns = []
from nncf.common.api_marker import api
api_fqns = dict()
aliased_fqns = {} # type: Dict[str, str]
for modname, module in modules.items():
print(f"{modname}")
for obj_name, obj in inspect.getmembers(module):
objects_module = getattr(obj, '__module__', None)
if objects_module == modname:
if inspect.isclass(obj) or inspect.isfunction(obj):
if hasattr(obj, "_nncf_api_marker"):
if hasattr(obj, api.API_MARKER_ATTR):
marked_object_name = obj._nncf_api_marker
# Check the actual name of the originally marked object
# so that the classes derived from base API classes don't
# all automatically end up in API
if marked_object_name == obj.__name__:
print(f"\t{obj_name}")
api_fqns.append(f"{modname}.{obj_name}")
if marked_object_name != obj.__name__:
continue
fqn = f"{modname}.{obj_name}"
if hasattr(obj, api.CANONICAL_ALIAS_ATTR):
canonical_import_name = getattr(obj, api.CANONICAL_ALIAS_ATTR)
aliased_fqns[fqn] = canonical_import_name
if canonical_import_name == fqn:
print(f"\t{obj_name}")
else:
print(f"\t{obj_name} -> {canonical_import_name}")
api_fqns[fqn] = True

print()
skipped_str = '\n'.join([f"{k}: {v}" for k, v in skipped_modules.items()])
print(f"Skipped: {skipped_str}\n")
for fqn, canonical_alias in aliased_fqns.items():
try:
module_name, _, function_name = canonical_alias.rpartition('.')
getattr(importlib.import_module(module_name), function_name)
except (ImportError, AttributeError) as e:
print(
f"API entity with canonical_alias={canonical_alias} not available for import as specified!\n"
f"Adjust the __init__.py files so that the symbol is available for import as {canonical_alias}.")
raise e
api_fqns.pop(fqn)
api_fqns[canonical_alias] = True

print("API entities:")
for api_fqn in api_fqns:
print(api_fqn)
return api_fqns
return list(api_fqns.keys())


with mock(['torch', 'torchvision', 'onnx', 'onnxruntime', 'openvino', 'tensorflow', 'tensorflow_addons']):
Expand Down
2 changes: 2 additions & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,5 @@
else:
nncf_logger.info(f"NNCF initialized successfully. Supported frameworks detected: "
f"{', '.join([name for name, loaded in _LOADED_FRAMEWORKS.items() if loaded])}")


3 changes: 3 additions & 0 deletions nncf/common/accuracy_aware_training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from scipy.interpolate import interp1d

from nncf.api.compression import CompressionAlgorithmController
from nncf.common.api_marker import api
from nncf.common.composite_compression import CompositeCompressionAlgorithmController
from nncf.common.logging import nncf_logger
from nncf.common.utils.registry import Registry
Expand Down Expand Up @@ -168,6 +169,7 @@ def _accuracy_criterion_satisfied(self):
return accuracy_budget >= 0 and self.runner.is_model_fully_compressed(self.compression_controller)


@api(canonical_alias="nncf.tensorflow.EarlyExitCompressionTrainingLoop")
class EarlyExitCompressionTrainingLoop(BaseEarlyExitCompressionTrainingLoop):
"""
Adaptive compression training loop allows an accuracy-aware training process
Expand All @@ -191,6 +193,7 @@ def __init__(self,
self.runner = runner_factory.create_training_loop()


@api(canonical_alias="nncf.tensorflow.AdaptiveCompressionTrainingLoop")
class AdaptiveCompressionTrainingLoop(BaseEarlyExitCompressionTrainingLoop):
"""
Adaptive compression training loop allows an accuracy-aware training process whereby
Expand Down
26 changes: 24 additions & 2 deletions nncf/common/api_marker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
"""
Copyright (c) 2023 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""


class api:
API_MARKER_ATTR = "_nncf_api_marker"
CANONICAL_ALIAS_ATTR = "_nncf_canonical_alias"

def __init__(self):
pass
def __init__(self, canonical_alias: str = None):
self._canonical_alias = canonical_alias

def __call__(self, obj):
# The value of the marker will be useful in determining
# whether we are handling a base class or a derived one.
setattr(obj, api.API_MARKER_ATTR, obj.__name__)
if self._canonical_alias is not None:
setattr(obj, api.CANONICAL_ALIAS_ATTR, self._canonical_alias)
return obj


def is_api(obj) -> bool:
return hasattr(obj, api.API_MARKER_ATTR)

2 changes: 2 additions & 0 deletions nncf/common/quantization/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from enum import Enum
from typing import Dict, List, Optional, Any

from nncf.common.api_marker import api
from nncf.common.graph import NNCFNode
from nncf.common.graph import NNCFNodeName
from nncf.config.schemata.defaults import QUANTIZATION_BITS
Expand Down Expand Up @@ -308,6 +309,7 @@ class UnifiedScaleType(Enum):
UNIFY_ALWAYS = 1


@api(canonical_alias="nncf.QuantizationPreset")
class QuantizationPreset(Enum):
PERFORMANCE = 'performance'
MIXED = 'mixed'
Expand Down
2 changes: 1 addition & 1 deletion nncf/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from nncf.config.structures import NNCFExtraConfigStruct


@api()
@api(canonical_alias="nncf.NNCFConfig")
class NNCFConfig(dict):
"""A regular dictionary object extended with some utility functions."""

Expand Down
2 changes: 1 addition & 1 deletion nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ModelInput = TypeVar('ModelInput')


@api()
@api(canonical_alias="nncf.Dataset")
class Dataset(Generic[DataItem, ModelInput]):
"""
The `nncf.Dataset` class defines the interface by which compression algorithms
Expand Down
4 changes: 2 additions & 2 deletions nncf/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from nncf.common.api_marker import api


@api()
@api(canonical_alias="nncf.TargetDevice")
class TargetDevice(Enum):
"""
Describes the target device the specificity of which will be taken
Expand All @@ -36,7 +36,7 @@ class TargetDevice(Enum):
CPU_SPR = 'CPU_SPR'


@api()
@api(canonical_alias="nncf.ModelType")
class ModelType(Enum):
"""
Describes the model type the specificity of which will be taken into
Expand Down
3 changes: 2 additions & 1 deletion nncf/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nncf.parameters import TargetDevice


@api()
@api(canonical_alias="nncf.quantize")
def quantize(model: TModel,
calibration_dataset: Dataset,
preset: QuantizationPreset = QuantizationPreset.PERFORMANCE,
Expand Down Expand Up @@ -85,6 +85,7 @@ def quantize(model: TModel,
raise RuntimeError(f'Unsupported type of backend: {backend}')


@api(canonical_alias="nncf.quantize_with_accuracy_control")
def quantize_with_accuracy_control(model: TModel,
calibration_dataset: Dataset,
validation_dataset: Dataset,
Expand Down
2 changes: 1 addition & 1 deletion nncf/scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nncf.common.graph.graph import NNCFGraph


@api()
@api(canonical_alias="nncf.IgnoredScope")
@dataclass
class IgnoredScope:
"""
Expand Down
1 change: 1 addition & 0 deletions nncf/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pkg_resources import parse_version

try:
_tf_version = tensorflow.__version__
tensorflow_version = parse_version(_tf_version).base_version
except:
nncf_logger.debug("Could not parse tensorflow version")
Expand Down
3 changes: 2 additions & 1 deletion nncf/tensorflow/helpers/callback_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from nncf.common.api_marker import api
from nncf.common.composite_compression import CompositeCompressionAlgorithmController
from nncf.tensorflow.pruning.base_algorithm import BasePruningAlgoController
from nncf.tensorflow.pruning.callbacks import PruningStatisticsCallback
Expand All @@ -19,6 +19,7 @@
from nncf.tensorflow.sparsity.base_algorithm import BaseSparsityController


@api(canonical_alias="nncf.tensorflow.create_compression_callbacks")
def create_compression_callbacks(compression_ctrl, log_tensorboard=True, log_text=True, log_dir=None):
compression_controllers = compression_ctrl.child_ctrls \
if isinstance(compression_ctrl, CompositeCompressionAlgorithmController) \
Expand Down
2 changes: 1 addition & 1 deletion nncf/tensorflow/helpers/model_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_compression_algorithm_builder(config: NNCFConfig,

return TFCompositeCompressionAlgorithmBuilder(config, should_init)

@api()
@api(canonical_alias="nncf.tensorflow.create_compressed_model")
@tracked_function(NNCF_TF_CATEGORY, [CompressionStartedFromConfig(argname="config"), ])
def create_compressed_model(model: tf.keras.Model,
config: NNCFConfig,
Expand Down
2 changes: 1 addition & 1 deletion nncf/tensorflow/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __iter__(self):
return iter(self._data_loader)


@api()
@api(canonical_alias="nncf.tensorflow.register_default_init_args")
def register_default_init_args(nncf_config: NNCFConfig,
data_loader: tf.data.Dataset,
batch_size: int,
Expand Down
3 changes: 2 additions & 1 deletion nncf/torch/checkpoint_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import torch

from nncf.common.api_marker import api
from nncf.common.logging import nncf_logger
from nncf.common.deprecation import warning_deprecated


@api(canonical_alias="nncf.torch.load_state")
def load_state(model: torch.nn.Module, state_dict_to_load: dict, is_resume: bool = False,
keys_to_ignore: List[str] = None) -> int:
"""
Expand Down
4 changes: 4 additions & 0 deletions nncf/torch/dynamic_graph/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch

from nncf.common.api_marker import api
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.utils.debug import is_debug
from nncf.torch.dynamic_graph.graph import DynamicGraph
Expand Down Expand Up @@ -429,6 +430,7 @@ def set_current_context(c: TracingContext):
_CURRENT_CONTEXT.context = c


@api(canonical_alias="nncf.torch.no_nncf_trace")
@contextmanager
def no_nncf_trace():
ctx = get_current_context()
Expand All @@ -440,6 +442,7 @@ def no_nncf_trace():
yield


@api(canonical_alias="nncf.torch.forward_nncf_trace")
@contextmanager
def forward_nncf_trace():
ctx = get_current_context()
Expand All @@ -455,6 +458,7 @@ def get_current_context() -> TracingContext:
return _CURRENT_CONTEXT.context


@api(canonical_alias="nncf.torch.disable_tracing")
def disable_tracing(method):
"""
Patch a method so that it will be executed within no_nncf_trace context
Expand Down
3 changes: 3 additions & 0 deletions nncf/torch/dynamic_graph/io_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

from nncf.common.api_marker import api
from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME
from nncf.common.graph.definitions import MODEL_OUTPUT_OP_NAME
from nncf.torch.dynamic_graph.patch_pytorch import register_operator
Expand All @@ -16,11 +17,13 @@
from nncf.torch.dynamic_graph.context import forward_nncf_trace


@api(canonical_alias="nncf.torch.nncf_model_input")
@register_operator(name=MODEL_INPUT_OP_NAME)
def nncf_model_input(tensor: 'torch.Tensor'):
return tensor


@api(canonical_alias="nncf.torch.nncf_model_output")
@register_operator(name=MODEL_OUTPUT_OP_NAME)
def nncf_model_output(tensor: 'torch.Tensor'):
return tensor
Expand Down
2 changes: 2 additions & 0 deletions nncf/torch/dynamic_graph/patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn.parallel import DistributedDataParallel

from nncf import nncf_logger
from nncf.common.api_marker import api
from nncf.torch.dynamic_graph.structs import NamespaceTarget
from nncf.torch.dynamic_graph.trace_tensor import TracedTensor
from nncf.torch.dynamic_graph.wrappers import ignore_scope
Expand Down Expand Up @@ -96,6 +97,7 @@ class MagicFunctionsToPatch:
}


@api(canonical_alias="nncf.torch.load_state")
def register_operator(name=None):
def wrap(operator):
op_name = name
Expand Down
3 changes: 3 additions & 0 deletions nncf/torch/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torch.utils.cpp_extension import _get_build_directory

from nncf.common.api_marker import api
from nncf.common.logging.logger import extension_is_loading_info_log
from nncf.common.utils.registry import Registry
from nncf.common.logging import nncf_logger
Expand Down Expand Up @@ -80,10 +81,12 @@ def _force_build_extensions(ext_type: ExtensionsType):
class_type.load()


@api(canonical_alias="nncf.torch.force_build_cpu_extensions")
def force_build_cpu_extensions():
_force_build_extensions(ExtensionsType.CPU)


@api(canonical_alias="nncf.torch.force_build_cuda_extensions")
def force_build_cuda_extensions():
_force_build_extensions(ExtensionsType.CUDA)

Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def default_criterion_fn(outputs: Any, target: Any, criterion: Any) -> torch.Ten
return criterion(outputs, target)


@api()
@api(canonical_alias="nncf.torch.register_default_init_args")
def register_default_init_args(nncf_config: 'NNCFConfig',
train_loader: torch.utils.data.DataLoader,
criterion: _Loss = None,
Expand Down
2 changes: 2 additions & 0 deletions nncf/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.nn.utils.weight_norm import WeightNorm

from nncf import nncf_logger
from nncf.common.api_marker import api
from nncf.torch.dynamic_graph.context import forward_nncf_trace
from nncf.torch.utils import no_jit_trace
from nncf.torch.checkpoint_loading import OPTIONAL_PARAMETERS_REGISTRY
Expand Down Expand Up @@ -385,6 +386,7 @@ def from_module(module):
NNCF_WRAPPED_USER_MODULES_DICT = {}


@api(canonical_alias="nncf.torch.load_state")
def register_module(*quantizable_field_names: str, ignored_algorithms: list = None):
# quantizable_field_names will work for `weight` attributes only. Should later extend to registering
# customly named attributes if it becomes necessary
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/model_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from nncf.torch.utils import training_mode_switcher


@api()
@api(canonical_alias="nncf.torch.create_compressed_model")
@tracked_function(NNCF_PT_CATEGORY, [CompressionStartedFromConfig(argname="config"), ])
def create_compressed_model(model: Module,
config: NNCFConfig,
Expand Down

0 comments on commit d8a01cc

Please sign in to comment.