Skip to content

Commit

Permalink
Updating Semi-supervised image classification using contrastive pretr…
Browse files Browse the repository at this point in the history
…aining with SimCLR Keras 3 example (TF-Only) (keras-team#1777)

* Update to the existing keras3 example

* left tf ops replaced with keras ops

* formatting done

* seed generator added

* .md and .ipynb file added
  • Loading branch information
sitamgithub-MSIT committed May 30, 2024
1 parent c6657f1 commit 609ab17
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 58 deletions.
62 changes: 34 additions & 28 deletions examples/vision/ipynb/semisupervised_simclr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
"source": [
"# Semi-supervised image classification using contrastive pretraining with SimCLR\n",
"\n",
"**Author:** [Andr\u00e1s B\u00e9res](https://www.linkedin.com/in/andras-beres-789190210)<br>\n",
"**Author:** [András Béres](https://www.linkedin.com/in/andras-beres-789190210)<br>\n",
"**Date created:** 2021/04/24<br>\n",
"**Last modified:** 2021/04/24<br>\n",
"**Last modified:** 2024/03/04<br>\n",
"**Description:** Contrastive pretraining with SimCLR for semi-supervised image classification on the STL-10 dataset."
]
},
Expand Down Expand Up @@ -86,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -109,6 +109,7 @@
"import tensorflow_datasets as tfds\n",
"\n",
"import keras\n",
"from keras import ops\n",
"from keras import layers"
]
},
Expand All @@ -123,7 +124,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -162,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -245,7 +246,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -257,6 +258,7 @@
" def __init__(self, brightness=0, jitter=0, **kwargs):\n",
" super().__init__(**kwargs)\n",
"\n",
" self.seed_generator = keras.random.SeedGenerator(1337)\n",
" self.brightness = brightness\n",
" self.jitter = jitter\n",
"\n",
Expand All @@ -267,24 +269,29 @@
"\n",
" def call(self, images, training=True):\n",
" if training:\n",
" batch_size = tf.shape(images)[0]\n",
" batch_size = ops.shape(images)[0]\n",
"\n",
" # Same for all colors\n",
" brightness_scales = 1 + tf.random.uniform(\n",
" brightness_scales = 1 + keras.random.uniform(\n",
" (batch_size, 1, 1, 1),\n",
" minval=-self.brightness,\n",
" maxval=self.brightness,\n",
" seed=self.seed_generator,\n",
" )\n",
" # Different for all colors\n",
" jitter_matrices = tf.random.uniform(\n",
" (batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter\n",
" jitter_matrices = keras.random.uniform(\n",
" (batch_size, 1, 3, 3), \n",
" minval=-self.jitter, \n",
" maxval=self.jitter,\n",
" seed=self.seed_generator,\n",
" )\n",
"\n",
" color_transforms = (\n",
" tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales\n",
" ops.tile(ops.expand_dims(ops.eye(3), axis=0), (batch_size, 1, 1, 1))\n",
" * brightness_scales\n",
" + jitter_matrices\n",
" )\n",
" images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)\n",
" images = ops.clip(ops.matmul(images, color_transforms), 0, 1)\n",
" return images\n",
"\n",
"\n",
Expand Down Expand Up @@ -344,7 +351,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -363,8 +370,7 @@
" layers.Dense(width, activation=\"relu\"),\n",
" ],\n",
" name=\"encoder\",\n",
" )\n",
""
" )\n"
]
},
{
Expand All @@ -380,7 +386,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -461,7 +467,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -526,19 +532,19 @@
" # NT-Xent loss (normalized temperature-scaled cross entropy)\n",
"\n",
" # Cosine similarity: the dot product of the l2-normalized feature vectors\n",
" projections_1 = tf.math.l2_normalize(projections_1, axis=1)\n",
" projections_2 = tf.math.l2_normalize(projections_2, axis=1)\n",
" projections_1 = ops.normalize(projections_1, axis=1)\n",
" projections_2 = ops.normalize(projections_2, axis=1)\n",
" similarities = (\n",
" tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature\n",
" ops.matmul(projections_1, ops.transpose(projections_2)) / self.temperature\n",
" )\n",
"\n",
" # The similarity between the representations of two augmented views of the\n",
" # same image should be higher than their similarity with other views\n",
" batch_size = tf.shape(projections_1)[0]\n",
" contrastive_labels = tf.range(batch_size)\n",
" batch_size = ops.shape(projections_1)[0]\n",
" contrastive_labels = ops.arange(batch_size)\n",
" self.contrastive_accuracy.update_state(contrastive_labels, similarities)\n",
" self.contrastive_accuracy.update_state(\n",
" contrastive_labels, tf.transpose(similarities)\n",
" contrastive_labels, ops.transpose(similarities)\n",
" )\n",
"\n",
" # The temperature-scaled similarities are used as logits for cross-entropy\n",
Expand All @@ -547,15 +553,15 @@
" contrastive_labels, similarities, from_logits=True\n",
" )\n",
" loss_2_1 = keras.losses.sparse_categorical_crossentropy(\n",
" contrastive_labels, tf.transpose(similarities), from_logits=True\n",
" contrastive_labels, ops.transpose(similarities), from_logits=True\n",
" )\n",
" return (loss_1_2 + loss_2_1) / 2\n",
"\n",
" def train_step(self, data):\n",
" (unlabeled_images, _), (labeled_images, labels) = data\n",
"\n",
" # Both labeled and unlabeled images are used, without labels\n",
" images = tf.concat((unlabeled_images, labeled_images), axis=0)\n",
" images = ops.concatenate((unlabeled_images, labeled_images), axis=0)\n",
" # Each image is augmented twice, differently\n",
" augmented_images_1 = self.contrastive_augmenter(images, training=True)\n",
" augmented_images_2 = self.contrastive_augmenter(images, training=True)\n",
Expand Down Expand Up @@ -645,7 +651,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -687,7 +693,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -866,4 +872,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
37 changes: 22 additions & 15 deletions examples/vision/md/semisupervised_simclr.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Author:** [András Béres](https://www.linkedin.com/in/andras-beres-789190210)<br>
**Date created:** 2021/04/24<br>
**Last modified:** 2021/04/24<br>
**Last modified:** 2024/03/04<br>
**Description:** Contrastive pretraining with SimCLR for semi-supervised image classification on the STL-10 dataset.


Expand Down Expand Up @@ -88,6 +88,7 @@ import tensorflow as tf
import tensorflow_datasets as tfds

import keras
from keras import ops
from keras import layers
```

Expand Down Expand Up @@ -206,6 +207,7 @@ class RandomColorAffine(layers.Layer):
def __init__(self, brightness=0, jitter=0, **kwargs):
super().__init__(**kwargs)

self.seed_generator = keras.random.SeedGenerator(1337)
self.brightness = brightness
self.jitter = jitter

Expand All @@ -216,24 +218,29 @@ class RandomColorAffine(layers.Layer):

def call(self, images, training=True):
if training:
batch_size = tf.shape(images)[0]
batch_size = ops.shape(images)[0]

# Same for all colors
brightness_scales = 1 + tf.random.uniform(
brightness_scales = 1 + keras.random.uniform(
(batch_size, 1, 1, 1),
minval=-self.brightness,
maxval=self.brightness,
seed=self.seed_generator,
)
# Different for all colors
jitter_matrices = tf.random.uniform(
(batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
jitter_matrices = keras.random.uniform(
(batch_size, 1, 3, 3),
minval=-self.jitter,
maxval=self.jitter,
seed=self.seed_generator,
)

color_transforms = (
tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales
ops.tile(ops.expand_dims(ops.eye(3), axis=0), (batch_size, 1, 1, 1))
* brightness_scales
+ jitter_matrices
)
images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)
images = ops.clip(ops.matmul(images, color_transforms), 0, 1)
return images


Expand Down Expand Up @@ -491,19 +498,19 @@ class ContrastiveModel(keras.Model):
# NT-Xent loss (normalized temperature-scaled cross entropy)

# Cosine similarity: the dot product of the l2-normalized feature vectors
projections_1 = tf.math.l2_normalize(projections_1, axis=1)
projections_2 = tf.math.l2_normalize(projections_2, axis=1)
projections_1 = ops.normalize(projections_1, axis=1)
projections_2 = ops.normalize(projections_2, axis=1)
similarities = (
tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
ops.matmul(projections_1, ops.transpose(projections_2)) / self.temperature
)

# The similarity between the representations of two augmented views of the
# same image should be higher than their similarity with other views
batch_size = tf.shape(projections_1)[0]
contrastive_labels = tf.range(batch_size)
batch_size = ops.shape(projections_1)[0]
contrastive_labels = ops.arange(batch_size)
self.contrastive_accuracy.update_state(contrastive_labels, similarities)
self.contrastive_accuracy.update_state(
contrastive_labels, tf.transpose(similarities)
contrastive_labels, ops.transpose(similarities)
)

# The temperature-scaled similarities are used as logits for cross-entropy
Expand All @@ -512,15 +519,15 @@ class ContrastiveModel(keras.Model):
contrastive_labels, similarities, from_logits=True
)
loss_2_1 = keras.losses.sparse_categorical_crossentropy(
contrastive_labels, tf.transpose(similarities), from_logits=True
contrastive_labels, ops.transpose(similarities), from_logits=True
)
return (loss_1_2 + loss_2_1) / 2

def train_step(self, data):
(unlabeled_images, _), (labeled_images, labels) = data

# Both labeled and unlabeled images are used, without labels
images = tf.concat((unlabeled_images, labeled_images), axis=0)
images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
# Each image is augmented twice, differently
augmented_images_1 = self.contrastive_augmenter(images, training=True)
augmented_images_2 = self.contrastive_augmenter(images, training=True)
Expand Down
Loading

0 comments on commit 609ab17

Please sign in to comment.