diff --git a/assets/same_class_diff_color.png b/assets/same_class_diff_color.png
new file mode 100644
index 0000000..c5d98c8
Binary files /dev/null and b/assets/same_class_diff_color.png differ
diff --git a/assets/same_color_diff_class.png b/assets/same_color_diff_class.png
new file mode 100644
index 0000000..775ce42
Binary files /dev/null and b/assets/same_color_diff_class.png differ
diff --git a/solution.py b/solution.py
index 9edd9f7..f065a11 100644
--- a/solution.py
+++ b/solution.py
@@ -421,12 +421,12 @@ def forward(self, x, y):
# %% tags=["task"]
style_size = ... # TODO choose a size for the style space
unet_depth = ... # TODO Choose a depth for the UNet
-style_mapping = DenseModel(
+style_encoder = DenseModel(
input_shape=..., num_classes=... # How big is the style space?
unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid())
-generator = Generator(unet, style_mapping=style_mapping)
+generator = Generator(unet, style_encoder=style_encoder)
# %% tags=["solution"]
# Here is an example of a working setup! Note that you can change the hyperparameters as you experiment.
# Choose your own setup to see what works for you.
@@ -533,7 +533,7 @@ def copy_parameters(source_model, target_model):
# %%
-generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping))
+generator_ema = Generator(deepcopy(unet), style_encoder=deepcopy(style_encoder))
generator_ema = generator_ema.to(device)
# %% [markdown] tags=[]
@@ -785,51 +785,64 @@ def copy_parameters(source_model, target_model):
# %% [markdown]
# Now we need to use these prototypes to create counterfactual images!
-# TODO make a task here!
+# %% [markdown]
Task 4.1: Create counterfactuals
+# In the below, we will store the counterfactual images in the `counterfactuals` array.
+# - Create a counterfactual image for each of the prototypes.
+# - Classify the counterfactual image using the classifier.
+# - Store the source and target labels; which is which?
# %% tags=["task"]
-num_images = len(test_mnist)
+num_images = 1000
+random_test_mnist = torch.utils.data.Subset(
+ test_mnist, np.random.choice(len(test_mnist), num_images, replace=False)
counterfactuals = np.zeros((4, num_images, 3, 28, 28))
predictions = []
source_labels = []
target_labels = []
-for x, y in test_mnist:
- for i in range(4):
- if i == y:
- # Store the image as is.
- counterfactuals[i] = ...
- # Create the counterfactual from the image and prototype
+for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):
+ for lbl in range(4):
+ # TODO Create the counterfactual
x_fake = generator(x.unsqueeze(0).to(device), ...)
- counterfactuals[i] = x_fake.cpu().detach().numpy()
+ # TODO Predict the class of the counterfactual image
pred = model(...)
- source_labels.append(y)
- target_labels.append(i)
+ # TODO Store the source and target labels
+ source_labels.append(...) # The original label of the image
+ target_labels.append(...) # The desired label of the counterfactual image
+ # Store the counterfactual image and prediction
+ counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()
# %% tags=["solution"]
-num_images = len(test_mnist)
+num_images = 1000
+random_test_mnist = torch.utils.data.Subset(
+ test_mnist, np.random.choice(len(test_mnist), num_images, replace=False)
counterfactuals = np.zeros((4, num_images, 3, 28, 28))
predictions = []
source_labels = []
target_labels = []
-for x, y in test_mnist:
- for i in range(4):
- if i == y:
- # Store the image as is.
- counterfactuals[i] = x
+for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):
+ for lbl in range(4):
# Create the counterfactual
x_fake = generator(
- x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device)
+ x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device)
- counterfactuals[i] = x_fake.cpu().detach().numpy()
+ # Predict the class of the counterfactual image
pred = model(x_fake)
- source_labels.append(y)
- target_labels.append(i)
+ # Store the source and target labels
+ source_labels.append(y) # The original label of the image
+ target_labels.append(lbl) # The desired label of the counterfactual image
+ # Store the counterfactual image and prediction
+ counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()
# %% [markdown] tags=[]
@@ -842,13 +855,14 @@ def copy_parameters(source_model, target_model):
# - How well is our GAN doing at creating counterfactual images?
-# - Do you think that the prototypes used matter? Why or why not?
+# - Does your choice of prototypes matter? Why or why not?
# %% [markdown] tags=[]
# Let's also plot some examples of the counterfactual images.
+# %%
for i in np.random.choice(range(num_images), 4):
fig, axs = plt.subplots(1, 4, figsize=(20, 4))
for j, ax in enumerate(axs):
@@ -857,7 +871,7 @@ def copy_parameters(source_model, target_model):
ax.set_title(f"Class {j}")
# %% [markdown] tags=[]
# - Can you easily tell which of these images is the original, and which ones are the counterfactuals?
# - What is your hypothesis for the features that define each class?
@@ -874,10 +888,121 @@ def copy_parameters(source_model, target_model):
# Let's try putting the two together to see if we can figure out what exactly makes a class.
+# %%
+batch_size = 4
+batch = [random_test_mnist[i] for i in range(batch_size)]
+x = torch.stack([b[0] for b in batch])
+y = torch.tensor([b[1] for b in batch])
+x_fake = torch.tensor(counterfactuals[0, :batch_size])
+x = x.to(device).float()
+y = y.to(device)
+x_fake = x_fake.to(device).float()
+# Generated attributions on integrated gradients
+attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y)
-# %% [markdown] tags=[]
+# %% Another visualization function
+def visualize_color_attribution_and_counterfactual(
+ attribution, original_image, counterfactual_image
+ attribution = np.transpose(attribution, (1, 2, 0))
+ original_image = np.transpose(original_image, (1, 2, 0))
+ counterfactual_image = np.transpose(counterfactual_image, (1, 2, 0))
+ fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 5))
+ ax0.imshow(original_image)
+ ax0.set_title("Image")
+ ax0.axis("off")
+ ax1.imshow(counterfactual_image)
+ ax1.set_title("Counterfactual")
+ ax1.axis("off")
+ ax2.imshow(np.abs(attribution))
+ ax2.set_title("Attribution")
+ ax2.axis("off")
+ plt.show()
+# %%
+for idx in range(batch_size):
+ print("Source class:", y[idx].item())
+ print("Target class:", 0)
+ visualize_color_attribution_and_counterfactual(
+ attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy()
+ )
+# %% [markdown]
+# Questions
+# - Do the attributions explain the differences between the images and their counterfactuals?
+# - What happens when the "counterfactual" and the original image are of the same class? Why do you think this is?
+# - Do you have a more refined hypothesis for what makes each class unique?
+# %% [markdown]
+# Checkpoint 4
+# At this point you have:
+# - Created a StarGAN that can change the class of an image
+# - Evaluated the StarGAN on unseen data
+# - Used the StarGAN to create counterfactual images
+# - Used the counterfactual images to highlight the differences between classes
+# %% [markdown]
+# # Part 6: Exploring the Style Space, finding the answer
+# By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!
+# Here is an example of two images that are very similar in color, but are of different classes.
+# ![same_color_diff_class](assets/same_color_diff_class.png)
+# While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it!
+# Conversely, here is an example of two images with very different colors, but that are of the same class:
+# ![same_class_diff_color](assets/same_class_diff_color.png)
+# Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all!
+# So color is important... but not always? What's going on!?
+# There is a final piece of information that we can use to solve the puzzle: the style space.
+# %%
Task 6.1: Explore the style space
+# Let's take a look at the style space.
+# We will use the style encoder to encode the style of the images and then use PCA to visualize it.
+# %%
+styles = []
+labels = []
+for img, label in random_test_mnist:
+ styles.append(
+ style_encoder(img.unsqueeze(0).to(device)).cpu().detach().numpy().squeeze()
+ )
+ labels.append(label)
+# PCA
+from sklearn.decomposition import PCA
+pca = PCA(n_components=2)
+styles_pca = pca.fit_transform(styles)
+# Plot the PCA
+plt.figure(figsize=(10, 10))
+for i in range(4):
+ plt.scatter(
+ styles_pca[np.array(labels) == i, 0],
+ styles_pca[np.array(labels) == i, 1],
+ label=f"Class {i}",
+ )
+# %% [markdown]
Task 6.2: Adding color to the style space
+# We know that color is important. Does interpreting the style space as colors help us understand better?
+# Let's use the style space to color the PCA plot.
+# %% [markdown] tags=[]
# ## Going Further
# Here are some ideas for how to continue with this notebook: