diff --git a/examples/vision/ipynb/token_learner.ipynb b/examples/vision/ipynb/token_learner.ipynb
index 25decf67b8..6be9f80614 100644
--- a/examples/vision/ipynb/token_learner.ipynb
+++ b/examples/vision/ipynb/token_learner.ipynb
@@ -8,9 +8,9 @@
"source": [
"# Learning to tokenize in Vision Transformers\n",
"\n",
- "**Authors:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution)
\n",
+ "**Authors:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)
\n",
"**Date created:** 2021/12/10
\n",
- "**Last modified:** 2021/12/15
\n",
+ "**Last modified:** 2023/08/14
\n",
"**Description:** Adaptively generating a smaller number of tokens for Vision Transformers."
]
},
@@ -56,22 +56,6 @@
"* [TokenLearner slides from NeurIPS 2021](https://nips.cc/media/neurips-2021/Slides/26578.pdf)"
]
},
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "## Setup\n",
- "\n",
- "We need to install TensorFlow Addons to run this example. To install it, execute the\n",
- "following:\n",
- "\n",
- "```shell\n",
- "pip install tensorflow-addons\n",
- "```"
- ]
- },
{
"cell_type": "markdown",
"metadata": {
@@ -89,11 +73,11 @@
},
"outputs": [],
"source": [
- "import tensorflow as tf\n",
+ "import keras\n",
+ "from keras import layers\n",
+ "from keras import ops\n",
+ "from tensorflow import data as tf_data\n",
"\n",
- "from tensorflow import keras\n",
- "from tensorflow.keras import layers\n",
- "import tensorflow_addons as tfa\n",
"\n",
"from datetime import datetime\n",
"import matplotlib.pyplot as plt\n",
@@ -124,7 +108,7 @@
"source": [
"# DATA\n",
"BATCH_SIZE = 256\n",
- "AUTO = tf.data.AUTOTUNE\n",
+ "AUTO = tf_data.AUTOTUNE\n",
"INPUT_SHAPE = (32, 32, 3)\n",
"NUM_CLASSES = 10\n",
"\n",
@@ -133,7 +117,7 @@
"WEIGHT_DECAY = 1e-4\n",
"\n",
"# TRAINING\n",
- "EPOCHS = 20\n",
+ "EPOCHS = 1\n",
"\n",
"# AUGMENTATION\n",
"IMAGE_SIZE = 48 # We will resize input images to this size.\n",
@@ -182,13 +166,13 @@
"print(f\"Testing samples: {len(x_test)}\")\n",
"\n",
"# Convert to tf.data.Dataset objects.\n",
- "train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
+ "train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))\n",
"train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)\n",
"\n",
- "val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))\n",
+ "val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))\n",
"val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)\n",
"\n",
- "test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
+ "test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))\n",
"test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)"
]
},
@@ -266,19 +250,25 @@
"outputs": [],
"source": [
"\n",
- "def position_embedding(\n",
- " projected_patches, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM\n",
- "):\n",
- " # Build the positions.\n",
- " positions = tf.range(start=0, limit=num_patches, delta=1)\n",
- "\n",
- " # Encode the positions with an Embedding layer.\n",
- " encoded_positions = layers.Embedding(\n",
- " input_dim=num_patches, output_dim=projection_dim\n",
- " )(positions)\n",
- "\n",
- " # Add encoded positions to the projected patches.\n",
- " return projected_patches + encoded_positions\n",
+ "class PatchEncoder(layers.Layer):\n",
+ " def __init__(self, num_patches, projection_dim):\n",
+ " super().__init__()\n",
+ " self.num_patches = num_patches\n",
+ " self.position_embedding = layers.Embedding(\n",
+ " input_dim=num_patches, output_dim=projection_dim\n",
+ " )\n",
+ "\n",
+ " def call(self, patch):\n",
+ " positions = ops.expand_dims(\n",
+ " ops.arange(start=0, stop=self.num_patches, step=1), axis=0\n",
+ " )\n",
+ " encoded = patch + self.position_embedding(positions)\n",
+ " return encoded\n",
+ "\n",
+ " def get_config(self):\n",
+ " config = super().get_config()\n",
+ " config.update({\"num_patches\": self.num_patches})\n",
+ " return config\n",
""
]
},
@@ -306,7 +296,7 @@
" # Iterate over the hidden units and\n",
" # add Dense => Dropout.\n",
" for units in hidden_units:\n",
- " x = layers.Dense(units, activation=tf.nn.gelu)(x)\n",
+ " x = layers.Dense(units, activation=ops.gelu)(x)\n",
" x = layers.Dropout(dropout_rate)(x)\n",
" return x\n",
""
@@ -360,21 +350,21 @@
" layers.Conv2D(\n",
" filters=number_of_tokens,\n",
" kernel_size=(3, 3),\n",
- " activation=tf.nn.gelu,\n",
+ " activation=ops.gelu,\n",
" padding=\"same\",\n",
" use_bias=False,\n",
" ),\n",
" layers.Conv2D(\n",
" filters=number_of_tokens,\n",
" kernel_size=(3, 3),\n",
- " activation=tf.nn.gelu,\n",
+ " activation=ops.gelu,\n",
" padding=\"same\",\n",
" use_bias=False,\n",
" ),\n",
" layers.Conv2D(\n",
" filters=number_of_tokens,\n",
" kernel_size=(3, 3),\n",
- " activation=tf.nn.gelu,\n",
+ " activation=ops.gelu,\n",
" padding=\"same\",\n",
" use_bias=False,\n",
" ),\n",
@@ -400,11 +390,11 @@
"\n",
" # Element-Wise multiplication of the attention maps and the inputs\n",
" attended_inputs = (\n",
- " attention_maps[..., tf.newaxis] * inputs\n",
+ " ops.expand_dims(attention_maps, axis=-1) * inputs\n",
" ) # (B, num_tokens, H*W, C)\n",
"\n",
" # Global average pooling the element wise multiplication result.\n",
- " outputs = tf.reduce_mean(attended_inputs, axis=2) # (B, num_tokens, C)\n",
+ " outputs = ops.mean(attended_inputs, axis=2) # (B, num_tokens, C)\n",
" return outputs\n",
""
]
@@ -488,7 +478,9 @@
" ) # (B, number_patches, projection_dim)\n",
"\n",
" # Add positional embeddings to the projected patches.\n",
- " encoded_patches = position_embedding(\n",
+ " encoded_patches = PatchEncoder(\n",
+ " num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM\n",
+ " )(\n",
" projected_patches\n",
" ) # (B, number_patches, projection_dim)\n",
" encoded_patches = layers.Dropout(0.1)(encoded_patches)\n",
@@ -556,7 +548,7 @@
"\n",
"def run_experiment(model):\n",
" # Initialize the AdamW optimizer.\n",
- " optimizer = tfa.optimizers.AdamW(\n",
+ " optimizer = keras.optimizers.AdamW(\n",
" learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY\n",
" )\n",
"\n",
@@ -572,7 +564,7 @@
" )\n",
"\n",
" # Define callbacks\n",
- " checkpoint_filepath = \"/tmp/checkpoint\"\n",
+ " checkpoint_filepath = \"/tmp/checkpoint.weights.h5\"\n",
" checkpoint_callback = keras.callbacks.ModelCheckpoint(\n",
" checkpoint_filepath,\n",
" monitor=\"val_accuracy\",\n",
diff --git a/examples/vision/md/token_learner.md b/examples/vision/md/token_learner.md
index 6c893fa10f..d9aa08cec2 100644
--- a/examples/vision/md/token_learner.md
+++ b/examples/vision/md/token_learner.md
@@ -1,8 +1,8 @@
# Learning to tokenize in Vision Transformers
-**Authors:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution)
+**Authors:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)
**Date created:** 2021/12/10
-**Last modified:** 2021/12/15
+**Last modified:** 2023/08/14
**Description:** Adaptively generating a smaller number of tokens for Vision Transformers.
@@ -46,26 +46,16 @@ references:
* [Image Classification with ViTs on keras.io](https://keras.io/examples/vision/image_classification_with_vision_transformer/)
* [TokenLearner slides from NeurIPS 2021](https://nips.cc/media/neurips-2021/Slides/26578.pdf)
----
-## Setup
-
-We need to install TensorFlow Addons to run this example. To install it, execute the
-following:
-
-```shell
-pip install tensorflow-addons
-```
-
---
## Imports
```python
-import tensorflow as tf
+import keras
+from keras import layers
+from keras import ops
+from tensorflow import data as tf_data
-from tensorflow import keras
-from tensorflow.keras import layers
-import tensorflow_addons as tfa
from datetime import datetime
import matplotlib.pyplot as plt
@@ -84,7 +74,7 @@ develop intuition about the architecture is to experiment with it.
```python
# DATA
BATCH_SIZE = 256
-AUTO = tf.data.AUTOTUNE
+AUTO = tf_data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
NUM_CLASSES = 10
@@ -93,7 +83,7 @@ LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
# TRAINING
-EPOCHS = 20
+EPOCHS = 1
# AUGMENTATION
IMAGE_SIZE = 48 # We will resize input images to this size.
@@ -130,13 +120,13 @@ print(f"Validation samples: {len(x_val)}")
print(f"Testing samples: {len(x_test)}")
# Convert to tf.data.Dataset objects.
-train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)
-val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
+val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)
-test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
```
@@ -146,10 +136,6 @@ Training samples: 40000
Validation samples: 10000
Testing samples: 10000
-2021-12-15 13:59:48.329729: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA
-To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
-2021-12-15 13:59:50.627454: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38444 MB memory: -> device: 0, name: A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0
-
```
---
@@ -195,19 +181,25 @@ tokens.
```python
-def position_embedding(
- projected_patches, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
-):
- # Build the positions.
- positions = tf.range(start=0, limit=num_patches, delta=1)
-
- # Encode the positions with an Embedding layer.
- encoded_positions = layers.Embedding(
- input_dim=num_patches, output_dim=projection_dim
- )(positions)
-
- # Add encoded positions to the projected patches.
- return projected_patches + encoded_positions
+class PatchEncoder(layers.Layer):
+ def __init__(self, num_patches, projection_dim):
+ super().__init__()
+ self.num_patches = num_patches
+ self.position_embedding = layers.Embedding(
+ input_dim=num_patches, output_dim=projection_dim
+ )
+
+ def call(self, patch):
+ positions = ops.expand_dims(
+ ops.arange(start=0, stop=self.num_patches, step=1), axis=0
+ )
+ encoded = patch + self.position_embedding(positions)
+ return encoded
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"num_patches": self.num_patches})
+ return config
```
@@ -223,7 +215,7 @@ def mlp(x, dropout_rate, hidden_units):
# Iterate over the hidden units and
# add Dense => Dropout.
for units in hidden_units:
- x = layers.Dense(units, activation=tf.nn.gelu)(x)
+ x = layers.Dense(units, activation=ops.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
@@ -265,21 +257,21 @@ def token_learner(inputs, number_of_tokens=NUM_TOKENS):
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
- activation=tf.nn.gelu,
+ activation=ops.gelu,
padding="same",
use_bias=False,
),
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
- activation=tf.nn.gelu,
+ activation=ops.gelu,
padding="same",
use_bias=False,
),
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
- activation=tf.nn.gelu,
+ activation=ops.gelu,
padding="same",
use_bias=False,
),
@@ -305,11 +297,11 @@ def token_learner(inputs, number_of_tokens=NUM_TOKENS):
# Element-Wise multiplication of the attention maps and the inputs
attended_inputs = (
- attention_maps[..., tf.newaxis] * inputs
+ ops.expand_dims(attention_maps, axis=-1) * inputs
) # (B, num_tokens, H*W, C)
# Global average pooling the element wise multiplication result.
- outputs = tf.reduce_mean(attended_inputs, axis=2) # (B, num_tokens, C)
+ outputs = ops.mean(attended_inputs, axis=2) # (B, num_tokens, C)
return outputs
```
@@ -369,7 +361,9 @@ def create_vit_classifier(use_token_learner=True, token_learner_units=NUM_TOKENS
) # (B, number_patches, projection_dim)
# Add positional embeddings to the projected patches.
- encoded_patches = position_embedding(
+ encoded_patches = PatchEncoder(
+ num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
+ )(
projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = layers.Dropout(0.1)(encoded_patches)
@@ -418,7 +412,7 @@ network.
def run_experiment(model):
# Initialize the AdamW optimizer.
- optimizer = tfa.optimizers.AdamW(
+ optimizer = keras.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
@@ -434,7 +428,7 @@ def run_experiment(model):
)
# Define callbacks
- checkpoint_filepath = "/tmp/checkpoint"
+ checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
@@ -468,53 +462,10 @@ run_experiment(vit_token_learner)