-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updating Semi-supervised image classification using contrastive pretraining with SimCLR Keras 3 example (TF-Only) #1777
Changes from 1 commit
dfca98f
9a33aa7
b4a4546
0909cbd
eb7332c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,6 +84,7 @@ | |
import tensorflow_datasets as tfds | ||
|
||
import keras | ||
from keras import ops | ||
from keras import layers | ||
|
||
""" | ||
|
@@ -199,24 +200,24 @@ def get_config(self): | |
|
||
def call(self, images, training=True): | ||
if training: | ||
batch_size = tf.shape(images)[0] | ||
batch_size = ops.shape(images)[0] | ||
|
||
# Same for all colors | ||
brightness_scales = 1 + tf.random.uniform( | ||
brightness_scales = 1 + keras.random.uniform( | ||
(batch_size, 1, 1, 1), | ||
minval=-self.brightness, | ||
maxval=self.brightness, | ||
) | ||
# Different for all colors | ||
jitter_matrices = tf.random.uniform( | ||
jitter_matrices = keras.random.uniform( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||
(batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter | ||
) | ||
|
||
color_transforms = ( | ||
tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be converted to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, actually, there were some issues. In spite of having the same TF ops in Keras, the |
||
+ jitter_matrices | ||
) | ||
images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1) | ||
images = ops.clip(ops.matmul(images, color_transforms), 0, 1) | ||
return images | ||
|
||
|
||
|
@@ -416,19 +417,19 @@ def contrastive_loss(self, projections_1, projections_2): | |
# NT-Xent loss (normalized temperature-scaled cross entropy) | ||
|
||
# Cosine similarity: the dot product of the l2-normalized feature vectors | ||
projections_1 = tf.math.l2_normalize(projections_1, axis=1) | ||
projections_2 = tf.math.l2_normalize(projections_2, axis=1) | ||
projections_1 = keras.utils.normalize(projections_1, axis=1, order=2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. Replaced with Keras ops. |
||
projections_2 = keras.utils.normalize(projections_2, axis=1, order=2) | ||
similarities = ( | ||
tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature | ||
) | ||
|
||
# The similarity between the representations of two augmented views of the | ||
# same image should be higher than their similarity with other views | ||
batch_size = tf.shape(projections_1)[0] | ||
contrastive_labels = tf.range(batch_size) | ||
batch_size = ops.shape(projections_1)[0] | ||
contrastive_labels = ops.arange(batch_size) | ||
self.contrastive_accuracy.update_state(contrastive_labels, similarities) | ||
self.contrastive_accuracy.update_state( | ||
contrastive_labels, tf.transpose(similarities) | ||
contrastive_labels, ops.transpose(similarities) | ||
) | ||
|
||
# The temperature-scaled similarities are used as logits for cross-entropy | ||
|
@@ -437,15 +438,15 @@ def contrastive_loss(self, projections_1, projections_2): | |
contrastive_labels, similarities, from_logits=True | ||
) | ||
loss_2_1 = keras.losses.sparse_categorical_crossentropy( | ||
contrastive_labels, tf.transpose(similarities), from_logits=True | ||
contrastive_labels, ops.transpose(similarities), from_logits=True | ||
) | ||
return (loss_1_2 + loss_2_1) / 2 | ||
|
||
def train_step(self, data): | ||
(unlabeled_images, _), (labeled_images, labels) = data | ||
|
||
# Both labeled and unlabeled images are used, without labels | ||
images = tf.concat((unlabeled_images, labeled_images), axis=0) | ||
images = ops.concatenate((unlabeled_images, labeled_images), axis=0) | ||
# Each image is augmented twice, differently | ||
augmented_images_1 = self.contrastive_augmenter(images, training=True) | ||
augmented_images_2 = self.contrastive_augmenter(images, training=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a
seed=self.seed_generator
arg and createself.seed_generator
in__init__()
.