Skip to content

Commit

Permalink
Add arbitrary shape support for ViT series (#46)
Browse files Browse the repository at this point in the history
* Improvements
1. Add arbitrary shape support to ViT and MobileViT
2. Simplify the logic in BaseModel
3. Add more model weights

* Update version number

* Improve test coverage

* Fix numpy tests
  • Loading branch information
james77777778 authored May 22, 2024
1 parent 2160b3e commit e2d31cc
Show file tree
Hide file tree
Showing 33 changed files with 509 additions and 240 deletions.
2 changes: 1 addition & 1 deletion kimm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
from kimm._src.utils.model_registry import list_models
from kimm._src.version import version

__version__ = "0.2.0"
__version__ = "0.2.1"
5 changes: 2 additions & 3 deletions kimm/_src/blocks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def apply_transformer_block(
num_heads: int,
mlp_ratio: float = 4.0,
use_qkv_bias: bool = False,
use_qk_norm: bool = False,
projection_dropout_rate: float = 0.0,
attention_dropout_rate: float = 0.0,
activation: str = "gelu",
name: str = "transformer_block",
):
# data_format must be "channels_last"
x = inputs
residual_1 = x

Expand All @@ -65,7 +65,6 @@ def apply_transformer_block(
dim,
num_heads,
use_qkv_bias,
use_qk_norm,
attention_dropout_rate,
projection_dropout_rate,
name=f"{name}_attn",
Expand All @@ -79,7 +78,7 @@ def apply_transformer_block(
int(dim * mlp_ratio),
activation=activation,
dropout_rate=projection_dropout_rate,
data_format="channels_last", # TODO: let backend decides
data_format="channels_last",
name=f"{name}_mlp",
)
x = layers.Add()([residual_2, x])
Expand Down
69 changes: 37 additions & 32 deletions kimm/_src/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import keras
from keras import InputSpec
from keras import layers
from keras import ops

Expand All @@ -13,7 +14,6 @@ def __init__(
hidden_dim: int,
num_heads: int = 8,
use_qkv_bias: bool = False,
use_qk_norm: bool = False,
attention_dropout_rate: float = 0.0,
projection_dropout_rate: float = 0.0,
**kwargs,
Expand All @@ -24,7 +24,6 @@ def __init__(
self.head_dim = hidden_dim // num_heads
self.scale = self.head_dim ** (-0.5)
self.use_qkv_bias = use_qkv_bias
self.use_qk_norm = use_qk_norm
self.attention_dropout_rate = attention_dropout_rate
self.projection_dropout_rate = projection_dropout_rate

Expand All @@ -34,16 +33,6 @@ def __init__(
dtype=self.dtype_policy,
name=f"{self.name}_qkv",
)
if use_qk_norm:
self.q_norm = layers.LayerNormalization(
dtype=self.dtype_policy, name=f"{self.name}_q_norm"
)
self.k_norm = layers.LayerNormalization(
dtype=self.dtype_policy, name=f"{self.name}_k_norm"
)
else:
self.q_norm = layers.Identity(dtype=self.dtype_policy)
self.k_norm = layers.Identity(dtype=self.dtype_policy)

self.attention_dropout = layers.Dropout(
attention_dropout_rate,
Expand All @@ -60,11 +49,16 @@ def __init__(
)

def build(self, input_shape):
self.input_spec = InputSpec(ndim=len(input_shape))
if self.input_spec.ndim not in (3, 4):
raise ValueError(
"The ndim of the inputs must be 3 or 4. "
f"Received: input_shape={input_shape}"
)

self.qkv.build(input_shape)
qkv_output_shape = list(input_shape)
qkv_output_shape[-1] = qkv_output_shape[-1] * 3
self.q_norm.build(qkv_output_shape)
self.k_norm.build(qkv_output_shape)
attention_input_shape = [
input_shape[0],
self.num_heads,
Expand All @@ -79,30 +73,42 @@ def build(self, input_shape):
def call(self, inputs, training=None, mask=None):
input_shape = ops.shape(inputs)
qkv = self.qkv(inputs)
qkv = ops.reshape(
qkv,
[
input_shape[0],
input_shape[1],
3,
self.num_heads,
self.head_dim,
],
)
qkv = ops.transpose(qkv, [2, 0, 3, 1, 4])
q, k, v = ops.unstack(qkv, 3, axis=0)
q = self.q_norm(q)
k = self.k_norm(k)
if self.input_spec.ndim == 3:
qkv = ops.reshape(
qkv,
[
input_shape[0],
input_shape[1],
3,
self.num_heads,
self.head_dim,
],
)
qkv = ops.transpose(qkv, [0, 3, 2, 1, 4])
q, k, v = ops.unstack(qkv, 3, axis=2)
else:
# self.input_spec.ndim==4
qkv = ops.reshape(
qkv,
[
input_shape[0],
input_shape[1],
input_shape[2],
3,
self.num_heads,
self.head_dim,
],
)
qkv = ops.transpose(qkv, [0, 1, 4, 3, 2, 5])
q, k, v = ops.unstack(qkv, 3, axis=3)

# attention
q = ops.multiply(q, self.scale)
attn = ops.matmul(q, ops.swapaxes(k, -2, -1))
attn = ops.softmax(attn)
attn = self.attention_dropout(attn)
x = ops.matmul(attn, v)

x = ops.swapaxes(x, 1, 2)
x = ops.reshape(x, input_shape)
x = ops.reshape(ops.swapaxes(x, -3, -2), input_shape)
x = self.projection(x)
x = self.projection_dropout(x)
return x
Expand All @@ -114,7 +120,6 @@ def get_config(self):
"hidden_dim": self.hidden_dim,
"num_heads": self.num_heads,
"use_qkv_bias": self.use_qkv_bias,
"use_qk_norm": self.use_qk_norm,
"attention_dropout_rate": self.attention_dropout_rate,
"projection_dropout_rate": self.projection_dropout_rate,
"name": self.name,
Expand Down
31 changes: 30 additions & 1 deletion kimm/_src/layers/attention_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import keras
import pytest
from absl.testing import parameterized
from keras.src import testing
Expand All @@ -7,7 +8,7 @@

class AttentionTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_attention_basic(self):
def test_basic_3d(self):
self.run_layer_test(
Attention,
init_kwargs={"hidden_dim": 20, "num_heads": 2},
Expand All @@ -18,3 +19,31 @@ def test_attention_basic(self):
expected_num_losses=0,
supports_masking=False,
)

@pytest.mark.requires_trainable_backend
def test_basic_4d(self):
self.run_layer_test(
Attention,
init_kwargs={"hidden_dim": 20, "num_heads": 2},
input_shape=(1, 2, 10, 20),
expected_output_shape=(1, 2, 10, 20),
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)

def test_invalid_ndim(self):
# Test 2D
inputs = keras.Input(shape=[1])
with self.assertRaisesRegex(
ValueError, "The ndim of the inputs must be 3 or 4."
):
Attention(1, 1)(inputs)

# Test 5D
inputs = keras.Input(shape=[1, 2, 3, 4])
with self.assertRaisesRegex(
ValueError, "The ndim of the inputs must be 3 or 4."
):
Attention(1, 1)(inputs)
2 changes: 1 addition & 1 deletion kimm/_src/layers/layer_scale_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class LayerScaleTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer_scale_basic(self):
def test_basic(self):
self.run_layer_test(
LayerScale,
init_kwargs={"axis": -1},
Expand Down
2 changes: 1 addition & 1 deletion kimm/_src/layers/learnable_affine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class LearnableAffineTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer_scale_basic(self):
def test_basic(self):
self.run_layer_test(
LearnableAffine,
init_kwargs={"scale_value": 1.0, "bias_value": 0.0},
Expand Down
4 changes: 2 additions & 2 deletions kimm/_src/layers/mobile_one_conv2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
class MobileOneConv2DTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(TEST_CASES)
@pytest.mark.requires_trainable_backend
def test_mobile_one_conv2d_basic(
def test_basic(
self,
filters,
kernel_size,
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_mobile_one_conv2d_basic(
)

@parameterized.parameters(TEST_CASES)
def test_mobile_one_conv2d_get_reparameterized_weights(
def test_get_reparameterized_weights(
self,
filters,
kernel_size,
Expand Down
49 changes: 47 additions & 2 deletions kimm/_src/layers/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,24 @@
@kimm_export(parent_path=["kimm.layers"])
@keras.saving.register_keras_serializable(package="kimm")
class PositionEmbedding(layers.Layer):
def __init__(self, **kwargs):
def __init__(self, height, width, **kwargs):
super().__init__(**kwargs)
# We need height and width for saving and loading
self.height = int(height)
self.width = int(width)

def build(self, input_shape):
if len(input_shape) != 3:
raise ValueError(
"PositionEmbedding only accepts 3-dimensional input. "
f"Received: input_shape={input_shape}"
)
if self.height * self.width != input_shape[-2]:
raise ValueError(
"The embedding size doesn't match the height and width. "
f"Received: height={self.height}, width={self.width}, "
f"input_shape={input_shape}"
)
self.pos_embed = self.add_weight(
shape=[1, input_shape[-2] + 1, input_shape[-1]],
initializer="random_normal",
Expand All @@ -41,5 +50,41 @@ def compute_output_shape(self, input_shape):
output_shape[1] = output_shape[1] + 1
return output_shape

def save_own_variables(self, store):
super().save_own_variables(store)
# Add height and width information
store["height"] = self.height
store["width"] = self.width

def load_own_variables(self, store):
old_height = int(store["height"][...])
old_width = int(store["width"][...])
if old_height == self.height and old_width == self.width:
self.pos_embed.assign(store["0"])
self.cls_token.assign(store["1"])
return

# Resize the embedding if there is a shape mismatch
pos_embed = store["0"]
pos_embed_prefix, pos_embed = pos_embed[:, :1], pos_embed[:, 1:]
pos_embed_dim = pos_embed.shape[-1]
pos_embed = ops.cast(pos_embed, "float32")
pos_embed = ops.reshape(pos_embed, [1, old_height, old_width, -1])
pos_embed = ops.image.resize(
pos_embed,
size=[self.height, self.width],
interpolation="bilinear",
antialias=True,
data_format="channels_last",
)
pos_embed = ops.reshape(pos_embed, [1, -1, pos_embed_dim])
pos_embed = ops.concatenate([pos_embed_prefix, pos_embed], axis=1)
self.pos_embed.assign(pos_embed)
self.cls_token.assign(store["1"])

def get_config(self):
return super().get_config()
config = super().get_config()
config.update(
{"height": self.height, "width": self.width, "name": self.name}
)
return config
22 changes: 18 additions & 4 deletions kimm/_src/layers/position_embedding_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import pytest
from absl.testing import parameterized
from keras import layers
from keras import models
from keras.src import testing

from kimm._src.layers.position_embedding import PositionEmbedding


class PositionEmbeddingTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_position_embedding_basic(self):
def test_basic(self):
self.run_layer_test(
PositionEmbedding,
init_kwargs={},
init_kwargs={"height": 2, "width": 5},
input_shape=(1, 10, 10),
expected_output_shape=(1, 11, 10),
expected_num_trainable_weights=2,
Expand All @@ -20,10 +21,23 @@ def test_position_embedding_basic(self):
supports_masking=False,
)

def test_embedding_resizing(self):
temp_dir = self.get_temp_dir()
model = models.Sequential(
[layers.Input(shape=[256, 8]), PositionEmbedding(16, 16)]
)
model.save(f"{temp_dir}/model.keras")

# Resize from (16, 16) to (8, 8)
model = models.Sequential(
[layers.Input(shape=[64, 8]), PositionEmbedding(8, 8)]
)
model.load_weights(f"{temp_dir}/model.keras")

@pytest.mark.requires_trainable_backend
def test_position_embedding_invalid_input_shape(self):
def test_invalid_input_shape(self):
inputs = layers.Input([3])
with self.assertRaisesRegex(
ValueError, "PositionEmbedding only accepts 3-dimensional input."
):
PositionEmbedding()(inputs)
PositionEmbedding(2, 2)(inputs)
4 changes: 2 additions & 2 deletions kimm/_src/layers/rep_conv2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
class RepConv2DTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(TEST_CASES)
@pytest.mark.requires_trainable_backend
def test_rep_conv2d_basic(
def test_basic(
self,
filters,
kernel_size,
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_rep_conv2d_basic(
)

@parameterized.parameters(TEST_CASES)
def test_rep_conv2d_get_reparameterized_weights(
def test_get_reparameterized_weights(
self,
filters,
kernel_size,
Expand Down
Loading

0 comments on commit e2d31cc

Please sign in to comment.