Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating Semi-supervised image classification using contrastive pretraining with SimCLR Keras 3 example (TF-Only) #1777

Merged
merged 5 commits into from
Mar 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions examples/vision/semisupervised_simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import tensorflow_datasets as tfds

import keras
from keras import ops
from keras import layers

"""
Expand Down Expand Up @@ -199,24 +200,24 @@ def get_config(self):

def call(self, images, training=True):
if training:
batch_size = tf.shape(images)[0]
batch_size = ops.shape(images)[0]

# Same for all colors
brightness_scales = 1 + tf.random.uniform(
brightness_scales = 1 + keras.random.uniform(
(batch_size, 1, 1, 1),
minval=-self.brightness,
maxval=self.brightness,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a seed=self.seed_generator arg and create self.seed_generator in __init__().

)
# Different for all colors
jitter_matrices = tf.random.uniform(
jitter_matrices = keras.random.uniform(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

(batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
)

color_transforms = (
tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be converted to ops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually, there were some issues. In spite of having the same TF ops in Keras, the batch_shape attribute that is present here is missing in Keras ops. That is why I did not change this in the first place. But I addressed this in the latest commit and replaced the TF ops.

+ jitter_matrices
)
images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)
images = ops.clip(ops.matmul(images, color_transforms), 0, 1)
return images


Expand Down Expand Up @@ -416,19 +417,19 @@ def contrastive_loss(self, projections_1, projections_2):
# NT-Xent loss (normalized temperature-scaled cross entropy)

# Cosine similarity: the dot product of the l2-normalized feature vectors
projections_1 = tf.math.l2_normalize(projections_1, axis=1)
projections_2 = tf.math.l2_normalize(projections_2, axis=1)
projections_1 = keras.utils.normalize(projections_1, axis=1, order=2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just use ops.normalize with the latest Keras version

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Replaced with Keras ops.

projections_2 = keras.utils.normalize(projections_2, axis=1, order=2)
similarities = (
tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
)

# The similarity between the representations of two augmented views of the
# same image should be higher than their similarity with other views
batch_size = tf.shape(projections_1)[0]
contrastive_labels = tf.range(batch_size)
batch_size = ops.shape(projections_1)[0]
contrastive_labels = ops.arange(batch_size)
self.contrastive_accuracy.update_state(contrastive_labels, similarities)
self.contrastive_accuracy.update_state(
contrastive_labels, tf.transpose(similarities)
contrastive_labels, ops.transpose(similarities)
)

# The temperature-scaled similarities are used as logits for cross-entropy
Expand All @@ -437,15 +438,15 @@ def contrastive_loss(self, projections_1, projections_2):
contrastive_labels, similarities, from_logits=True
)
loss_2_1 = keras.losses.sparse_categorical_crossentropy(
contrastive_labels, tf.transpose(similarities), from_logits=True
contrastive_labels, ops.transpose(similarities), from_logits=True
)
return (loss_1_2 + loss_2_1) / 2

def train_step(self, data):
(unlabeled_images, _), (labeled_images, labels) = data

# Both labeled and unlabeled images are used, without labels
images = tf.concat((unlabeled_images, labeled_images), axis=0)
images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
# Each image is augmented twice, differently
augmented_images_1 = self.contrastive_augmenter(images, training=True)
augmented_images_2 = self.contrastive_augmenter(images, training=True)
Expand Down
Loading