Skip to content

Commit

Permalink
Updated Variational AutoEncoder example for Keras 3 (keras-team#1836)
Browse files Browse the repository at this point in the history
* vae keras 3 example updated

* seed generator added to rng layer

* generated files are added
  • Loading branch information
sitamgithub-MSIT authored Apr 24, 2024
1 parent f924d09 commit 190dbe5
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 39 deletions.
45 changes: 24 additions & 21 deletions examples/generative/ipynb/vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
"**Date created:** 2020/05/03<br>\n",
"**Last modified:** 2023/11/22<br>\n",
"**Last modified:** 2024/04/24<br>\n",
"**Description:** Convolutional Variational AutoEncoder (VAE) trained on MNIST digits."
]
},
Expand All @@ -25,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -38,6 +38,7 @@
"import numpy as np\n",
"import tensorflow as tf\n",
"import keras\n",
"from keras import ops\n",
"from keras import layers"
]
},
Expand All @@ -52,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -62,13 +63,16 @@
"class Sampling(layers.Layer):\n",
" \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n",
"\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.seed_generator = keras.random.SeedGenerator(1337)\n",
"\n",
" def call(self, inputs):\n",
" z_mean, z_log_var = inputs\n",
" batch = tf.shape(z_mean)[0]\n",
" dim = tf.shape(z_mean)[1]\n",
" epsilon = tf.random.normal(shape=(batch, dim))\n",
" return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n",
""
" batch = ops.shape(z_mean)[0]\n",
" dim = ops.shape(z_mean)[1]\n",
" epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)\n",
" return z_mean + ops.exp(0.5 * z_log_var) * epsilon\n"
]
},
{
Expand All @@ -82,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -113,7 +117,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -140,7 +144,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -170,14 +174,14 @@
" with tf.GradientTape() as tape:\n",
" z_mean, z_log_var, z = self.encoder(data)\n",
" reconstruction = self.decoder(z)\n",
" reconstruction_loss = tf.reduce_mean(\n",
" tf.reduce_sum(\n",
" reconstruction_loss = ops.mean(\n",
" ops.sum(\n",
" keras.losses.binary_crossentropy(data, reconstruction),\n",
" axis=(1, 2),\n",
" )\n",
" )\n",
" kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))\n",
" kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))\n",
" kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))\n",
" kl_loss = ops.mean(ops.sum(kl_loss, axis=1))\n",
" total_loss = reconstruction_loss + kl_loss\n",
" grads = tape.gradient(total_loss, self.trainable_weights)\n",
" self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n",
Expand All @@ -188,8 +192,7 @@
" \"loss\": self.total_loss_tracker.result(),\n",
" \"reconstruction_loss\": self.reconstruction_loss_tracker.result(),\n",
" \"kl_loss\": self.kl_loss_tracker.result(),\n",
" }\n",
""
" }\n"
]
},
{
Expand All @@ -203,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -229,7 +232,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -286,7 +289,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -340,4 +343,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
23 changes: 14 additions & 9 deletions examples/generative/md/vae.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Author:** [fchollet](https://twitter.com/fchollet)<br>
**Date created:** 2020/05/03<br>
**Last modified:** 2023/11/22<br>
**Last modified:** 2024/04/24<br>
**Description:** Convolutional Variational AutoEncoder (VAE) trained on MNIST digits.


Expand All @@ -22,6 +22,7 @@ os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
import tensorflow as tf
import keras
from keras import ops
from keras import layers
```

Expand All @@ -34,12 +35,16 @@ from keras import layers
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.seed_generator = keras.random.SeedGenerator(1337)

def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
batch = ops.shape(z_mean)[0]
dim = ops.shape(z_mean)[1]
epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
return z_mean + ops.exp(0.5 * z_log_var) * epsilon

```

Expand Down Expand Up @@ -204,14 +209,14 @@ class VAE(keras.Model):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
reconstruction_loss = ops.mean(
ops.sum(
keras.losses.binary_crossentropy(data, reconstruction),
axis=(1, 2),
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
Expand Down
23 changes: 14 additions & 9 deletions examples/generative/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Variational AutoEncoder
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2020/05/03
Last modified: 2023/11/22
Last modified: 2024/04/24
Description: Convolutional Variational AutoEncoder (VAE) trained on MNIST digits.
Accelerator: GPU
"""
Expand All @@ -18,6 +18,7 @@
import numpy as np
import tensorflow as tf
import keras
from keras import ops
from keras import layers

"""
Expand All @@ -28,12 +29,16 @@
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.seed_generator = keras.random.SeedGenerator(1337)

def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
batch = ops.shape(z_mean)[0]
dim = ops.shape(z_mean)[1]
epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
return z_mean + ops.exp(0.5 * z_log_var) * epsilon


"""
Expand Down Expand Up @@ -94,14 +99,14 @@ def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
reconstruction_loss = ops.mean(
ops.sum(
keras.losses.binary_crossentropy(data, reconstruction),
axis=(1, 2),
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
Expand Down

0 comments on commit 190dbe5

Please sign in to comment.