Skip to content

Commit

Permalink
[PT] Constant graph (openvinotoolkit#2316)
Browse files Browse the repository at this point in the history
### Changes

Introduced model parameters tracing


![image](https://github.com/openvinotoolkit/nncf/assets/3229971/9eb4d341-832a-42d6-a3d5-e4722abf712d)


### Reason for changes

General approach to access and change weights in PyTorch model.

### Related tickets

ref: 119585

### Tests

test_graphs.py
test_compressed_graph.py
test_graph_building.py
test_nncf_network.py
test_tracing_context.py
  • Loading branch information
alexsu52 authored Dec 19, 2023
1 parent 4660729 commit 4b903a5
Show file tree
Hide file tree
Showing 48 changed files with 13,802 additions and 188 deletions.
2 changes: 2 additions & 0 deletions nncf/common/graph/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

MODEL_INPUT_OP_NAME = "nncf_model_input"
MODEL_OUTPUT_OP_NAME = "nncf_model_output"
MODEL_CONST_OP_NAME = "nncf_model_const"


class NNCFGraphNodeType:
INPUT_NODE = MODEL_INPUT_OP_NAME
OUTPUT_NODE = MODEL_OUTPUT_OP_NAME
CONST_NODE = MODEL_CONST_OP_NAME
9 changes: 9 additions & 0 deletions nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,12 @@ class ConvertDtypeLayerAttributes(BaseLayerAttributes):

src_dtype: Any
dst_dtype: Any


@dataclass
class ParameterLayerAttributes(BaseLayerAttributes):
"""
:param name: Parameter name.
"""

name: str
11 changes: 11 additions & 0 deletions nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def get_operator_metatype_by_op_name(self, op_name: str) -> Type[OperatorMetatyp
NOOP_METATYPES = Registry("noop_metatypes")
INPUT_NOOP_METATYPES = Registry("input_noop_metatypes")
OUTPUT_NOOP_METATYPES = Registry("output_noop_metatypes")
CONST_NOOP_METATYPES = Registry("const_noop_metatypes")


class UnknownMetatype(OperatorMetatype):
Expand Down Expand Up @@ -175,3 +176,13 @@ class OutputNoopMetatype(OperatorMetatype):
@classmethod
def get_all_aliases(cls) -> List[str]:
return [NNCFGraphNodeType.OUTPUT_NODE]


@NOOP_METATYPES.register()
@CONST_NOOP_METATYPES.register()
class ConstNoopMetatype(OperatorMetatype):
name = "const_noop"

@classmethod
def get_all_aliases(cls) -> List[str]:
return [NNCFGraphNodeType.CONST_NODE]
43 changes: 36 additions & 7 deletions nncf/torch/dynamic_graph/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections import defaultdict
from collections import deque
from contextlib import contextmanager
from typing import Callable, DefaultDict, List, Optional
from typing import Callable, DefaultDict, List, Optional, Union

import torch

Expand All @@ -31,7 +31,7 @@
from nncf.torch.dynamic_graph.scope import Scope
from nncf.torch.dynamic_graph.scope import ScopeElement
from nncf.torch.dynamic_graph.trace_tensor import TensorMeta
from nncf.torch.dynamic_graph.trace_tensor import TracedTensor
from nncf.torch.dynamic_graph.trace_tensor import TracedTensorMixin


class ThreadLocalGlobalContext(threading.local):
Expand Down Expand Up @@ -67,12 +67,14 @@ def reset(self):
self.scopes = []
self.module_call_stack = []
self.in_operator = False
self.in_parameter_trace = False
self.num_nested_hooks = 0
self.base_module_replica = None
self.operator_counters = {}
self.node_call_tracker = {}
self.traced_tensor_weakrefs = []
self.nested_contexts_stack = []
self.processed_parameters = {}


class CopySafeThreadingVars:
Expand Down Expand Up @@ -132,7 +134,7 @@ def __exit__(self, *args):
previous_context = self._threading.thread_local.nested_contexts_stack.pop(-1)
for traced_tensor_weakref in self._threading.thread_local.traced_tensor_weakrefs:
tt = traced_tensor_weakref()
if tt is None or not isinstance(tt, TracedTensor):
if tt is None or not isinstance(tt, TracedTensorMixin):
continue
if previous_context is None:
tt.strip()
Expand Down Expand Up @@ -163,17 +165,36 @@ def find_operator_node(
def register_global_buffer(self, name: str, buffer):
self.global_buffer_store[name] = buffer

def register_traced_tensor(self, tt: TracedTensor):
def register_traced_tensor(self, tt: torch.Tensor):
"""
Registers a weak reference to a traced tensor in the context so that in case
the block under context retains a reference to an intermediate tensor somewhere,
the context can mark this traced tensor reference as "expired" tracing-wise upon context
exit.
:param tt: A TracedTensor to be registered.
the context strips this traced tensor upon context exit.
:param tt: A tensor with TracedTensorMixin tracing capabilities to be registered.
"""
wr = weakref.ref(tt)
self._threading.thread_local.traced_tensor_weakrefs.append(wr)

def register_processed_parameter(self, param_name: str, tensor: torch.Tensor) -> None:
"""
Registers the processed parameter in the context to avoid double calculation of hooks
for the same parameters.
:param param_name: The parameter name.
:param tensor: The processed parameter.
"""
self._threading.thread_local.processed_parameters[param_name] = tensor

def get_processed_parameter(self, param_name: str) -> Union[torch.Tensor, None]:
"""
Rerturn the processed parameter by name.
:param param_name: The parameter name.
:return: The processed parameter by name if found, otherwise None.
"""
return self._threading.thread_local.processed_parameters.get(param_name, None)

def maybe_add_node(
self,
inputs: OperatorInput,
Expand Down Expand Up @@ -339,6 +360,14 @@ def in_operator(self):
def in_operator(self, val):
self._threading.thread_local.in_operator = val

@property
def in_parameter_trace(self):
return self._threading.thread_local.in_parameter_trace

@in_parameter_trace.setter
def in_parameter_trace(self, val):
self._threading.thread_local.in_parameter_trace = val

@property
def module_call_stack(self) -> List[torch.nn.Module]:
return self._threading.thread_local.module_call_stack
Expand Down
16 changes: 12 additions & 4 deletions nncf/torch/dynamic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nncf.torch.dynamic_graph.operation_address import OperationAddress
from nncf.torch.dynamic_graph.scope import Scope
from nncf.torch.dynamic_graph.trace_tensor import TensorMeta
from nncf.torch.dynamic_graph.trace_tensor import TracedParameter
from nncf.torch.dynamic_graph.trace_tensor import TracedTensor


Expand Down Expand Up @@ -401,7 +402,7 @@ def save_first_iteration_node(self, inputs: OperatorInput, node: DynamicGraphNod
It finds and saves "starting" points of iteration for further matching with them on next iteration,
instead of adding new nodes for each iteration. "Starting" points of iteration are nodes
* that have at least one input node, which is outside of iteration scope
* or whose all inputs are not TracedTensor
* or whose all inputs are not tensors with TracedTensorMixin tracing capabilities.
"""
op_exec_context = node.op_exec_context
name = str(node)
Expand All @@ -414,17 +415,24 @@ def save_first_iteration_node(self, inputs: OperatorInput, node: DynamicGraphNod
has_input_outside_iteration = False
untraced_tensor_inputs = []
traced_tensor_inputs = []
traced_parameter_inputs = []
non_tensor_inputs = []
for i in inputs:
input_obj = i.getter()
if isinstance(input_obj, Tensor):
if not isinstance(input_obj, TracedTensor):
untraced_tensor_inputs.append(input_obj)
else:
if isinstance(input_obj, TracedTensor):
traced_tensor_inputs.append(input_obj)
elif isinstance(input_obj, TracedParameter):
traced_parameter_inputs.append(input_obj)
else:
untraced_tensor_inputs.append(input_obj)
else:
non_tensor_inputs.append(input_obj)

for i in traced_parameter_inputs:
if i.tensor_meta is not None:
traced_tensor_inputs.append(i)

for i in traced_tensor_inputs:
creator_id = i.tensor_meta.creator_id
creator_node = self.get_node_by_id(creator_id)
Expand Down
10 changes: 9 additions & 1 deletion nncf/torch/dynamic_graph/graph_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.dynamic_graph.io_handling import LoaderInputInfo
from nncf.torch.dynamic_graph.io_handling import ModelInputInfo
from nncf.torch.dynamic_graph.wrappers import wrap_parameters
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_multidevice

Expand All @@ -28,7 +29,11 @@ def __init__(self, custom_forward_fn: Callable[[torch.nn.Module], Any]):
self.custom_forward_fn = custom_forward_fn

def trace_graph(
self, model: torch.nn.Module, context_to_use: Optional[TracingContext] = None, as_eval: bool = False
self,
model: torch.nn.Module,
context_to_use: Optional[TracingContext] = None,
as_eval: bool = False,
trace_parameters: bool = False,
) -> DynamicGraph:
sd = deepcopy(model.state_dict())

Expand All @@ -41,6 +46,9 @@ def trace_graph(
with context_to_use as _ctx:
_ctx.base_module_thread_local_replica = model
with torch.no_grad():
if trace_parameters:
wrap_parameters(model)

if as_eval:
with training_mode_switcher(model, is_training=False):
self.custom_forward_fn(model)
Expand Down
12 changes: 11 additions & 1 deletion nncf/torch/dynamic_graph/layer_attributes_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes
from nncf.common.graph.layer_attributes import MultipleOutputLayerAttributes
from nncf.common.graph.layer_attributes import PadLayerAttributes
from nncf.common.graph.layer_attributes import ParameterLayerAttributes
from nncf.common.graph.layer_attributes import PermuteLayerAttributes
from nncf.common.graph.layer_attributes import ReshapeLayerAttributes
from nncf.common.graph.layer_attributes import TransposeLayerAttributes
from nncf.common.graph.operator_metatypes import ConstNoopMetatype
from nncf.common.graph.utils import get_split_axis
from nncf.torch.graph.operator_metatypes import PTCatMetatype
from nncf.torch.graph.operator_metatypes import PTGroupNormMetatype
Expand All @@ -49,8 +51,9 @@
GETITEM_OP_NAMES = ["__getitem__"]
PAD_OP_NAMES = PTPadMetatype.get_all_aliases()
CONCAT_OP_NAMES = PTCatMetatype.get_all_aliases()
CONST_OP_NAMES = ConstNoopMetatype.get_all_aliases()
OP_NAMES_REQUIRING_ATTRS_FROM_ARGS_KWARGS = list(
TRANSPOSE_OP_NAMES + PERMUTE_OP_NAMES + GETITEM_OP_NAMES + PAD_OP_NAMES + CONCAT_OP_NAMES
TRANSPOSE_OP_NAMES + PERMUTE_OP_NAMES + GETITEM_OP_NAMES + PAD_OP_NAMES + CONCAT_OP_NAMES + CONST_OP_NAMES
)


Expand Down Expand Up @@ -121,6 +124,8 @@ def get_layer_attributes_from_args_and_kwargs(op_name: str, args, kwargs) -> Bas
layer_attrs = _get_pad_attrs_from_args_kwargs(args, kwargs)
elif op_name in CONCAT_OP_NAMES:
layer_attrs = _get_concat_attrs_from_args_kwargs(args, kwargs)
elif op_name in CONST_OP_NAMES:
layer_attrs = _get_const_attrs_from_args_kwargs(args, kwargs)
return layer_attrs


Expand Down Expand Up @@ -180,3 +185,8 @@ def _get_kwargs_shifted(args_names, args, kwargs, shift=1):
for idx, arg_name in enumerate(args_names):
res_kwargs[arg_name] = kwargs[arg_name] if arg_name in kwargs else args[idx + shift]
return res_kwargs


def _get_const_attrs_from_args_kwargs(args, _) -> ParameterLayerAttributes:
name = getattr(args[0], "name", "Unknown")
return ParameterLayerAttributes(name)
19 changes: 10 additions & 9 deletions nncf/torch/dynamic_graph/patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import functools
import inspect
from contextlib import contextmanager
from typing import List
from typing import List, Tuple

import torch
import torch.utils.cpp_extension
Expand All @@ -25,19 +25,20 @@
from nncf.common.utils.api_marker import api
from nncf.torch.dynamic_graph.structs import NamespaceTarget
from nncf.torch.dynamic_graph.structs import PatchedOperatorInfo
from nncf.torch.dynamic_graph.trace_tensor import TracedParameter
from nncf.torch.dynamic_graph.trace_tensor import TracedTensor
from nncf.torch.dynamic_graph.wrappers import ignore_scope
from nncf.torch.dynamic_graph.wrappers import wrap_module_call
from nncf.torch.dynamic_graph.wrappers import wrap_operator


def get_namespace_to_patch(namespace_target: NamespaceTarget) -> object:
def get_namespaces_to_patch(namespace_target: NamespaceTarget) -> Tuple[object, ...]:
if namespace_target == NamespaceTarget.TORCH_NN_FUNCTIONAL:
return torch.nn.functional
return (torch.nn.functional,)
if namespace_target == NamespaceTarget.TORCH_TENSOR:
return TracedTensor
return (TracedTensor, TracedParameter)
if namespace_target == NamespaceTarget.TORCH:
return torch
return (torch,)
raise RuntimeError("{} namespace wasn't found in {}".format(namespace_target, NamespaceTarget))


Expand Down Expand Up @@ -366,17 +367,17 @@ def patch_torch_operators():
for namespace, function_names in functions_to_patch.items():
for function_name in function_names:
op_info = PatchedOperatorInfo(function_name, namespace)
patched_namespace = get_namespace_to_patch(namespace)
patch_namespace_opname(patched_namespace, op_info)
for patched_namespace in get_namespaces_to_patch(namespace):
patch_namespace_opname(patched_namespace, op_info)

# Patch operators without tracing so that
# both they and any internal calls to otherwise traced functions do not appear into the model graph.

for namespace, function_names in functions_to_patch_without_tracing.items():
for function_name in function_names:
op_info = PatchedOperatorInfo(function_name, namespace, skip_trace=True)
patched_namespace = get_namespace_to_patch(namespace)
patch_namespace_opname(patched_namespace, op_info)
for patched_namespace in get_namespaces_to_patch(namespace):
patch_namespace_opname(patched_namespace, op_info)

# Patch __repr__ twice in 'torch.Tensor' and 'TracedTensor'.
# This is done to not add operations behind print() operator for the both TracedTensor and torch.Tensor.
Expand Down
Loading

0 comments on commit 4b903a5

Please sign in to comment.