From 245af931328306fc5770a52fdd3ef80a51334dee Mon Sep 17 00:00:00 2001 From: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Date: Thu, 9 Nov 2023 23:41:16 +0000 Subject: [PATCH] generate the files --- examples/vision/ipynb/token_learner.ipynb | 90 +++++++------- examples/vision/md/token_learner.md | 141 +++++++--------------- examples/vision/token_learner.py | 2 +- 3 files changed, 88 insertions(+), 145 deletions(-) diff --git a/examples/vision/ipynb/token_learner.ipynb b/examples/vision/ipynb/token_learner.ipynb index 25decf67b8..6be9f80614 100644 --- a/examples/vision/ipynb/token_learner.ipynb +++ b/examples/vision/ipynb/token_learner.ipynb @@ -8,9 +8,9 @@ "source": [ "# Learning to tokenize in Vision Transformers\n", "\n", - "**Authors:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution)
\n", + "**Authors:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)
\n", "**Date created:** 2021/12/10
\n", - "**Last modified:** 2021/12/15
\n", + "**Last modified:** 2023/08/14
\n", "**Description:** Adaptively generating a smaller number of tokens for Vision Transformers." ] }, @@ -56,22 +56,6 @@ "* [TokenLearner slides from NeurIPS 2021](https://nips.cc/media/neurips-2021/Slides/26578.pdf)" ] }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "## Setup\n", - "\n", - "We need to install TensorFlow Addons to run this example. To install it, execute the\n", - "following:\n", - "\n", - "```shell\n", - "pip install tensorflow-addons\n", - "```" - ] - }, { "cell_type": "markdown", "metadata": { @@ -89,11 +73,11 @@ }, "outputs": [], "source": [ - "import tensorflow as tf\n", + "import keras\n", + "from keras import layers\n", + "from keras import ops\n", + "from tensorflow import data as tf_data\n", "\n", - "from tensorflow import keras\n", - "from tensorflow.keras import layers\n", - "import tensorflow_addons as tfa\n", "\n", "from datetime import datetime\n", "import matplotlib.pyplot as plt\n", @@ -124,7 +108,7 @@ "source": [ "# DATA\n", "BATCH_SIZE = 256\n", - "AUTO = tf.data.AUTOTUNE\n", + "AUTO = tf_data.AUTOTUNE\n", "INPUT_SHAPE = (32, 32, 3)\n", "NUM_CLASSES = 10\n", "\n", @@ -133,7 +117,7 @@ "WEIGHT_DECAY = 1e-4\n", "\n", "# TRAINING\n", - "EPOCHS = 20\n", + "EPOCHS = 1\n", "\n", "# AUGMENTATION\n", "IMAGE_SIZE = 48 # We will resize input images to this size.\n", @@ -182,13 +166,13 @@ "print(f\"Testing samples: {len(x_test)}\")\n", "\n", "# Convert to tf.data.Dataset objects.\n", - "train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", + "train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))\n", "train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)\n", "\n", - "val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))\n", + "val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))\n", "val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)\n", "\n", - "test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n", + "test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))\n", "test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)" ] }, @@ -266,19 +250,25 @@ "outputs": [], "source": [ "\n", - "def position_embedding(\n", - " projected_patches, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM\n", - "):\n", - " # Build the positions.\n", - " positions = tf.range(start=0, limit=num_patches, delta=1)\n", - "\n", - " # Encode the positions with an Embedding layer.\n", - " encoded_positions = layers.Embedding(\n", - " input_dim=num_patches, output_dim=projection_dim\n", - " )(positions)\n", - "\n", - " # Add encoded positions to the projected patches.\n", - " return projected_patches + encoded_positions\n", + "class PatchEncoder(layers.Layer):\n", + " def __init__(self, num_patches, projection_dim):\n", + " super().__init__()\n", + " self.num_patches = num_patches\n", + " self.position_embedding = layers.Embedding(\n", + " input_dim=num_patches, output_dim=projection_dim\n", + " )\n", + "\n", + " def call(self, patch):\n", + " positions = ops.expand_dims(\n", + " ops.arange(start=0, stop=self.num_patches, step=1), axis=0\n", + " )\n", + " encoded = patch + self.position_embedding(positions)\n", + " return encoded\n", + "\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update({\"num_patches\": self.num_patches})\n", + " return config\n", "" ] }, @@ -306,7 +296,7 @@ " # Iterate over the hidden units and\n", " # add Dense => Dropout.\n", " for units in hidden_units:\n", - " x = layers.Dense(units, activation=tf.nn.gelu)(x)\n", + " x = layers.Dense(units, activation=ops.gelu)(x)\n", " x = layers.Dropout(dropout_rate)(x)\n", " return x\n", "" @@ -360,21 +350,21 @@ " layers.Conv2D(\n", " filters=number_of_tokens,\n", " kernel_size=(3, 3),\n", - " activation=tf.nn.gelu,\n", + " activation=ops.gelu,\n", " padding=\"same\",\n", " use_bias=False,\n", " ),\n", " layers.Conv2D(\n", " filters=number_of_tokens,\n", " kernel_size=(3, 3),\n", - " activation=tf.nn.gelu,\n", + " activation=ops.gelu,\n", " padding=\"same\",\n", " use_bias=False,\n", " ),\n", " layers.Conv2D(\n", " filters=number_of_tokens,\n", " kernel_size=(3, 3),\n", - " activation=tf.nn.gelu,\n", + " activation=ops.gelu,\n", " padding=\"same\",\n", " use_bias=False,\n", " ),\n", @@ -400,11 +390,11 @@ "\n", " # Element-Wise multiplication of the attention maps and the inputs\n", " attended_inputs = (\n", - " attention_maps[..., tf.newaxis] * inputs\n", + " ops.expand_dims(attention_maps, axis=-1) * inputs\n", " ) # (B, num_tokens, H*W, C)\n", "\n", " # Global average pooling the element wise multiplication result.\n", - " outputs = tf.reduce_mean(attended_inputs, axis=2) # (B, num_tokens, C)\n", + " outputs = ops.mean(attended_inputs, axis=2) # (B, num_tokens, C)\n", " return outputs\n", "" ] @@ -488,7 +478,9 @@ " ) # (B, number_patches, projection_dim)\n", "\n", " # Add positional embeddings to the projected patches.\n", - " encoded_patches = position_embedding(\n", + " encoded_patches = PatchEncoder(\n", + " num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM\n", + " )(\n", " projected_patches\n", " ) # (B, number_patches, projection_dim)\n", " encoded_patches = layers.Dropout(0.1)(encoded_patches)\n", @@ -556,7 +548,7 @@ "\n", "def run_experiment(model):\n", " # Initialize the AdamW optimizer.\n", - " optimizer = tfa.optimizers.AdamW(\n", + " optimizer = keras.optimizers.AdamW(\n", " learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY\n", " )\n", "\n", @@ -572,7 +564,7 @@ " )\n", "\n", " # Define callbacks\n", - " checkpoint_filepath = \"/tmp/checkpoint\"\n", + " checkpoint_filepath = \"/tmp/checkpoint.weights.h5\"\n", " checkpoint_callback = keras.callbacks.ModelCheckpoint(\n", " checkpoint_filepath,\n", " monitor=\"val_accuracy\",\n", diff --git a/examples/vision/md/token_learner.md b/examples/vision/md/token_learner.md index 6c893fa10f..d9aa08cec2 100644 --- a/examples/vision/md/token_learner.md +++ b/examples/vision/md/token_learner.md @@ -1,8 +1,8 @@ # Learning to tokenize in Vision Transformers -**Authors:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution)
+**Authors:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)
**Date created:** 2021/12/10
-**Last modified:** 2021/12/15
+**Last modified:** 2023/08/14
**Description:** Adaptively generating a smaller number of tokens for Vision Transformers. @@ -46,26 +46,16 @@ references: * [Image Classification with ViTs on keras.io](https://keras.io/examples/vision/image_classification_with_vision_transformer/) * [TokenLearner slides from NeurIPS 2021](https://nips.cc/media/neurips-2021/Slides/26578.pdf) ---- -## Setup - -We need to install TensorFlow Addons to run this example. To install it, execute the -following: - -```shell -pip install tensorflow-addons -``` - --- ## Imports ```python -import tensorflow as tf +import keras +from keras import layers +from keras import ops +from tensorflow import data as tf_data -from tensorflow import keras -from tensorflow.keras import layers -import tensorflow_addons as tfa from datetime import datetime import matplotlib.pyplot as plt @@ -84,7 +74,7 @@ develop intuition about the architecture is to experiment with it. ```python # DATA BATCH_SIZE = 256 -AUTO = tf.data.AUTOTUNE +AUTO = tf_data.AUTOTUNE INPUT_SHAPE = (32, 32, 3) NUM_CLASSES = 10 @@ -93,7 +83,7 @@ LEARNING_RATE = 1e-3 WEIGHT_DECAY = 1e-4 # TRAINING -EPOCHS = 20 +EPOCHS = 1 # AUGMENTATION IMAGE_SIZE = 48 # We will resize input images to this size. @@ -130,13 +120,13 @@ print(f"Validation samples: {len(x_val)}") print(f"Testing samples: {len(x_test)}") # Convert to tf.data.Dataset objects. -train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO) -val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)) +val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val)) val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO) -test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) +test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test)) test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO) ``` @@ -146,10 +136,6 @@ Training samples: 40000 Validation samples: 10000 Testing samples: 10000 -2021-12-15 13:59:48.329729: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA -To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. -2021-12-15 13:59:50.627454: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38444 MB memory: -> device: 0, name: A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0 - ``` --- @@ -195,19 +181,25 @@ tokens. ```python -def position_embedding( - projected_patches, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM -): - # Build the positions. - positions = tf.range(start=0, limit=num_patches, delta=1) - - # Encode the positions with an Embedding layer. - encoded_positions = layers.Embedding( - input_dim=num_patches, output_dim=projection_dim - )(positions) - - # Add encoded positions to the projected patches. - return projected_patches + encoded_positions +class PatchEncoder(layers.Layer): + def __init__(self, num_patches, projection_dim): + super().__init__() + self.num_patches = num_patches + self.position_embedding = layers.Embedding( + input_dim=num_patches, output_dim=projection_dim + ) + + def call(self, patch): + positions = ops.expand_dims( + ops.arange(start=0, stop=self.num_patches, step=1), axis=0 + ) + encoded = patch + self.position_embedding(positions) + return encoded + + def get_config(self): + config = super().get_config() + config.update({"num_patches": self.num_patches}) + return config ``` @@ -223,7 +215,7 @@ def mlp(x, dropout_rate, hidden_units): # Iterate over the hidden units and # add Dense => Dropout. for units in hidden_units: - x = layers.Dense(units, activation=tf.nn.gelu)(x) + x = layers.Dense(units, activation=ops.gelu)(x) x = layers.Dropout(dropout_rate)(x) return x @@ -265,21 +257,21 @@ def token_learner(inputs, number_of_tokens=NUM_TOKENS): layers.Conv2D( filters=number_of_tokens, kernel_size=(3, 3), - activation=tf.nn.gelu, + activation=ops.gelu, padding="same", use_bias=False, ), layers.Conv2D( filters=number_of_tokens, kernel_size=(3, 3), - activation=tf.nn.gelu, + activation=ops.gelu, padding="same", use_bias=False, ), layers.Conv2D( filters=number_of_tokens, kernel_size=(3, 3), - activation=tf.nn.gelu, + activation=ops.gelu, padding="same", use_bias=False, ), @@ -305,11 +297,11 @@ def token_learner(inputs, number_of_tokens=NUM_TOKENS): # Element-Wise multiplication of the attention maps and the inputs attended_inputs = ( - attention_maps[..., tf.newaxis] * inputs + ops.expand_dims(attention_maps, axis=-1) * inputs ) # (B, num_tokens, H*W, C) # Global average pooling the element wise multiplication result. - outputs = tf.reduce_mean(attended_inputs, axis=2) # (B, num_tokens, C) + outputs = ops.mean(attended_inputs, axis=2) # (B, num_tokens, C) return outputs ``` @@ -369,7 +361,9 @@ def create_vit_classifier(use_token_learner=True, token_learner_units=NUM_TOKENS ) # (B, number_patches, projection_dim) # Add positional embeddings to the projected patches. - encoded_patches = position_embedding( + encoded_patches = PatchEncoder( + num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM + )( projected_patches ) # (B, number_patches, projection_dim) encoded_patches = layers.Dropout(0.1)(encoded_patches) @@ -418,7 +412,7 @@ network. def run_experiment(model): # Initialize the AdamW optimizer. - optimizer = tfa.optimizers.AdamW( + optimizer = keras.optimizers.AdamW( learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY ) @@ -434,7 +428,7 @@ def run_experiment(model): ) # Define callbacks - checkpoint_filepath = "/tmp/checkpoint" + checkpoint_filepath = "/tmp/checkpoint.weights.h5" checkpoint_callback = keras.callbacks.ModelCheckpoint( checkpoint_filepath, monitor="val_accuracy", @@ -468,53 +462,10 @@ run_experiment(vit_token_learner)
``` -Epoch 1/20 - -2021-12-15 13:59:59.531011: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8200 -2021-12-15 14:00:04.728435: I tensorflow/stream_executor/cuda/cuda_blas.cc:1774] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once. - -157/157 [==============================] - 20s 39ms/step - loss: 2.2716 - accuracy: 0.1396 - top-5-accuracy: 0.5908 - val_loss: 2.0672 - val_accuracy: 0.2004 - val_top-5-accuracy: 0.7632 -Epoch 2/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.9780 - accuracy: 0.2488 - top-5-accuracy: 0.7917 - val_loss: 1.8621 - val_accuracy: 0.2986 - val_top-5-accuracy: 0.8391 -Epoch 3/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.8168 - accuracy: 0.3138 - top-5-accuracy: 0.8437 - val_loss: 1.7044 - val_accuracy: 0.3680 - val_top-5-accuracy: 0.8793 -Epoch 4/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.6765 - accuracy: 0.3701 - top-5-accuracy: 0.8820 - val_loss: 1.6490 - val_accuracy: 0.3857 - val_top-5-accuracy: 0.8809 -Epoch 5/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.6091 - accuracy: 0.4058 - top-5-accuracy: 0.8978 - val_loss: 1.5899 - val_accuracy: 0.4221 - val_top-5-accuracy: 0.8989 -Epoch 6/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.5386 - accuracy: 0.4340 - top-5-accuracy: 0.9097 - val_loss: 1.5434 - val_accuracy: 0.4321 - val_top-5-accuracy: 0.9098 -Epoch 7/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.4944 - accuracy: 0.4481 - top-5-accuracy: 0.9171 - val_loss: 1.4914 - val_accuracy: 0.4674 - val_top-5-accuracy: 0.9146 -Epoch 8/20 -157/157 [==============================] - 5s 33ms/step - loss: 1.4767 - accuracy: 0.4586 - top-5-accuracy: 0.9179 - val_loss: 1.5280 - val_accuracy: 0.4528 - val_top-5-accuracy: 0.9090 -Epoch 9/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.4331 - accuracy: 0.4751 - top-5-accuracy: 0.9248 - val_loss: 1.3996 - val_accuracy: 0.4857 - val_top-5-accuracy: 0.9298 -Epoch 10/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.3990 - accuracy: 0.4925 - top-5-accuracy: 0.9291 - val_loss: 1.3888 - val_accuracy: 0.4872 - val_top-5-accuracy: 0.9308 -Epoch 11/20 -157/157 [==============================] - 5s 33ms/step - loss: 1.3646 - accuracy: 0.5019 - top-5-accuracy: 0.9355 - val_loss: 1.4330 - val_accuracy: 0.4811 - val_top-5-accuracy: 0.9208 -Epoch 12/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.3607 - accuracy: 0.5037 - top-5-accuracy: 0.9354 - val_loss: 1.3242 - val_accuracy: 0.5149 - val_top-5-accuracy: 0.9415 -Epoch 13/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.3303 - accuracy: 0.5170 - top-5-accuracy: 0.9384 - val_loss: 1.2934 - val_accuracy: 0.5295 - val_top-5-accuracy: 0.9437 -Epoch 14/20 -157/157 [==============================] - 5s 33ms/step - loss: 1.3038 - accuracy: 0.5259 - top-5-accuracy: 0.9426 - val_loss: 1.3102 - val_accuracy: 0.5187 - val_top-5-accuracy: 0.9422 -Epoch 15/20 -157/157 [==============================] - 5s 33ms/step - loss: 1.2926 - accuracy: 0.5304 - top-5-accuracy: 0.9441 - val_loss: 1.3220 - val_accuracy: 0.5234 - val_top-5-accuracy: 0.9428 -Epoch 16/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.2724 - accuracy: 0.5346 - top-5-accuracy: 0.9458 - val_loss: 1.2670 - val_accuracy: 0.5370 - val_top-5-accuracy: 0.9491 -Epoch 17/20 -157/157 [==============================] - 5s 33ms/step - loss: 1.2515 - accuracy: 0.5450 - top-5-accuracy: 0.9462 - val_loss: 1.2837 - val_accuracy: 0.5349 - val_top-5-accuracy: 0.9474 -Epoch 18/20 -157/157 [==============================] - 5s 33ms/step - loss: 1.2427 - accuracy: 0.5505 - top-5-accuracy: 0.9492 - val_loss: 1.3425 - val_accuracy: 0.5180 - val_top-5-accuracy: 0.9371 -Epoch 19/20 -157/157 [==============================] - 5s 34ms/step - loss: 1.2129 - accuracy: 0.5605 - top-5-accuracy: 0.9514 - val_loss: 1.2297 - val_accuracy: 0.5590 - val_top-5-accuracy: 0.9536 -Epoch 20/20 -157/157 [==============================] - 5s 33ms/step - loss: 1.1994 - accuracy: 0.5667 - top-5-accuracy: 0.9523 - val_loss: 1.2390 - val_accuracy: 0.5577 - val_top-5-accuracy: 0.9528 -40/40 [==============================] - 0s 11ms/step - loss: 1.2293 - accuracy: 0.5564 - top-5-accuracy: 0.9549 -Test accuracy: 55.64% -Test top 5 accuracy: 95.49% + 157/157 ━━━━━━━━━━━━━━━━━━━━ 303s 2s/step - accuracy: 0.1158 - loss: 2.4798 - top-5-accuracy: 0.5352 - val_accuracy: 0.2206 - val_loss: 2.0292 - val_top-5-accuracy: 0.7688 + 40/40 ━━━━━━━━━━━━━━━━━━━━ 5s 133ms/step - accuracy: 0.2298 - loss: 2.0179 - top-5-accuracy: 0.7723 +Test accuracy: 22.9% +Test top 5 accuracy: 77.22% ```
@@ -598,4 +549,4 @@ author of TokenLearner) for fruitful discussions. | Trained Model | Demo | | :--: | :--: | -| [![Generic badge](https://img.shields.io/badge/🤗%20Model-TokenLearner-black.svg)](https://huggingface.co/keras-io/learning_to_tokenize_in_ViT) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-TokenLearner-black.svg)](https://huggingface.co/spaces/keras-io/token_learner) | \ No newline at end of file +| [![Generic badge](https://img.shields.io/badge/🤗%20Model-TokenLearner-black.svg)](https://huggingface.co/keras-io/learning_to_tokenize_in_ViT) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-TokenLearner-black.svg)](https://huggingface.co/spaces/keras-io/token_learner) | diff --git a/examples/vision/token_learner.py b/examples/vision/token_learner.py index b19bfc2a99..2e30a7b1e9 100644 --- a/examples/vision/token_learner.py +++ b/examples/vision/token_learner.py @@ -78,7 +78,7 @@ WEIGHT_DECAY = 1e-4 # TRAINING -EPOCHS = 20 +EPOCHS = 1 # AUGMENTATION IMAGE_SIZE = 48 # We will resize input images to this size.