From c98226fc53b9f9a43d794ce48f37c6d35fdb8930 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Thu, 1 Feb 2024 20:19:54 +0800 Subject: [PATCH] Add `MobileOne` (#36) * Update export scripts * Add `MobileOne` * Update version * Update README * Fix tests * Remove explicit naming --- README.md | 1 + kimm/__init__.py | 2 +- kimm/layers/__init__.py | 1 + kimm/layers/attention.py | 14 +- kimm/layers/layer_scale.py | 2 - kimm/layers/mobile_one_conv2d.py | 376 ++++++++++++++++++++++ kimm/layers/mobile_one_conv2d_test.py | 164 ++++++++++ kimm/layers/rep_conv2d.py | 78 +++-- kimm/layers/rep_conv2d_test.py | 2 +- kimm/models/__init__.py | 1 + kimm/models/mobileone.py | 358 ++++++++++++++++++++ kimm/models/models_test.py | 16 +- kimm/utils/timm_utils.py | 18 +- shell/export.sh | 1 + tools/convert_convmixer_from_timm.py | 1 + tools/convert_convnext_from_timm.py | 1 + tools/convert_densenet_from_timm.py | 1 + tools/convert_efficientnet_from_timm.py | 1 + tools/convert_ghostnet_from_timm.py | 1 + tools/convert_inception_next_from_timm.py | 1 + tools/convert_mobilenet_v2_from_timm.py | 1 + tools/convert_mobilenet_v3_from_timm.py | 1 + tools/convert_mobileone_from_timm.py | 166 ++++++++++ tools/convert_regnet_from_timm.py | 1 + tools/convert_resnet_from_timm.py | 1 + tools/convert_vgg_from_timm.py | 1 + tools/convert_vit_from_timm.py | 1 + tools/convert_xception_from_keras.py | 1 + 28 files changed, 1167 insertions(+), 46 deletions(-) create mode 100644 kimm/layers/mobile_one_conv2d.py create mode 100644 kimm/layers/mobile_one_conv2d_test.py create mode 100644 kimm/models/mobileone.py create mode 100644 tools/convert_mobileone_from_timm.py diff --git a/README.md b/README.md index 4e610f1..b50deda 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,7 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io |LCNet|[arXiv 2021](https://arxiv.org/abs/2109.15099)|`timm`|`kimm.models.LCNet*`| |MobileNetV2|[CVPR 2018](https://arxiv.org/abs/1801.04381)|`timm`|`kimm.models.MobileNetV2*`| |MobileNetV3|[ICCV 2019](https://arxiv.org/abs/1905.02244)|`timm`|`kimm.models.MobileNetV3*`| +|MobileOne|[CVPR 2023](https://arxiv.org/abs/2206.04040)|`timm`|`kimm.models.MobileOne*`| |MobileViT|[ICLR 2022](https://arxiv.org/abs/2110.02178)|`timm`|`kimm.models.MobileViT*`| |MobileViTV2|[arXiv 2022](https://arxiv.org/abs/2206.02680)|`timm`|`kimm.models.MobileViTV2*`| |RegNet|[CVPR 2020](https://arxiv.org/abs/2003.13678)|`timm`|`kimm.models.RegNet*`| diff --git a/kimm/__init__.py b/kimm/__init__.py index b419ce9..6069ed6 100644 --- a/kimm/__init__.py +++ b/kimm/__init__.py @@ -2,4 +2,4 @@ from kimm import models # force to add models to the registry from kimm.utils.model_registry import list_models -__version__ = "0.1.6" +__version__ = "0.1.7" diff --git a/kimm/layers/__init__.py b/kimm/layers/__init__.py index 577f0aa..f21f54f 100644 --- a/kimm/layers/__init__.py +++ b/kimm/layers/__init__.py @@ -1,4 +1,5 @@ from kimm.layers.attention import Attention from kimm.layers.layer_scale import LayerScale +from kimm.layers.mobile_one_conv2d import MobileOneConv2D from kimm.layers.position_embedding import PositionEmbedding from kimm.layers.rep_conv2d import RepConv2D diff --git a/kimm/layers/attention.py b/kimm/layers/attention.py index d610557..271f10a 100644 --- a/kimm/layers/attention.py +++ b/kimm/layers/attention.py @@ -13,7 +13,6 @@ def __init__( use_qk_norm: bool = False, attention_dropout_rate: float = 0.0, projection_dropout_rate: float = 0.0, - name: str = "attention", **kwargs, ): super().__init__(**kwargs) @@ -25,20 +24,19 @@ def __init__( self.use_qk_norm = use_qk_norm self.attention_dropout_rate = attention_dropout_rate self.projection_dropout_rate = projection_dropout_rate - self.name = name self.qkv = layers.Dense( hidden_dim * 3, use_bias=use_qkv_bias, dtype=self.dtype_policy, - name=f"{name}_qkv", + name=f"{self.name}_qkv", ) if use_qk_norm: self.q_norm = layers.LayerNormalization( - dtype=self.dtype_policy, name=f"{name}_q_norm" + dtype=self.dtype_policy, name=f"{self.name}_q_norm" ) self.k_norm = layers.LayerNormalization( - dtype=self.dtype_policy, name=f"{name}_k_norm" + dtype=self.dtype_policy, name=f"{self.name}_k_norm" ) else: self.q_norm = layers.Identity(dtype=self.dtype_policy) @@ -47,15 +45,15 @@ def __init__( self.attention_dropout = layers.Dropout( attention_dropout_rate, dtype=self.dtype_policy, - name=f"{name}_attn_drop", + name=f"{self.name}_attn_drop", ) self.projection = layers.Dense( - hidden_dim, dtype=self.dtype_policy, name=f"{name}_proj" + hidden_dim, dtype=self.dtype_policy, name=f"{self.name}_proj" ) self.projection_dropout = layers.Dropout( projection_dropout_rate, dtype=self.dtype_policy, - name=f"{name}_proj_drop", + name=f"{self.name}_proj_drop", ) def build(self, input_shape): diff --git a/kimm/layers/layer_scale.py b/kimm/layers/layer_scale.py index 7ef61c7..8cb2924 100644 --- a/kimm/layers/layer_scale.py +++ b/kimm/layers/layer_scale.py @@ -11,13 +11,11 @@ def __init__( self, axis: int = -1, initializer: Initializer = initializers.Constant(1e-5), - name: str = "layer_scale", **kwargs, ): super().__init__(**kwargs) self.axis = axis self.initializer = initializer - self.name = name def build(self, input_shape): if isinstance(self.axis, list): diff --git a/kimm/layers/mobile_one_conv2d.py b/kimm/layers/mobile_one_conv2d.py new file mode 100644 index 0000000..138eadd --- /dev/null +++ b/kimm/layers/mobile_one_conv2d.py @@ -0,0 +1,376 @@ +import typing + +import keras +import numpy as np +from keras import Sequential +from keras import layers +from keras import ops +from keras.src.backend import standardize_data_format +from keras.src.layers import Layer +from keras.src.utils.argument_validation import standardize_tuple + + +@keras.saving.register_keras_serializable(package="kimm") +class MobileOneConv2D(Layer): + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding=None, + has_skip: bool = True, + use_depthwise: bool = False, + branch_size: int = 1, + reparameterized: bool = False, + data_format=None, + activation=None, + **kwargs, + ): + super().__init__(**kwargs) + self.filters = filters + self.kernel_size = standardize_tuple(kernel_size, 2, "kernel_size") + self.strides = standardize_tuple(strides, 2, "strides") + self.padding = padding + self.has_skip = has_skip + self.use_depthwise = use_depthwise + self.branch_size = branch_size + self._reparameterized = reparameterized + self.data_format = standardize_data_format(data_format) + self.activation = activation + + if self.kernel_size[0] != self.kernel_size[1]: + raise ValueError( + "The value of kernel_size must be the same. " + f"Received: kernel_size={kernel_size}" + ) + if self.strides[0] != self.strides[1]: + raise ValueError( + "The value of strides must be the same. " + f"Received: strides={strides}" + ) + if has_skip is True and (self.strides[0] != 1 or self.strides[1] != 1): + raise ValueError( + "strides must be `1` when `has_skip=True`. " + f"Received: has_skip={has_skip}, strides={strides}" + ) + + self.zero_padding = layers.Identity(dtype=self.dtype_policy) + if padding is None: + padding = "same" + if self.strides[0] > 1: + padding = "valid" + self.zero_padding = layers.ZeroPadding2D( + (self.kernel_size[0] // 2, self.kernel_size[1] // 2), + data_format=self.data_format, + dtype=self.dtype_policy, + name=f"{self.name}_pad", + ) + self.padding = padding + else: + self.padding = padding + + channel_axis = -1 if self.data_format == "channels_last" else -3 + + # Build layers (rep_conv2d, identity, conv_kxk, conv_scale) + self.rep_conv2d: typing.Optional[layers.Conv2D] = None + self.identity: typing.Optional[layers.BatchNormalization] = None + self.conv_kxk: typing.Optional[typing.List[Sequential]] = None + self.conv_scale: typing.Optional[Sequential] = None + if self._reparameterized: + self.rep_conv2d = self._get_conv2d( + use_depthwise, + self.filters, + self.kernel_size, + self.strides, + self.padding, + use_bias=True, + name=f"{self.name}_reparam_conv", + ) + else: + # Skip connection + if self.has_skip: + self.identity = layers.BatchNormalization( + axis=channel_axis, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype_policy, + name=f"{self.name}_identity", + ) + else: + self.identity = None + + # Convoluation branches + self.conv_kxk = [] + for i in range(self.branch_size): + self.conv_kxk.append( + Sequential( + [ + self._get_conv2d( + self.use_depthwise, + self.filters, + self.kernel_size, + self.strides, + self.padding, + use_bias=False, + ), + layers.BatchNormalization( + axis=channel_axis, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype_policy, + ), + ], + name=f"{self.name}_conv_kxk_{i}", + ) + ) + + # Scale branch + self.conv_scale = None + if self.kernel_size[0] > 1: + self.conv_scale = Sequential( + [ + self._get_conv2d( + self.use_depthwise, + self.filters, + 1, + self.strides, + self.padding, + use_bias=False, + ), + layers.BatchNormalization( + axis=channel_axis, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype_policy, + ), + ], + name=f"{self.name}_conv_scale", + ) + + if activation is None: + self.act = layers.Identity(dtype=self.dtype_policy) + else: + self.act = layers.Activation(activation, dtype=self.dtype_policy) + + # Internal parameters for `_get_reparameterized_weights_from_layer` + self._input_channels = None + self._rep_kernel_shape = None + + # Attach extra layers + self.extra_layers = [] + if self.rep_conv2d is not None: + self.extra_layers.append(self.rep_conv2d) + if self.identity is not None: + self.extra_layers.append(self.identity) + if self.conv_kxk is not None: + self.extra_layers.extend(self.conv_kxk) + if self.conv_scale is not None: + self.extra_layers.append(self.conv_scale) + self.extra_layers.append(self.act) + + def _get_conv2d( + self, + use_depthwise, + filters, + kernel_size, + strides, + padding, + use_bias, + name=None, + ): + if use_depthwise: + return layers.DepthwiseConv2D( + kernel_size, + strides, + padding, + data_format=self.data_format, + use_bias=use_bias, + dtype=self.dtype_policy, + name=name, + ) + else: + return layers.Conv2D( + filters, + kernel_size, + strides, + padding, + data_format=self.data_format, + use_bias=use_bias, + dtype=self.dtype_policy, + name=name, + ) + + def build(self, input_shape): + channel_axis = -1 if self.data_format == "channels_last" else -3 + + if isinstance(self.zero_padding, layers.ZeroPadding2D): + padded_shape = self.zero_padding.compute_output_shape(input_shape) + else: + padded_shape = input_shape + + if self.rep_conv2d is not None: + self.rep_conv2d.build(padded_shape) + if self.identity is not None: + self.identity.build(input_shape) + if self.conv_kxk is not None: + for layer in self.conv_kxk: + layer.build(padded_shape) + if self.conv_scale is not None: + self.conv_scale.build(input_shape) + + # Update internal parameters + self._input_channels = input_shape[channel_axis] + if self.conv_kxk is not None: + self._rep_kernel_shape = self.conv_kxk[0].layers[0].kernel.shape + + self.built = True + + def call(self, inputs, **kwargs): + x = ops.cast(inputs, self.compute_dtype) + padded_x = self.zero_padding(x) + + # Shortcut for reparameterized mode + if self._reparameterized: + return self.act(self.rep_conv2d(padded_x, **kwargs)) + + # Skip connection + identity_outputs = None + if self.identity is not None: + identity_outputs = self.identity(x, **kwargs) + + # Scale branch + scale_outputs = None + if self.conv_scale is not None: + scale_outputs = self.conv_scale(x, **kwargs) + + # Conv branch + conv_outputs = scale_outputs + for layer in self.conv_kxk: + if conv_outputs is None: + conv_outputs = layer(padded_x, **kwargs) + else: + conv_outputs = layers.Add()( + [conv_outputs, layer(padded_x, **kwargs)] + ) + + if identity_outputs is not None: + outputs = layers.Add()([conv_outputs, identity_outputs]) + else: + outputs = conv_outputs + return self.act(outputs) + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "has_skip": self.has_skip, + "use_depthwise": self.use_depthwise, + "branch_size": self.branch_size, + "reparameterized": self._reparameterized, + "data_format": self.data_format, + "activation": self.activation, + "name": self.name, + } + ) + return config + + def _get_reparameterized_weights_from_layer(self, layer): + if isinstance(layer, Sequential): + if not isinstance( + layer.layers[0], (layers.Conv2D, layers.DepthwiseConv2D) + ): + raise ValueError + if not isinstance(layer.layers[1], layers.BatchNormalization): + raise ValueError + kernel = ops.convert_to_numpy(layer.layers[0].kernel) + if self.use_depthwise: + kernel = np.swapaxes(kernel, -2, -1) + gamma = ops.convert_to_numpy(layer.layers[1].gamma) + beta = ops.convert_to_numpy(layer.layers[1].beta) + running_mean = ops.convert_to_numpy(layer.layers[1].moving_mean) + running_var = ops.convert_to_numpy(layer.layers[1].moving_variance) + eps = layer.layers[1].epsilon + elif isinstance(layer, layers.BatchNormalization): + if self._rep_kernel_shape is None: + raise ValueError( + "Remember to build the layer before performing" + "reparameterization. Failed to get valid " + "`self._rep_kernel_shape`." + ) + # Calculate identity tensor + kernel_value = ops.convert_to_numpy( + ops.zeros(self._rep_kernel_shape) + ) + kernel = kernel_value.copy() + if self.use_depthwise: + kernel = np.swapaxes(kernel, -2, -1) + for i in range(self._input_channels): + group_i = 0 if self.use_depthwise else i + kernel[ + self.kernel_size[0] // 2, + self.kernel_size[1] // 2, + group_i, + i, + ] = 1 + gamma = ops.convert_to_numpy(layer.gamma) + beta = ops.convert_to_numpy(layer.beta) + running_mean = ops.convert_to_numpy(layer.moving_mean) + running_var = ops.convert_to_numpy(layer.moving_variance) + eps = layer.epsilon + + # use float64 for better precision + kernel = kernel.astype("float64") + gamma = gamma.astype("float64") + beta = beta.astype("float64") + running_var = running_var.astype("float64") + running_var = running_var.astype("float64") + + std = np.sqrt(running_var + eps) + t = np.reshape(gamma / std, [1, 1, 1, -1]) + + kernel_final = kernel * t + if self.use_depthwise: + kernel_final = np.swapaxes(kernel_final, -2, -1) + return kernel_final, beta - running_mean * gamma / std + + def get_reparameterized_weights(self): + # Get kernels and bias from scale branch + kernel_scale = 0.0 + bias_scale = 0.0 + if self.conv_scale is not None: + ( + kernel_scale, + bias_scale, + ) = self._get_reparameterized_weights_from_layer(self.conv_scale) + pad = self.kernel_size[0] // 2 + kernel_scale = np.pad( + kernel_scale, [[pad, pad], [pad, pad], [0, 0], [0, 0]] + ) + + # Get kernels and bias from skip branch + kernel_identity = 0.0 + bias_identity = 0.0 + if self.identity is not None: + ( + kernel_identity, + bias_identity, + ) = self._get_reparameterized_weights_from_layer(self.identity) + + # Get kernels and bias from conv branch + kernel_conv = 0.0 + bias_conv = 0.0 + for i in range(self.branch_size): + ( + _kernel_conv, + _bias_conv, + ) = self._get_reparameterized_weights_from_layer(self.conv_kxk[i]) + kernel_conv += _kernel_conv + bias_conv += _bias_conv + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final diff --git a/kimm/layers/mobile_one_conv2d_test.py b/kimm/layers/mobile_one_conv2d_test.py new file mode 100644 index 0000000..f7ecc83 --- /dev/null +++ b/kimm/layers/mobile_one_conv2d_test.py @@ -0,0 +1,164 @@ +import pytest +from absl.testing import parameterized +from keras import backend +from keras import random +from keras.src import testing + +from kimm.layers.mobile_one_conv2d import MobileOneConv2D + +TEST_CASES = [ + { + "filters": 16, + "kernel_size": 3, + "has_skip": True, + "use_depthwise": False, + "branch_size": 2, + "data_format": "channels_last", + "input_shape": (1, 4, 4, 16), + "output_shape": (1, 4, 4, 16), + "num_trainable_weights": 11, + "num_non_trainable_weights": 8, + }, + { + "filters": 16, + "kernel_size": 3, + "has_skip": True, + "use_depthwise": True, + "branch_size": 3, + "data_format": "channels_last", + "input_shape": (1, 4, 4, 16), + "output_shape": (1, 4, 4, 16), + "num_trainable_weights": 14, + "num_non_trainable_weights": 10, + }, + { + "filters": 16, + "kernel_size": 3, + "has_skip": False, + "use_depthwise": False, + "branch_size": 2, + "data_format": "channels_last", + "input_shape": (1, 4, 4, 8), + "output_shape": (1, 4, 4, 16), + "num_trainable_weights": 9, + "num_non_trainable_weights": 6, + }, + { + "filters": 16, + "kernel_size": 5, + "has_skip": True, + "use_depthwise": False, + "branch_size": 2, + "data_format": "channels_last", + "input_shape": (1, 4, 4, 16), + "output_shape": (1, 4, 4, 16), + "num_trainable_weights": 11, + "num_non_trainable_weights": 8, + }, + { + "filters": 16, + "kernel_size": 3, + "has_skip": True, + "use_depthwise": False, + "branch_size": 2, + "data_format": "channels_first", + "input_shape": (1, 16, 4, 4), + "output_shape": (1, 16, 4, 4), + "num_trainable_weights": 11, + "num_non_trainable_weights": 8, + }, +] + + +class MobileOneConv2DTest(testing.TestCase, parameterized.TestCase): + @parameterized.parameters(TEST_CASES) + @pytest.mark.requires_trainable_backend + def test_mobile_one_conv2d_basic( + self, + filters, + kernel_size, + has_skip, + use_depthwise, + branch_size, + data_format, + input_shape, + output_shape, + num_trainable_weights, + num_non_trainable_weights, + ): + if ( + backend.backend() == "tensorflow" + and data_format == "channels_first" + ): + self.skipTest( + "Conv2D in tensorflow backend with 'channels_first' is limited " + "to be supported" + ) + self.run_layer_test( + MobileOneConv2D, + init_kwargs={ + "filters": filters, + "kernel_size": kernel_size, + "has_skip": has_skip, + "use_depthwise": use_depthwise, + "branch_size": branch_size, + "data_format": data_format, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=num_trainable_weights, + expected_num_non_trainable_weights=num_non_trainable_weights, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.parameters(TEST_CASES) + def test_mobile_one_conv2d_get_reparameterized_weights( + self, + filters, + kernel_size, + has_skip, + use_depthwise, + branch_size, + data_format, + input_shape, + output_shape, + num_trainable_weights, + num_non_trainable_weights, + ): + if ( + backend.backend() == "tensorflow" + and data_format == "channels_first" + ): + self.skipTest( + "Conv2D in tensorflow backend with 'channels_first' is limited " + "to be supported" + ) + layer = MobileOneConv2D( + filters=filters, + kernel_size=kernel_size, + has_skip=has_skip, + use_depthwise=use_depthwise, + branch_size=branch_size, + data_format=data_format, + ) + layer.build(input_shape) + reparameterized_layer = MobileOneConv2D( + filters=filters, + kernel_size=kernel_size, + has_skip=has_skip, + use_depthwise=use_depthwise, + branch_size=branch_size, + reparameterized=True, + data_format=data_format, + ) + reparameterized_layer.build(input_shape) + x = random.uniform(input_shape) + + kernel, bias = layer.get_reparameterized_weights() + reparameterized_layer.rep_conv2d.kernel.assign(kernel) + reparameterized_layer.rep_conv2d.bias.assign(bias) + y1 = layer(x, training=False) + y2 = reparameterized_layer(x, training=False) + + self.assertAllClose(y1, y2, atol=1e-3) diff --git a/kimm/layers/rep_conv2d.py b/kimm/layers/rep_conv2d.py index 92faef2..e3b67e8 100644 --- a/kimm/layers/rep_conv2d.py +++ b/kimm/layers/rep_conv2d.py @@ -20,7 +20,6 @@ def __init__( reparameterized: bool = False, data_format=None, activation=None, - name="rep_conv2d", **kwargs, ): super().__init__(**kwargs) @@ -32,8 +31,17 @@ def __init__( self._reparameterized = reparameterized self.data_format = standardize_data_format(data_format) self.activation = activation - self.name = name + if self.kernel_size[0] != self.kernel_size[1]: + raise ValueError( + "The value of kernel_size must be the same. " + f"Received: kernel_size={kernel_size}" + ) + if self.strides[0] != self.strides[1]: + raise ValueError( + "The value of strides must be the same. " + f"Received: strides={strides}" + ) if has_skip is True and (self.strides[0] != 1 or self.strides[1] != 1): raise ValueError( "strides must be `1` when `has_skip=True`. " @@ -49,7 +57,7 @@ def __init__( (self.kernel_size[0] // 2, self.kernel_size[1] // 2), data_format=self.data_format, dtype=self.dtype_policy, - name=f"{name}_pad", + name=f"{self.name}_pad", ) self.padding = padding else: @@ -58,14 +66,14 @@ def __init__( channel_axis = -1 if self.data_format == "channels_last" else -3 if self._reparameterized: self.rep_conv2d = layers.Conv2D( - filters, - kernel_size, - strides, - padding, + self.filters, + self.kernel_size, + self.strides, + self.padding, data_format=self.data_format, use_bias=True, dtype=self.dtype_policy, - name=f"{name}_reparam_conv", + name=f"{self.name}_reparam_conv", ) self.identity = None self.conv_kxk = None @@ -78,16 +86,16 @@ def __init__( momentum=0.9, epsilon=1e-5, dtype=self.dtype_policy, - name=f"{name}_identity", + name=f"{self.name}_identity", ) else: self.identity = None self.conv_kxk = Sequential( [ layers.Conv2D( - filters, - kernel_size, - strides, + self.filters, + self.kernel_size, + self.strides, padding=self.padding, data_format=self.data_format, use_bias=False, @@ -100,14 +108,14 @@ def __init__( dtype=self.dtype_policy, ), ], - name=f"{name}_conv_kxk", + name=f"{self.name}_conv_kxk", ) self.conv_1x1 = Sequential( [ layers.Conv2D( - filters, + self.filters, 1, - strides, + self.strides, padding=self.padding, data_format=self.data_format, use_bias=False, @@ -120,7 +128,7 @@ def __init__( dtype=self.dtype_policy, ), ], - name=f"{name}_conv_1x1", + name=f"{self.name}_conv_1x1", ) if activation is None: @@ -128,7 +136,11 @@ def __init__( else: self.act = layers.Activation(activation, dtype=self.dtype_policy) - # attach extra layers + # Internal parameters for `_get_reparameterized_weights_from_layer` + self._input_channels = None + self._rep_kernel_shape = None + + # Attach extra layers self.extra_layers = [] if self.rep_conv2d is not None: self.extra_layers.append(self.rep_conv2d) @@ -141,6 +153,8 @@ def __init__( self.extra_layers.append(self.act) def build(self, input_shape): + channel_axis = -1 if self.data_format == "channels_last" else -3 + if isinstance(self.zero_padding, layers.ZeroPadding2D): padded_shape = self.zero_padding.compute_output_shape(input_shape) else: @@ -155,13 +169,18 @@ def build(self, input_shape): if self.conv_1x1 is not None: self.conv_1x1.build(input_shape) + # Update internal parameters + self._input_channels = input_shape[channel_axis] + if self.conv_kxk is not None: + self._rep_kernel_shape = self.conv_kxk.layers[0].kernel.shape + self.built = True def call(self, inputs, **kwargs): x = ops.cast(inputs, self.compute_dtype) padded_x = self.zero_padding(x) - # Deploy mode + # Shortcut for reparameterized mode if self._reparameterized: return self.act(self.rep_conv2d(padded_x, **kwargs)) @@ -191,7 +210,6 @@ def get_config(self): return config def _get_reparameterized_weights_from_layer(self, layer): - channel_axis = -1 if self.data_format == "channels_last" else -3 if isinstance(layer, Sequential): if not isinstance(layer.layers[0], layers.Conv2D): raise ValueError @@ -204,15 +222,21 @@ def _get_reparameterized_weights_from_layer(self, layer): running_var = ops.convert_to_numpy(layer.layers[1].moving_variance) eps = layer.layers[1].epsilon elif isinstance(layer, layers.BatchNormalization): - # calculate identity tensor - in_chs = self.conv_kxk.layers[0].input.shape[channel_axis] - kernel_size = self.conv_kxk.layers[0].kernel_size + if self._rep_kernel_shape is None: + raise ValueError( + "Remember to build the layer before performing" + "reparameterization. Failed to get valid " + "`self._rep_kernel_shape`." + ) + # Calculate identity tensor kernel_value = ops.convert_to_numpy( - ops.zeros_like(self.conv_kxk.layers[0].kernel) + ops.zeros(self._rep_kernel_shape) ) kernel_value = kernel_value.copy() - for i in range(in_chs): - kernel_value[kernel_size[0] // 2, kernel_size[1] // 2, i, i] = 1 + for i in range(self._input_channels): + kernel_value[ + self.kernel_size[0] // 2, self.kernel_size[1] // 2, i, i + ] = 1 kernel = kernel_value gamma = ops.convert_to_numpy(layer.gamma) beta = ops.convert_to_numpy(layer.beta) @@ -220,7 +244,7 @@ def _get_reparameterized_weights_from_layer(self, layer): running_var = ops.convert_to_numpy(layer.moving_variance) eps = layer.epsilon - # use float64 for better precision + # Use float64 for better precision kernel = kernel.astype("float64") gamma = gamma.astype("float64") beta = beta.astype("float64") @@ -238,7 +262,7 @@ def get_reparameterized_weights(self): kernel_1x1, bias_1x1 = self._get_reparameterized_weights_from_layer( self.conv_1x1 ) - pad = self.conv_kxk.layers[0].kernel_size[0] // 2 + pad = self.kernel_size[0] // 2 kernel_1x1 = np.pad( kernel_1x1, [[pad, pad], [pad, pad], [0, 0], [0, 0]] ) diff --git a/kimm/layers/rep_conv2d_test.py b/kimm/layers/rep_conv2d_test.py index ca6c8df..7128f6a 100644 --- a/kimm/layers/rep_conv2d_test.py +++ b/kimm/layers/rep_conv2d_test.py @@ -131,4 +131,4 @@ def test_rep_conv2d_get_reparameterized_weights( y1 = layer(x, training=False) y2 = reparameterized_layer(x, training=False) - self.assertAllClose(y1, y2, atol=1e-5) + self.assertAllClose(y1, y2, atol=1e-3) diff --git a/kimm/models/__init__.py b/kimm/models/__init__.py index 2e7fa46..382b4b4 100644 --- a/kimm/models/__init__.py +++ b/kimm/models/__init__.py @@ -8,6 +8,7 @@ from kimm.models.inception_v3 import * # noqa:F403 from kimm.models.mobilenet_v2 import * # noqa:F403 from kimm.models.mobilenet_v3 import * # noqa:F403 +from kimm.models.mobileone import * # noqa:F403 from kimm.models.mobilevit import * # noqa:F403 from kimm.models.regnet import * # noqa:F403 from kimm.models.repvgg import * # noqa:F403 diff --git a/kimm/models/mobileone.py b/kimm/models/mobileone.py new file mode 100644 index 0000000..0aa22a8 --- /dev/null +++ b/kimm/models/mobileone.py @@ -0,0 +1,358 @@ +import typing + +import keras +from keras import backend + +from kimm import layers as kimm_layers +from kimm.models.base_model import BaseModel +from kimm.utils import add_model_to_registry + + +@keras.saving.register_keras_serializable(package="kimm") +class MobileOne(BaseModel): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])], + ] + + def __init__( + self, + num_blocks: typing.Sequence[int], + num_channels: typing.Sequence[int], + stem_channels: int = 48, + branch_size: int = 1, + reparameterized: bool = False, + **kwargs, + ): + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + if kwargs["weights_url"] is not None and reparameterized is True: + raise ValueError( + "Weights can only be loaded with `reparameterized=False`. " + "You can first initialize the model with " + "`reparameterized=False` then use " + "`get_reparameterized_model` to get the converted model. " + f"Received: weights={kwargs['weights']}, " + f"reparameterized={reparameterized}" + ) + + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, + ) + x = inputs + + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} + + # stem + x = kimm_layers.MobileOneConv2D( + stem_channels, + 3, + 2, + has_skip=False, + reparameterized=reparameterized, + activation="relu", + name="stem", + )(x) + features["STEM_S2"] = x + + # stages + current_strides = 2 + for current_stage_idx, (c, n) in enumerate( + zip(num_channels, num_blocks) + ): + strides = 2 + current_strides *= strides + current_block_idx = 0 + # blocks + for _ in range(n): + strides = strides if current_block_idx == 0 else 1 + input_channels = x.shape[channels_axis] + has_skip1 = strides == 1 + has_skip2 = input_channels == c + name1 = f"stages_{current_stage_idx}_{current_block_idx}" + name2 = f"stages_{current_stage_idx}_{current_block_idx+1}" + # Depthwise + x = kimm_layers.MobileOneConv2D( + input_channels, + 3, + strides, + has_skip=has_skip1, + use_depthwise=True, + branch_size=branch_size, + reparameterized=reparameterized, + activation="relu", + name=name1, + )(x) + # Pointwise + x = kimm_layers.MobileOneConv2D( + c, + 1, + 1, + has_skip=has_skip2, + use_depthwise=False, + branch_size=branch_size, + reparameterized=reparameterized, + activation="relu", + name=name2, + )(x) + current_block_idx += 2 + + # add feature + features[f"BLOCK{current_stage_idx}_S{current_strides}"] = x + + # Head + x = self.build_head(x) + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.num_blocks = num_blocks + self.num_channels = num_channels + self.stem_channels = stem_channels + self.branch_size = branch_size + self.reparameterized = reparameterized + + def get_config(self): + config = super().get_config() + config.update( + { + "num_blocks": self.num_blocks, + "num_channels": self.num_channels, + "stem_channels": self.stem_channels, + "branch_size": self.branch_size, + "reparameterized": self.reparameterized, + } + ) + return config + + def fix_config(self, config): + unused_kwargs = [ + "num_blocks", + "num_channels", + "stem_channels", + "branch_size", + ] + for k in unused_kwargs: + config.pop(k, None) + return config + + def get_reparameterized_model(self): + config = self.get_config() + config["reparameterized"] = True + config["weights"] = None + model = MobileOne(**config) + for layer, reparameterized_layer in zip(self.layers, model.layers): + if hasattr(layer, "get_reparameterized_weights"): + kernel, bias = layer.get_reparameterized_weights() + reparameterized_layer.rep_conv2d.kernel.assign(kernel) + reparameterized_layer.rep_conv2d.bias.assign(bias) + else: + for weight, target_weight in zip( + layer.weights, reparameterized_layer.weights + ): + target_weight.assign(weight) + return model + + +""" +Model Definition +""" + + +class MobileOneS0(MobileOne): + available_weights = [ + ( + "imagenet", + MobileOne.default_origin, + "mobileones0_mobileone_s0.apple_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "MobileOneS0", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [2, 8, 10, 1], + [48, 128, 256, 1024], + 48, + 4, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class MobileOneS1(MobileOne): + available_weights = [ + ( + "imagenet", + MobileOne.default_origin, + "mobileones1_mobileone_s1.apple_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "MobileOneS1", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [2, 8, 10, 1], + [96, 192, 512, 1280], + 64, + 1, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class MobileOneS2(MobileOne): + available_weights = [ + ( + "imagenet", + MobileOne.default_origin, + "mobileones2_mobileone_s2.apple_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "MobileOneS2", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [2, 8, 10, 1], + [96, 256, 640, 2048], + 64, + 1, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class MobileOneS3(MobileOne): + available_weights = [ + ( + "imagenet", + MobileOne.default_origin, + "mobileones3_mobileone_s3.apple_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "MobileOneS3", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [2, 8, 10, 1], + [128, 320, 768, 2048], + 64, + 1, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +# TODO: Add MobileOneS4 (w/ SE blocks) + + +add_model_to_registry(MobileOneS0, "imagenet") +add_model_to_registry(MobileOneS1, "imagenet") +add_model_to_registry(MobileOneS2, "imagenet") +add_model_to_registry(MobileOneS3, "imagenet") diff --git a/kimm/models/models_test.py b/kimm/models/models_test.py index 229f421..6985283 100644 --- a/kimm/models/models_test.py +++ b/kimm/models/models_test.py @@ -214,6 +214,19 @@ ("BLOCK4_S32", [1, 7, 7, make_divisible(96 * 1.0)]), ], ), + # mobileone + ( + kimm_models.MobileOneS0.__name__, + kimm_models.MobileOneS0, + 224, + [ + ("STEM_S2", [1, 112, 112, 48]), + ("BLOCK0_S4", [1, 56, 56, 48]), + ("BLOCK1_S8", [1, 28, 28, 128]), + ("BLOCK2_S16", [1, 14, 14, 256]), + ("BLOCK3_S32", [1, 7, 7, 1024]), + ], + ), # mobilevit ( kimm_models.MobileViTS.__name__, @@ -434,7 +447,8 @@ def test_model_feature_extractor( self.assertEqual(list(y[name].shape), shape) @parameterized.named_parameters( - (kimm_models.RepVGGA0.__name__, kimm_models.RepVGGA0, 224) + (kimm_models.RepVGGA0.__name__, kimm_models.RepVGGA0, 224), + (kimm_models.MobileOneS0.__name__, kimm_models.MobileOneS0, 224), ) def test_model_get_reparameterized_model(self, model_class, image_size): x = random.uniform([1, image_size, image_size, 3]) * 255.0 diff --git a/kimm/utils/timm_utils.py b/kimm/utils/timm_utils.py index 1fd33cb..c398140 100644 --- a/kimm/utils/timm_utils.py +++ b/kimm/utils/timm_utils.py @@ -71,12 +71,18 @@ def assign_weights( keras_name: str, keras_weight: keras.Variable, torch_weight: np.ndarray ): if len(keras_weight.shape) == 4: - if "dwconv2d" in keras_name or "depthwise" in keras_name: - # depthwise conv2d layer - keras_weight.assign(np.transpose(torch_weight, [2, 3, 0, 1])) - elif "conv" in keras_name or "pointwise" in keras_name: - # conventional conv2d layer - keras_weight.assign(np.transpose(torch_weight, [2, 3, 1, 0])) + if ( + "conv" in keras_name + or "pointwise" in keras_name + or "dwconv2d" in keras_name + or "depthwise" in keras_name + ): + try: + # conventional conv2d layer + keras_weight.assign(np.transpose(torch_weight, [2, 3, 1, 0])) + except ValueError: + # depthwise conv2d layer + keras_weight.assign(np.transpose(torch_weight, [2, 3, 0, 1])) else: raise ValueError( f"Failed to assign {keras_name}. " diff --git a/shell/export.sh b/shell/export.sh index 37613dd..3958a2e 100755 --- a/shell/export.sh +++ b/shell/export.sh @@ -13,6 +13,7 @@ python3 -m tools.convert_inception_next_from_timm python3 -m tools.convert_inception_v3_from_timm python3 -m tools.convert_mobilenet_v2_from_timm python3 -m tools.convert_mobilenet_v3_from_timm +python3 -m tools.convert_mobileone_from_timm python3 -m tools.convert_mobilevit_from_timm python3 -m tools.convert_regnet_from_timm python3 -m tools.convert_repvgg_from_timm diff --git a/tools/convert_convmixer_from_timm.py b/tools/convert_convmixer_from_timm.py index f637d2b..371e9a0 100644 --- a/tools/convert_convmixer_from_timm.py +++ b/tools/convert_convmixer_from_timm.py @@ -43,6 +43,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_convnext_from_timm.py b/tools/convert_convnext_from_timm.py index 4c061e8..6da0e5b 100644 --- a/tools/convert_convnext_from_timm.py +++ b/tools/convert_convnext_from_timm.py @@ -55,6 +55,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_densenet_from_timm.py b/tools/convert_densenet_from_timm.py index fa17d5a..2ca9a4c 100644 --- a/tools/convert_densenet_from_timm.py +++ b/tools/convert_densenet_from_timm.py @@ -45,6 +45,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_efficientnet_from_timm.py b/tools/convert_efficientnet_from_timm.py index 040354a..2e3e8cc 100644 --- a/tools/convert_efficientnet_from_timm.py +++ b/tools/convert_efficientnet_from_timm.py @@ -89,6 +89,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_ghostnet_from_timm.py b/tools/convert_ghostnet_from_timm.py index 25587ac..795fcf7 100644 --- a/tools/convert_ghostnet_from_timm.py +++ b/tools/convert_ghostnet_from_timm.py @@ -48,6 +48,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_inception_next_from_timm.py b/tools/convert_inception_next_from_timm.py index 893f57a..b8a9c28 100644 --- a/tools/convert_inception_next_from_timm.py +++ b/tools/convert_inception_next_from_timm.py @@ -43,6 +43,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_mobilenet_v2_from_timm.py b/tools/convert_mobilenet_v2_from_timm.py index 136c5a5..28a2208 100644 --- a/tools/convert_mobilenet_v2_from_timm.py +++ b/tools/convert_mobilenet_v2_from_timm.py @@ -47,6 +47,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_mobilenet_v3_from_timm.py b/tools/convert_mobilenet_v3_from_timm.py index df426d1..10b17b6 100644 --- a/tools/convert_mobilenet_v3_from_timm.py +++ b/tools/convert_mobilenet_v3_from_timm.py @@ -55,6 +55,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_mobileone_from_timm.py b/tools/convert_mobileone_from_timm.py new file mode 100644 index 0000000..84fe4cc --- /dev/null +++ b/tools/convert_mobileone_from_timm.py @@ -0,0 +1,166 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" + +import os + +import keras +import numpy as np +import timm +import torch + +from kimm.models import mobileone +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "mobileone_s0.apple_in1k", + "mobileone_s1.apple_in1k", + "mobileone_s2.apple_in1k", + "mobileone_s3.apple_in1k", + # "mobileone_s4.apple_in1k", +] +keras_model_classes = [ + mobileone.MobileOneS0, + mobileone.MobileOneS1, + mobileone.MobileOneS2, + mobileone.MobileOneS3, + # mobileone.MobileOneS4, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [224, 224, 3] + torch_model = timm.create_model(timm_model_name, pretrained=True) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + weights=None, + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # for torch_name, (_, keras_name) in zip( + # non_trainable_state_dict.keys(), non_trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(non_trainable_state_dict.keys())) + # print(len(non_trainable_weights)) + + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # skip reparam_conv + if "reparam_conv_conv2d" in keras_name: + continue + # mobile_one_conv2d + if "conv.kxk" in torch_name and "kernel" in torch_name: + torch_name = torch_name.replace("conv.kxk", "conv_kxk") + torch_name = torch_name.replace("kernel", "conv.kernel") + if "conv.kxk" in torch_name and "gamma" in torch_name: + torch_name = torch_name.replace("conv.kxk", "conv_kxk") + torch_name = torch_name.replace("gamma", "bn.gamma") + if "conv.kxk" in torch_name and "beta" in torch_name: + torch_name = torch_name.replace("conv.kxk", "conv_kxk") + torch_name = torch_name.replace("beta", "bn.beta") + torch_name = torch_name.replace( + "conv.scale.kernel", "conv_scale.conv.kernel" + ) + torch_name = torch_name.replace( + "conv.scale.gamma", "conv_scale.bn.gamma" + ) + torch_name = torch_name.replace("conv.scale.beta", "conv_scale.bn.beta") + # mobile_one_conv2d bn + if "conv.kxk" in torch_name and "moving.mean" in torch_name: + torch_name = torch_name.replace("conv.kxk", "conv_kxk") + torch_name = torch_name.replace("moving.mean", "bn.moving.mean") + if "conv.kxk" in torch_name and "moving.variance" in torch_name: + torch_name = torch_name.replace("conv.kxk", "conv_kxk") + torch_name = torch_name.replace( + "moving.variance", "bn.moving.variance" + ) + torch_name = torch_name.replace( + "conv.scale.moving.mean", "conv_scale.bn.moving.mean" + ) + torch_name = torch_name.replace( + "conv.scale.moving.variance", "conv_scale.bn.moving.variance" + ) + # head + torch_name = torch_name.replace("classifier", "head.fc") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-4) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}") diff --git a/tools/convert_regnet_from_timm.py b/tools/convert_regnet_from_timm.py index 0343277..a5539b5 100644 --- a/tools/convert_regnet_from_timm.py +++ b/tools/convert_regnet_from_timm.py @@ -85,6 +85,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_resnet_from_timm.py b/tools/convert_resnet_from_timm.py index ff75bd7..440a3d5 100644 --- a/tools/convert_resnet_from_timm.py +++ b/tools/convert_resnet_from_timm.py @@ -47,6 +47,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_vgg_from_timm.py b/tools/convert_vgg_from_timm.py index 410bf9c..b505396 100644 --- a/tools/convert_vgg_from_timm.py +++ b/tools/convert_vgg_from_timm.py @@ -45,6 +45,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_vit_from_timm.py b/tools/convert_vit_from_timm.py index fd9ba2d..a8b78fd 100644 --- a/tools/convert_vit_from_timm.py +++ b/tools/convert_vit_from_timm.py @@ -52,6 +52,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model diff --git a/tools/convert_xception_from_keras.py b/tools/convert_xception_from_keras.py index 77dd946..20467b6 100644 --- a/tools/convert_xception_from_keras.py +++ b/tools/convert_xception_from_keras.py @@ -27,6 +27,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) with tempfile.TemporaryDirectory() as temp_dir: ori_model.save_weights(temp_dir + "/model.weights.h5")