diff --git a/assets/same_class_diff_color.png b/assets/same_class_diff_color.png new file mode 100644 index 0000000..c5d98c8 Binary files /dev/null and b/assets/same_class_diff_color.png differ diff --git a/assets/same_color_diff_class.png b/assets/same_color_diff_class.png new file mode 100644 index 0000000..775ce42 Binary files /dev/null and b/assets/same_color_diff_class.png differ diff --git a/solution.py b/solution.py index 9edd9f7..f065a11 100644 --- a/solution.py +++ b/solution.py @@ -421,12 +421,12 @@ def forward(self, x, y): # %% tags=["task"] style_size = ... # TODO choose a size for the style space unet_depth = ... # TODO Choose a depth for the UNet -style_mapping = DenseModel( +style_encoder = DenseModel( input_shape=..., num_classes=... # How big is the style space? ) unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid()) -generator = Generator(unet, style_mapping=style_mapping) +generator = Generator(unet, style_encoder=style_encoder) # %% tags=["solution"] # Here is an example of a working setup! Note that you can change the hyperparameters as you experiment. # Choose your own setup to see what works for you. @@ -533,7 +533,7 @@ def copy_parameters(source_model, target_model): # %% -generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping)) +generator_ema = Generator(deepcopy(unet), style_encoder=deepcopy(style_encoder)) generator_ema = generator_ema.to(device) # %% [markdown] tags=[] @@ -785,51 +785,64 @@ def copy_parameters(source_model, target_model): # %% [markdown] # Now we need to use these prototypes to create counterfactual images! -# TODO make a task here! +# %% [markdown] +#

Task 4.1: Create counterfactuals

+# In the below, we will store the counterfactual images in the `counterfactuals` array. +# +# # %% tags=["task"] -num_images = len(test_mnist) +num_images = 1000 +random_test_mnist = torch.utils.data.Subset( + test_mnist, np.random.choice(len(test_mnist), num_images, replace=False) +) counterfactuals = np.zeros((4, num_images, 3, 28, 28)) predictions = [] source_labels = [] target_labels = [] -for x, y in test_mnist: - for i in range(4): - if i == y: - # Store the image as is. - counterfactuals[i] = ... - # Create the counterfactual from the image and prototype +for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images): + for lbl in range(4): + # TODO Create the counterfactual x_fake = generator(x.unsqueeze(0).to(device), ...) - counterfactuals[i] = x_fake.cpu().detach().numpy() + # TODO Predict the class of the counterfactual image pred = model(...) - source_labels.append(y) - target_labels.append(i) + # TODO Store the source and target labels + source_labels.append(...) # The original label of the image + target_labels.append(...) # The desired label of the counterfactual image + # Store the counterfactual image and prediction + counterfactuals[lbl][i] = x_fake.cpu().detach().numpy() predictions.append(pred.argmax().item()) - # %% tags=["solution"] -num_images = len(test_mnist) +num_images = 1000 +random_test_mnist = torch.utils.data.Subset( + test_mnist, np.random.choice(len(test_mnist), num_images, replace=False) +) counterfactuals = np.zeros((4, num_images, 3, 28, 28)) predictions = [] source_labels = [] target_labels = [] -for x, y in test_mnist: - for i in range(4): - if i == y: - # Store the image as is. - counterfactuals[i] = x +for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images): + for lbl in range(4): # Create the counterfactual x_fake = generator( - x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device) + x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device) ) - counterfactuals[i] = x_fake.cpu().detach().numpy() + # Predict the class of the counterfactual image pred = model(x_fake) - source_labels.append(y) - target_labels.append(i) + # Store the source and target labels + source_labels.append(y) # The original label of the image + target_labels.append(lbl) # The desired label of the counterfactual image + # Store the counterfactual image and prediction + counterfactuals[lbl][i] = x_fake.cpu().detach().numpy() predictions.append(pred.argmax().item()) # %% [markdown] tags=[] @@ -842,13 +855,14 @@ def copy_parameters(source_model, target_model): #

Questions

# #
# %% [markdown] tags=[] # Let's also plot some examples of the counterfactual images. +# %% for i in np.random.choice(range(num_images), 4): fig, axs = plt.subplots(1, 4, figsize=(20, 4)) for j, ax in enumerate(axs): @@ -857,7 +871,7 @@ def copy_parameters(source_model, target_model): ax.set_title(f"Class {j}") # %% [markdown] tags=[] -#

Questions

+#

Questions

#
    #
  • Can you easily tell which of these images is the original, and which ones are the counterfactuals?
  • #
  • What is your hypothesis for the features that define each class?
  • @@ -874,10 +888,121 @@ def copy_parameters(source_model, target_model): # # Let's try putting the two together to see if we can figure out what exactly makes a class. # +# %% +batch_size = 4 +batch = [random_test_mnist[i] for i in range(batch_size)] +x = torch.stack([b[0] for b in batch]) +y = torch.tensor([b[1] for b in batch]) +x_fake = torch.tensor(counterfactuals[0, :batch_size]) +x = x.to(device).float() +y = y.to(device) +x_fake = x_fake.to(device).float() +# Generated attributions on integrated gradients +attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y) -# %% [markdown] tags=[] + +# %% Another visualization function +def visualize_color_attribution_and_counterfactual( + attribution, original_image, counterfactual_image +): + attribution = np.transpose(attribution, (1, 2, 0)) + original_image = np.transpose(original_image, (1, 2, 0)) + counterfactual_image = np.transpose(counterfactual_image, (1, 2, 0)) + + fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 5)) + ax0.imshow(original_image) + ax0.set_title("Image") + ax0.axis("off") + ax1.imshow(counterfactual_image) + ax1.set_title("Counterfactual") + ax1.axis("off") + ax2.imshow(np.abs(attribution)) + ax2.set_title("Attribution") + ax2.axis("off") + plt.show() + + +# %% +for idx in range(batch_size): + print("Source class:", y[idx].item()) + print("Target class:", 0) + visualize_color_attribution_and_counterfactual( + attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy() + ) +# %% [markdown] +#

    Questions

    +#
      +#
    • Do the attributions explain the differences between the images and their counterfactuals?
    • +#
    • What happens when the "counterfactual" and the original image are of the same class? Why do you think this is?
    • +#
    • Do you have a more refined hypothesis for what makes each class unique?
    • +#
    +#
    +# %% [markdown] +#

    Checkpoint 4

    +# At this point you have: +# - Created a StarGAN that can change the class of an image +# - Evaluated the StarGAN on unseen data +# - Used the StarGAN to create counterfactual images +# - Used the counterfactual images to highlight the differences between classes +# +# %% [markdown] +# # Part 6: Exploring the Style Space, finding the answer +# By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes! +# +# Here is an example of two images that are very similar in color, but are of different classes. +# ![same_color_diff_class](assets/same_color_diff_class.png) +# While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it! +# +# Conversely, here is an example of two images with very different colors, but that are of the same class: +# ![same_class_diff_color](assets/same_class_diff_color.png) +# Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all! +# +# +# So color is important... but not always? What's going on!? +# There is a final piece of information that we can use to solve the puzzle: the style space. +# %% +#

    Task 6.1: Explore the style space

    +# Let's take a look at the style space. +# We will use the style encoder to encode the style of the images and then use PCA to visualize it. +#
    # TODO + +# %% +styles = [] +labels = [] +for img, label in random_test_mnist: + styles.append( + style_encoder(img.unsqueeze(0).to(device)).cpu().detach().numpy().squeeze() + ) + labels.append(label) + +# PCA +from sklearn.decomposition import PCA + +pca = PCA(n_components=2) +styles_pca = pca.fit_transform(styles) + +# Plot the PCA +plt.figure(figsize=(10, 10)) +for i in range(4): + plt.scatter( + styles_pca[np.array(labels) == i, 0], + styles_pca[np.array(labels) == i, 1], + label=f"Class {i}", + ) + +plt.show() + +# %% [markdown] +#

    Task 6.2: Adding color to the style space

    +# We know that color is important. Does interpreting the style space as colors help us understand better? +# +# Let's use the style space to color the PCA plot. +#
    +# TODO WIP HERE + +# %% [markdown] tags=[] # ## Going Further # # Here are some ideas for how to continue with this notebook: