From beb25fb704ef3bb9e9928b4df56ddbee84312a01 Mon Sep 17 00:00:00 2001 From: Orient94 Date: Mon, 6 Feb 2023 14:32:59 +0800 Subject: [PATCH] Add meta ops in mmrazors/models --- .../models/architectures/meta_ops/__init__.py | 0 .../meta_ops/meta_base/__init__.py | 0 .../meta_ops/meta_base/meta_mixin.py | 223 ++++++++++++++++++ .../meta_ops/meta_bircks/__init__.py | 0 .../meta_ops/meta_bircks/meta_conv.py | 87 +++++++ 5 files changed, 310 insertions(+) create mode 100644 mmrazor/models/architectures/meta_ops/__init__.py create mode 100644 mmrazor/models/architectures/meta_ops/meta_base/__init__.py create mode 100644 mmrazor/models/architectures/meta_ops/meta_base/meta_mixin.py create mode 100644 mmrazor/models/architectures/meta_ops/meta_bircks/__init__.py create mode 100644 mmrazor/models/architectures/meta_ops/meta_bircks/meta_conv.py diff --git a/mmrazor/models/architectures/meta_ops/__init__.py b/mmrazor/models/architectures/meta_ops/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mmrazor/models/architectures/meta_ops/meta_base/__init__.py b/mmrazor/models/architectures/meta_ops/meta_base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mmrazor/models/architectures/meta_ops/meta_base/meta_mixin.py b/mmrazor/models/architectures/meta_ops/meta_base/meta_mixin.py new file mode 100644 index 000000000..afe86654f --- /dev/null +++ b/mmrazor/models/architectures/meta_ops/meta_base/meta_mixin.py @@ -0,0 +1,223 @@ +from abc import abstractmethod +from itertools import repeat +from typing import Callable, Iterable, Optional, Tuple, Set + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.modules.conv import _ConvNd + +from abc import ABC, abstractmethod + +from mmrazor.models.mutables.base_mutable import BaseMutable + + + +def _ntuple(n: int) -> Callable: # pragma: no cover + """Repeat a number n times.""" + + def parse(x): + if isinstance(x, Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def _get_current_kernel_pos(source_kernel_size: int, + target_kernel_size: int) -> Tuple[int, int]: + """Get position of current kernel size. + Returns: + Tuple[int, int]: (upper left position, bottom right position) + """ + assert source_kernel_size >= target_kernel_size, \ + '`source_kernel_size` must greater or equal than `target_kernel_size`' + + center = source_kernel_size >> 1 + current_offset = target_kernel_size >> 1 + + start_offset = center - current_offset + end_offset = center + current_offset + 1 + + return start_offset, end_offset + + +def _get_same_padding(kernel_size: int, n_dims: int) -> Tuple[int]: + """Get same padding according to kernel size.""" + assert kernel_size & 1 + _pair = _ntuple(n_dims) + + return _pair(kernel_size >> 1) + + + + +class MetaMixin(ABC): + """Base class for dynamic OP. A dynamic OP usually consists of a normal + static OP and mutables, where mutables are used to control the searchable + (mutable) part of the dynamic OP. + Note: + When the dynamic OP has just been initialized, its forward propagation + logic should be the same as the corresponding static OP. Only after + the searchable part accepts the specific mutable through the + corresponding interface does the part really become dynamic. + Note: + All subclass should implement ``to_static_op`` and + ``static_op_factory`` APIs. + Args: + accepted_mutables (set): The string set of all accepted mutables. + """ + accepted_mutable_attrs: Set[str] = set() + attr_mappings: Dict[str, str] = dict() + + @abstractmethod + def register_mutable_attr(self, attr: str, mutable: BaseMutable): + pass + + def get_mutable_attr(self, attr: str) -> BaseMutable: + + self.check_mutable_attr_valid(attr) + if attr in self.attr_mappings: + attr_map = self.attr_mappings[attr] + return getattr(self.mutable_attrs, attr_map, None) # type:ignore + else: + return getattr(self.mutable_attrs, attr, None) # type:ignore + + @classmethod + @abstractmethod + def convert_from(cls, module): + """Convert an instance of Pytorch module to a new instance of Dynamic + module.""" + + @property + @abstractmethod + def static_op_factory(self): + """Corresponding Pytorch OP.""" + + @abstractmethod + def to_static_op(self) -> nn.Module: + """Convert dynamic OP to static OP. + Note: + The forward result for the same input between dynamic OP and its + corresponding static OP must be same. + Returns: + nn.Module: Corresponding static OP. + """ + + def check_if_mutables_fixed(self): + """Check if all mutables are fixed. + Raises: + RuntimeError: Error if a existing mutable is not fixed. + """ + + def check_fixed(mutable: Optional[BaseMutable]) -> None: + if mutable is not None and not mutable.is_fixed: + raise RuntimeError(f'Mutable {type(mutable)} is not fixed.') + + for mutable in self.mutable_attrs.values(): # type: ignore + check_fixed(mutable) + + def check_mutable_attr_valid(self, attr): + assert attr in self.attr_mappings or \ + attr in self.accepted_mutable_attrs + + @staticmethod + def get_current_choice(mutable: BaseMutable): + """ + Get current choice of given mutable. + Args: + mutable (BaseMutable): Given mutable. + Raises: + RuntimeError: Error if `current_choice` is None. + Returns: + Any: Current choice of given mutable. + """ + current_choice = mutable.current_choice + if current_choice is None: + raise RuntimeError(f'current choice of mutable {type(mutable)} ' + 'can not be None at runtime') + + return current_choice + + +class MetaConvMixin(DynamicChannelMixin): + """A mixin class for Pytorch conv, which can mutate ``in_channels`` and + ``out_channels``. + Note: + All subclass should implement ``conv_func``API. + """ + + @property + @abstractmethod + def conv_func(self: _ConvNd): + """The function that will be used in ``forward_mixin``.""" + pass + + def register_mutable_attr(self, attr, mutable): + + if attr == 'in_channels': + self._register_mutable_in_channels(mutable) + elif attr == 'out_channels': + self._register_mutable_out_channels(mutable) + else: + raise NotImplementedError + + def _register_mutable_in_channels( + self: _ConvNd, mutable_in_channels: BaseMutable): + """Mutate ``in_channels`` with given mutable. + Args: + mutable_in_channels (BaseMutable): Mutable for controlling + ``in_channels``. + Raises: + ValueError: Error if size of mask if not same as ``in_channels``. + """ + assert hasattr(self, 'mutable_attrs') + self.check_mutable_channels(mutable_in_channels) + mask_size = mutable_in_channels.current_mask.size(0) + if mask_size != self.in_channels: + raise ValueError( + f'Expect mask size of mutable to be {self.in_channels} as ' + f'`in_channels`, but got: {mask_size}.') + + self.mutable_attrs['in_channels'] = mutable_in_channels + + def _register_mutable_out_channels( + self: _ConvNd, mutable_out_channels: BaseMutable): + """Mutate ``out_channels`` with given mutable. + Args: + mutable_out_channels (BaseMutable): Mutable for controlling + ``out_channels``. + Raises: + ValueError: Error if size of mask if not same as ``out_channels``. + """ + assert hasattr(self, 'mutable_attrs') + self.check_mutable_channels(mutable_out_channels) + mask_size = mutable_out_channels.current_mask.size(0) + if mask_size != self.out_channels: + raise ValueError( + f'Expect mask size of mutable to be {self.out_channels} as ' + f'`out_channels`, but got: {mask_size}.') + + self.mutable_attrs['out_channels'] = mutable_out_channels + + @property + def mutable_in_channels(self: _ConvNd): + """Mutable related to input.""" + assert hasattr(self, 'mutable_attrs') + return getattr(self.mutable_attrs, 'in_channels', None) # type:ignore + + @property + def mutable_out_channels(self: _ConvNd): + """Mutable related to output.""" + assert hasattr(self, 'mutable_attrs') + return getattr(self.mutable_attrs, 'out_channels', None) # type:ignore + + def forward_inpoup(self): + if 'in_channels' in self.mutable_attrs: + mutable_in_channels = self.mutable_attrs['in_channels'] + inp = mutable_in_channels.activated_channels + if 'out_channels' in self.mutable_attrs: + mutable_out_channels = self.mutable_attrs['out_channels'] + oup = mutable_out_channels.activated_channels + return inp, oup + diff --git a/mmrazor/models/architectures/meta_ops/meta_bircks/__init__.py b/mmrazor/models/architectures/meta_ops/meta_bircks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mmrazor/models/architectures/meta_ops/meta_bircks/meta_conv.py b/mmrazor/models/architectures/meta_ops/meta_bircks/meta_conv.py new file mode 100644 index 000000000..a7859139a --- /dev/null +++ b/mmrazor/models/architectures/meta_ops/meta_bircks/meta_conv.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor, scalar_tensor + +import math +from typing import Callable, Dict + + +def groups_channels(in_channels, groups): + if in_channels % groups == 0: + return int(in_channels/groups), groups + else: + num_mul = in_channels // groups + in_channels = groups * num_mul if num_mul > 0 else groups * (num_mul + 1) + in_channels = in_channels / groups + return int(in_channels), groups + +def groups_out_channels(out_channels, groups): + if out_channels % groups == 0: + return out_channels, groups + else: + num_mul = out_channels // groups + out_channels = groups * num_mul if num_mul > 0 else groups * (num_mul + 1) + out_channels = out_channels + return int(out_channels), groups + + + +class MetaConv2d(nn.Conv2d, MetaConvMixin): + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode): + super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode): + + self.mutable_attrs: Dict[str, BaseMuable] = nn.ModuleDict + self.stride = stride + self.padding = padding + self.kernel_size = kernel_size if not isinstance(kernel_size, int) \ + else [kernel_size, kernel_size] + self.base_oup = out_channels + self.base_inp = in_channels + + self.groups_ = groups + self.bias_ = True if bias is not False else False + self.max_oup_channel = self.base_oup + if in_channels/groups == 1: + self.max_inp_channel = 1 + else: + self.max_inp_channel = self.base_inp + + self.fc11 = nn.Linear(2, 64) + self.fc12 = nn.Linear(64, self.max_oup_channel * self.max_inp_channel \ + * self.kernel_size[0] * self.kernel_size[1]) + if self.bias_: + self.fc_bias = nn.Sequential( + nn.Linear(2, 16), + nn.ReLU(), + nn.Linear(16, self.max_out_channel) + ) + + def forward(self, x: Tensor): + + inp, out = self.forward_inpoup() + group_sample_num = self.base_inp / self.groups_ + group_sample_num = inp if group_sample_num > inp else group_sample_num + groups_new = int(inp / group_sample_num) if int(inp / group_sample_num) > 0 else 1 + inp, _ = groups_channels(inp, groups_new) + oup, _ = groups_out_channels(oup, groups_new) + + scale_tensor = torch.FloatTensor([inp/self.max_inp_channel, oup/self.max_out_channel]).to(x.device) + fc11_out = F.relu(self.fc11(scale_tensor)) + + vggconv3x3_weight = self.fc12(fc11_out).view( + self.max_oup_channel, + self.max_inp_channel, + self.kernel_size[0], + self.kernel_size[1]) + bias = None + if self.bias_: + bias = self.fc_bias(scale_tensor) + bias = bias[:oup] + + out = F.conv2d(x, vggconv3x3_weight[:oup, :inp, :, :], + bias=bias, stride=self.stride, padding=self.padding, groups=groups_new) + return out +