From faf8ec1c12e168ec0d0274993144d64de78cb58b Mon Sep 17 00:00:00 2001
From: Sitam Meur <103279526+sitamgithub-MSIT@users.noreply.github.com>
Date: Mon, 4 Mar 2024 13:39:15 +0530
Subject: [PATCH] Updating Semi-supervised image classification using
contrastive pretraining with SimCLR Keras 3 example (TF-Only) (#1777)
* Update to the existing keras3 example
* left tf ops replaced with keras ops
* formatting done
* seed generator added
* .md and .ipynb file added
---
.../vision/ipynb/semisupervised_simclr.ipynb | 62 ++++++++++---------
examples/vision/md/semisupervised_simclr.md | 37 ++++++-----
examples/vision/semisupervised_simclr.py | 38 +++++++-----
3 files changed, 79 insertions(+), 58 deletions(-)
diff --git a/examples/vision/ipynb/semisupervised_simclr.ipynb b/examples/vision/ipynb/semisupervised_simclr.ipynb
index 838cd73312..1af081a2da 100644
--- a/examples/vision/ipynb/semisupervised_simclr.ipynb
+++ b/examples/vision/ipynb/semisupervised_simclr.ipynb
@@ -8,9 +8,9 @@
"source": [
"# Semi-supervised image classification using contrastive pretraining with SimCLR\n",
"\n",
- "**Author:** [Andr\u00e1s B\u00e9res](https://www.linkedin.com/in/andras-beres-789190210)
\n",
+ "**Author:** [András Béres](https://www.linkedin.com/in/andras-beres-789190210)
\n",
"**Date created:** 2021/04/24
\n",
- "**Last modified:** 2021/04/24
\n",
+ "**Last modified:** 2024/03/04
\n",
"**Description:** Contrastive pretraining with SimCLR for semi-supervised image classification on the STL-10 dataset."
]
},
@@ -86,7 +86,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -109,6 +109,7 @@
"import tensorflow_datasets as tfds\n",
"\n",
"import keras\n",
+ "from keras import ops\n",
"from keras import layers"
]
},
@@ -123,7 +124,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -162,7 +163,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -245,7 +246,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -257,6 +258,7 @@
" def __init__(self, brightness=0, jitter=0, **kwargs):\n",
" super().__init__(**kwargs)\n",
"\n",
+ " self.seed_generator = keras.random.SeedGenerator(1337)\n",
" self.brightness = brightness\n",
" self.jitter = jitter\n",
"\n",
@@ -267,24 +269,29 @@
"\n",
" def call(self, images, training=True):\n",
" if training:\n",
- " batch_size = tf.shape(images)[0]\n",
+ " batch_size = ops.shape(images)[0]\n",
"\n",
" # Same for all colors\n",
- " brightness_scales = 1 + tf.random.uniform(\n",
+ " brightness_scales = 1 + keras.random.uniform(\n",
" (batch_size, 1, 1, 1),\n",
" minval=-self.brightness,\n",
" maxval=self.brightness,\n",
+ " seed=self.seed_generator,\n",
" )\n",
" # Different for all colors\n",
- " jitter_matrices = tf.random.uniform(\n",
- " (batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter\n",
+ " jitter_matrices = keras.random.uniform(\n",
+ " (batch_size, 1, 3, 3), \n",
+ " minval=-self.jitter, \n",
+ " maxval=self.jitter,\n",
+ " seed=self.seed_generator,\n",
" )\n",
"\n",
" color_transforms = (\n",
- " tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales\n",
+ " ops.tile(ops.expand_dims(ops.eye(3), axis=0), (batch_size, 1, 1, 1))\n",
+ " * brightness_scales\n",
" + jitter_matrices\n",
" )\n",
- " images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)\n",
+ " images = ops.clip(ops.matmul(images, color_transforms), 0, 1)\n",
" return images\n",
"\n",
"\n",
@@ -344,7 +351,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -363,8 +370,7 @@
" layers.Dense(width, activation=\"relu\"),\n",
" ],\n",
" name=\"encoder\",\n",
- " )\n",
- ""
+ " )\n"
]
},
{
@@ -380,7 +386,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -461,7 +467,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -526,19 +532,19 @@
" # NT-Xent loss (normalized temperature-scaled cross entropy)\n",
"\n",
" # Cosine similarity: the dot product of the l2-normalized feature vectors\n",
- " projections_1 = tf.math.l2_normalize(projections_1, axis=1)\n",
- " projections_2 = tf.math.l2_normalize(projections_2, axis=1)\n",
+ " projections_1 = ops.normalize(projections_1, axis=1)\n",
+ " projections_2 = ops.normalize(projections_2, axis=1)\n",
" similarities = (\n",
- " tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature\n",
+ " ops.matmul(projections_1, ops.transpose(projections_2)) / self.temperature\n",
" )\n",
"\n",
" # The similarity between the representations of two augmented views of the\n",
" # same image should be higher than their similarity with other views\n",
- " batch_size = tf.shape(projections_1)[0]\n",
- " contrastive_labels = tf.range(batch_size)\n",
+ " batch_size = ops.shape(projections_1)[0]\n",
+ " contrastive_labels = ops.arange(batch_size)\n",
" self.contrastive_accuracy.update_state(contrastive_labels, similarities)\n",
" self.contrastive_accuracy.update_state(\n",
- " contrastive_labels, tf.transpose(similarities)\n",
+ " contrastive_labels, ops.transpose(similarities)\n",
" )\n",
"\n",
" # The temperature-scaled similarities are used as logits for cross-entropy\n",
@@ -547,7 +553,7 @@
" contrastive_labels, similarities, from_logits=True\n",
" )\n",
" loss_2_1 = keras.losses.sparse_categorical_crossentropy(\n",
- " contrastive_labels, tf.transpose(similarities), from_logits=True\n",
+ " contrastive_labels, ops.transpose(similarities), from_logits=True\n",
" )\n",
" return (loss_1_2 + loss_2_1) / 2\n",
"\n",
@@ -555,7 +561,7 @@
" (unlabeled_images, _), (labeled_images, labels) = data\n",
"\n",
" # Both labeled and unlabeled images are used, without labels\n",
- " images = tf.concat((unlabeled_images, labeled_images), axis=0)\n",
+ " images = ops.concatenate((unlabeled_images, labeled_images), axis=0)\n",
" # Each image is augmented twice, differently\n",
" augmented_images_1 = self.contrastive_augmenter(images, training=True)\n",
" augmented_images_2 = self.contrastive_augmenter(images, training=True)\n",
@@ -645,7 +651,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -687,7 +693,7 @@
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": null,
"metadata": {
"colab_type": "code"
},
@@ -866,4 +872,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
\ No newline at end of file
+}
diff --git a/examples/vision/md/semisupervised_simclr.md b/examples/vision/md/semisupervised_simclr.md
index 541e077ae2..b2c87aa434 100644
--- a/examples/vision/md/semisupervised_simclr.md
+++ b/examples/vision/md/semisupervised_simclr.md
@@ -2,7 +2,7 @@
**Author:** [András Béres](https://www.linkedin.com/in/andras-beres-789190210)
**Date created:** 2021/04/24
-**Last modified:** 2021/04/24
+**Last modified:** 2024/03/04
**Description:** Contrastive pretraining with SimCLR for semi-supervised image classification on the STL-10 dataset.
@@ -88,6 +88,7 @@ import tensorflow as tf
import tensorflow_datasets as tfds
import keras
+from keras import ops
from keras import layers
```
@@ -206,6 +207,7 @@ class RandomColorAffine(layers.Layer):
def __init__(self, brightness=0, jitter=0, **kwargs):
super().__init__(**kwargs)
+ self.seed_generator = keras.random.SeedGenerator(1337)
self.brightness = brightness
self.jitter = jitter
@@ -216,24 +218,29 @@ class RandomColorAffine(layers.Layer):
def call(self, images, training=True):
if training:
- batch_size = tf.shape(images)[0]
+ batch_size = ops.shape(images)[0]
# Same for all colors
- brightness_scales = 1 + tf.random.uniform(
+ brightness_scales = 1 + keras.random.uniform(
(batch_size, 1, 1, 1),
minval=-self.brightness,
maxval=self.brightness,
+ seed=self.seed_generator,
)
# Different for all colors
- jitter_matrices = tf.random.uniform(
- (batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
+ jitter_matrices = keras.random.uniform(
+ (batch_size, 1, 3, 3),
+ minval=-self.jitter,
+ maxval=self.jitter,
+ seed=self.seed_generator,
)
color_transforms = (
- tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales
+ ops.tile(ops.expand_dims(ops.eye(3), axis=0), (batch_size, 1, 1, 1))
+ * brightness_scales
+ jitter_matrices
)
- images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)
+ images = ops.clip(ops.matmul(images, color_transforms), 0, 1)
return images
@@ -491,19 +498,19 @@ class ContrastiveModel(keras.Model):
# NT-Xent loss (normalized temperature-scaled cross entropy)
# Cosine similarity: the dot product of the l2-normalized feature vectors
- projections_1 = tf.math.l2_normalize(projections_1, axis=1)
- projections_2 = tf.math.l2_normalize(projections_2, axis=1)
+ projections_1 = ops.normalize(projections_1, axis=1)
+ projections_2 = ops.normalize(projections_2, axis=1)
similarities = (
- tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
+ ops.matmul(projections_1, ops.transpose(projections_2)) / self.temperature
)
# The similarity between the representations of two augmented views of the
# same image should be higher than their similarity with other views
- batch_size = tf.shape(projections_1)[0]
- contrastive_labels = tf.range(batch_size)
+ batch_size = ops.shape(projections_1)[0]
+ contrastive_labels = ops.arange(batch_size)
self.contrastive_accuracy.update_state(contrastive_labels, similarities)
self.contrastive_accuracy.update_state(
- contrastive_labels, tf.transpose(similarities)
+ contrastive_labels, ops.transpose(similarities)
)
# The temperature-scaled similarities are used as logits for cross-entropy
@@ -512,7 +519,7 @@ class ContrastiveModel(keras.Model):
contrastive_labels, similarities, from_logits=True
)
loss_2_1 = keras.losses.sparse_categorical_crossentropy(
- contrastive_labels, tf.transpose(similarities), from_logits=True
+ contrastive_labels, ops.transpose(similarities), from_logits=True
)
return (loss_1_2 + loss_2_1) / 2
@@ -520,7 +527,7 @@ class ContrastiveModel(keras.Model):
(unlabeled_images, _), (labeled_images, labels) = data
# Both labeled and unlabeled images are used, without labels
- images = tf.concat((unlabeled_images, labeled_images), axis=0)
+ images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
# Each image is augmented twice, differently
augmented_images_1 = self.contrastive_augmenter(images, training=True)
augmented_images_2 = self.contrastive_augmenter(images, training=True)
diff --git a/examples/vision/semisupervised_simclr.py b/examples/vision/semisupervised_simclr.py
index 5bb49a64d5..c398210962 100644
--- a/examples/vision/semisupervised_simclr.py
+++ b/examples/vision/semisupervised_simclr.py
@@ -2,9 +2,10 @@
Title: Semi-supervised image classification using contrastive pretraining with SimCLR
Author: [András Béres](https://www.linkedin.com/in/andras-beres-789190210)
Date created: 2021/04/24
-Last modified: 2021/04/24
+Last modified: 2024/03/04
Description: Contrastive pretraining with SimCLR for semi-supervised image classification on the STL-10 dataset.
Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
"""
"""
@@ -84,6 +85,7 @@
import tensorflow_datasets as tfds
import keras
+from keras import ops
from keras import layers
"""
@@ -189,6 +191,7 @@ class RandomColorAffine(layers.Layer):
def __init__(self, brightness=0, jitter=0, **kwargs):
super().__init__(**kwargs)
+ self.seed_generator = keras.random.SeedGenerator(1337)
self.brightness = brightness
self.jitter = jitter
@@ -199,24 +202,29 @@ def get_config(self):
def call(self, images, training=True):
if training:
- batch_size = tf.shape(images)[0]
+ batch_size = ops.shape(images)[0]
# Same for all colors
- brightness_scales = 1 + tf.random.uniform(
+ brightness_scales = 1 + keras.random.uniform(
(batch_size, 1, 1, 1),
minval=-self.brightness,
maxval=self.brightness,
+ seed=self.seed_generator,
)
# Different for all colors
- jitter_matrices = tf.random.uniform(
- (batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
+ jitter_matrices = keras.random.uniform(
+ (batch_size, 1, 3, 3),
+ minval=-self.jitter,
+ maxval=self.jitter,
+ seed=self.seed_generator,
)
color_transforms = (
- tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales
+ ops.tile(ops.expand_dims(ops.eye(3), axis=0), (batch_size, 1, 1, 1))
+ * brightness_scales
+ jitter_matrices
)
- images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)
+ images = ops.clip(ops.matmul(images, color_transforms), 0, 1)
return images
@@ -416,19 +424,19 @@ def contrastive_loss(self, projections_1, projections_2):
# NT-Xent loss (normalized temperature-scaled cross entropy)
# Cosine similarity: the dot product of the l2-normalized feature vectors
- projections_1 = tf.math.l2_normalize(projections_1, axis=1)
- projections_2 = tf.math.l2_normalize(projections_2, axis=1)
+ projections_1 = ops.normalize(projections_1, axis=1)
+ projections_2 = ops.normalize(projections_2, axis=1)
similarities = (
- tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
+ ops.matmul(projections_1, ops.transpose(projections_2)) / self.temperature
)
# The similarity between the representations of two augmented views of the
# same image should be higher than their similarity with other views
- batch_size = tf.shape(projections_1)[0]
- contrastive_labels = tf.range(batch_size)
+ batch_size = ops.shape(projections_1)[0]
+ contrastive_labels = ops.arange(batch_size)
self.contrastive_accuracy.update_state(contrastive_labels, similarities)
self.contrastive_accuracy.update_state(
- contrastive_labels, tf.transpose(similarities)
+ contrastive_labels, ops.transpose(similarities)
)
# The temperature-scaled similarities are used as logits for cross-entropy
@@ -437,7 +445,7 @@ def contrastive_loss(self, projections_1, projections_2):
contrastive_labels, similarities, from_logits=True
)
loss_2_1 = keras.losses.sparse_categorical_crossentropy(
- contrastive_labels, tf.transpose(similarities), from_logits=True
+ contrastive_labels, ops.transpose(similarities), from_logits=True
)
return (loss_1_2 + loss_2_1) / 2
@@ -445,7 +453,7 @@ def train_step(self, data):
(unlabeled_images, _), (labeled_images, labels) = data
# Both labeled and unlabeled images are used, without labels
- images = tf.concat((unlabeled_images, labeled_images), axis=0)
+ images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
# Each image is augmented twice, differently
augmented_images_1 = self.contrastive_augmenter(images, training=True)
augmented_images_2 = self.contrastive_augmenter(images, training=True)