Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating Semi-supervised image classification using contrastive pretraining with SimCLR Keras 3 example (TF-Only) #1777

Merged
merged 5 commits into from
Mar 4, 2024

Conversation

sitamgithub-MSIT
Copy link
Contributor

Corresponding Issue

This PR updates the Semi-supervised image classification using contrastive pretraining with SimCLR Keras 3.0 example [TF Only Backend]. Many TF ops are replaced with corresponding Keras ops.

For example, here is the notebook link provided:
https://colab.research.google.com/drive/1rLAKohQaybEuUR-Nxgl3VDjZ8g5qmzid?usp=sharing

cc: @fchollet @divyashreepathihalli

The following describes the Git difference for the changed files:

Changes:
diff --git a/examples/vision/semisupervised_simclr.py b/examples/vision/semisupervised_simclr.py
index 5bb49a64..9edd7d6d 100644
--- a/examples/vision/semisupervised_simclr.py
+++ b/examples/vision/semisupervised_simclr.py
@@ -84,6 +84,7 @@ import tensorflow as tf
 import tensorflow_datasets as tfds
 
 import keras
+from keras import ops
 from keras import layers
 
 """
@@ -199,16 +200,16 @@ class RandomColorAffine(layers.Layer):
 
     def call(self, images, training=True):
         if training:
-            batch_size = tf.shape(images)[0]
+            batch_size = ops.shape(images)[0]
 
             # Same for all colors
-            brightness_scales = 1 + tf.random.uniform(
+            brightness_scales = 1 + keras.random.uniform(
                 (batch_size, 1, 1, 1),
                 minval=-self.brightness,
                 maxval=self.brightness,
             )
             # Different for all colors
-            jitter_matrices = tf.random.uniform(
+            jitter_matrices = keras.random.uniform(
                 (batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
             )
 
@@ -216,7 +217,7 @@ class RandomColorAffine(layers.Layer):
                 tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales
                 + jitter_matrices
             )
-            images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)
+            images = ops.clip(ops.matmul(images, color_transforms), 0, 1)
         return images
 
 
@@ -416,19 +417,19 @@ class ContrastiveModel(keras.Model):
         # NT-Xent loss (normalized temperature-scaled cross entropy)
 
         # Cosine similarity: the dot product of the l2-normalized feature vectors
-        projections_1 = tf.math.l2_normalize(projections_1, axis=1)
-        projections_2 = tf.math.l2_normalize(projections_2, axis=1)
+        projections_1 = keras.utils.normalize(projections_1, axis=1, order=2)
+        projections_2 = keras.utils.normalize(projections_2, axis=1, order=2)
         similarities = (
             tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
         )
 
 
         # The similarity between the representations of two augmented views of the
         # same image should be higher than their similarity with other views
-        batch_size = tf.shape(projections_1)[0]
-        contrastive_labels = tf.range(batch_size)
+        batch_size = ops.shape(projections_1)[0]
+        contrastive_labels = ops.arange(batch_size)
         self.contrastive_accuracy.update_state(contrastive_labels, similarities)
         self.contrastive_accuracy.update_state(
-            contrastive_labels, tf.transpose(similarities)
+            contrastive_labels, ops.transpose(similarities)
         )
 
         # The temperature-scaled similarities are used as logits for cross-entropy
@@ -437,7 +438,7 @@ class ContrastiveModel(keras.Model):
             contrastive_labels, similarities, from_logits=True
         )
         loss_2_1 = keras.losses.sparse_categorical_crossentropy(
-            contrastive_labels, tf.transpose(similarities), from_logits=True
+            contrastive_labels, ops.transpose(similarities), from_logits=True
         )
         return (loss_1_2 + loss_2_1) / 2
 
@@ -445,7 +446,7 @@ class ContrastiveModel(keras.Model):
         (unlabeled_images, _), (labeled_images, labels) = data
 
         # Both labeled and unlabeled images are used, without labels
-        images = tf.concat((unlabeled_images, labeled_images), axis=0)
+        images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
         # Each image is augmented twice, differently
         augmented_images_1 = self.contrastive_augmenter(images, training=True)
         augmented_images_2 = self.contrastive_augmenter(images, training=True)
(END)

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, the code looks good!

@@ -416,19 +417,19 @@ def contrastive_loss(self, projections_1, projections_2):
# NT-Xent loss (normalized temperature-scaled cross entropy)

# Cosine similarity: the dot product of the l2-normalized feature vectors
projections_1 = tf.math.l2_normalize(projections_1, axis=1)
projections_2 = tf.math.l2_normalize(projections_2, axis=1)
projections_1 = keras.utils.normalize(projections_1, axis=1, order=2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just use ops.normalize with the latest Keras version

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Replaced with Keras ops.

(batch_size, 1, 1, 1),
minval=-self.brightness,
maxval=self.brightness,
)
# Different for all colors
jitter_matrices = tf.random.uniform(
jitter_matrices = keras.random.uniform(
(batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
)

color_transforms = (
tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be converted to ops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually, there were some issues. In spite of having the same TF ops in Keras, the batch_shape attribute that is present here is missing in Keras ops. That is why I did not change this in the first place. But I addressed this in the latest commit and replaced the TF ops.

@divyashreepathihalli
Copy link
Contributor

Thanks for the PR @sitamgithub-MSIT, there is 2 more ops that can be converted to Keras 3
tf.eye -> ops.eye
tf.matmul -> ops.matmul

@sitamgithub-MSIT
Copy link
Contributor Author

Thanks for the PR @sitamgithub-MSIT, there is 2 more ops that can be converted to Keras 3 tf.eye -> ops.eye tf.matmul -> ops.matmul

At first, this two-ops implementation in the code example did not exactly match the corresponding Keras ops, which is why I did not change those in the first place. But in the latest commit, I replaced these two TF ops. And the code example is working fine. You can check the mentioned notebook in the PR description.


# Same for all colors
brightness_scales = 1 + tf.random.uniform(
brightness_scales = 1 + keras.random.uniform(
(batch_size, 1, 1, 1),
minval=-self.brightness,
maxval=self.brightness,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a seed=self.seed_generator arg and create self.seed_generator in __init__().

(batch_size, 1, 1, 1),
minval=-self.brightness,
maxval=self.brightness,
)
# Different for all colors
jitter_matrices = tf.random.uniform(
jitter_matrices = keras.random.uniform(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@sitamgithub-MSIT
Copy link
Contributor Author

@fchollet I added seed generator arguments to random operations. Is everything OK now?

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you! Please add the generated files.

@fchollet fchollet merged commit faf8ec1 into keras-team:master Mar 4, 2024
3 checks passed
sitamgithub-MSIT added a commit to sitamgithub-MSIT/keras-io that referenced this pull request May 30, 2024
…aining with SimCLR Keras 3 example (TF-Only) (keras-team#1777)

* Update to the existing keras3 example

* left tf ops replaced with keras ops

* formatting done

* seed generator added

* .md and .ipynb file added
@sitamgithub-MSIT sitamgithub-MSIT deleted the simclr branch May 30, 2024 15:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants