Skip to content

Commit

Permalink
wip: Add discriminative attribution
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Aug 12, 2024
1 parent 846525f commit 702c0e3
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 28 deletions.
Binary file added assets/same_class_diff_color.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/same_color_diff_class.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
181 changes: 153 additions & 28 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,12 @@ def forward(self, x, y):
# %% tags=["task"]
style_size = ... # TODO choose a size for the style space
unet_depth = ... # TODO Choose a depth for the UNet
style_mapping = DenseModel(
style_encoder = DenseModel(
input_shape=..., num_classes=... # How big is the style space?
)
unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid())

generator = Generator(unet, style_mapping=style_mapping)
generator = Generator(unet, style_encoder=style_encoder)
# %% tags=["solution"]
# Here is an example of a working setup! Note that you can change the hyperparameters as you experiment.
# Choose your own setup to see what works for you.
Expand Down Expand Up @@ -533,7 +533,7 @@ def copy_parameters(source_model, target_model):


# %%
generator_ema = Generator(deepcopy(unet), style_mapping=deepcopy(style_mapping))
generator_ema = Generator(deepcopy(unet), style_encoder=deepcopy(style_encoder))
generator_ema = generator_ema.to(device)

# %% [markdown] tags=[]
Expand Down Expand Up @@ -785,51 +785,64 @@ def copy_parameters(source_model, target_model):

# %% [markdown]
# Now we need to use these prototypes to create counterfactual images!
# TODO make a task here!
# %% [markdown]
# <div class="alert alert-block alert-info"><h3>Task 4.1: Create counterfactuals</h3>
# In the below, we will store the counterfactual images in the `counterfactuals` array.
#
# <ul>
# <li> Create a counterfactual image for each of the prototypes. </li>
# <li> Classify the counterfactual image using the classifier. </li>
# <li> Store the source and target labels; which is which?</li>
# </ul>
# %% tags=["task"]
num_images = len(test_mnist)
num_images = 1000
random_test_mnist = torch.utils.data.Subset(
test_mnist, np.random.choice(len(test_mnist), num_images, replace=False)
)
counterfactuals = np.zeros((4, num_images, 3, 28, 28))

predictions = []
source_labels = []
target_labels = []

for x, y in test_mnist:
for i in range(4):
if i == y:
# Store the image as is.
counterfactuals[i] = ...
# Create the counterfactual from the image and prototype
for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):
for lbl in range(4):
# TODO Create the counterfactual
x_fake = generator(x.unsqueeze(0).to(device), ...)
counterfactuals[i] = x_fake.cpu().detach().numpy()
# TODO Predict the class of the counterfactual image
pred = model(...)

source_labels.append(y)
target_labels.append(i)
# TODO Store the source and target labels
source_labels.append(...) # The original label of the image
target_labels.append(...) # The desired label of the counterfactual image
# Store the counterfactual image and prediction
counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()
predictions.append(pred.argmax().item())

# %% tags=["solution"]
num_images = len(test_mnist)
num_images = 1000
random_test_mnist = torch.utils.data.Subset(
test_mnist, np.random.choice(len(test_mnist), num_images, replace=False)
)
counterfactuals = np.zeros((4, num_images, 3, 28, 28))

predictions = []
source_labels = []
target_labels = []

for x, y in test_mnist:
for i in range(4):
if i == y:
# Store the image as is.
counterfactuals[i] = x
for i, (x, y) in tqdm(enumerate(random_test_mnist), total=num_images):
for lbl in range(4):
# Create the counterfactual
x_fake = generator(
x.unsqueeze(0).to(device), prototypes[i].unsqueeze(0).to(device)
x.unsqueeze(0).to(device), prototypes[lbl].unsqueeze(0).to(device)
)
counterfactuals[i] = x_fake.cpu().detach().numpy()
# Predict the class of the counterfactual image
pred = model(x_fake)

source_labels.append(y)
target_labels.append(i)
# Store the source and target labels
source_labels.append(y) # The original label of the image
target_labels.append(lbl) # The desired label of the counterfactual image
# Store the counterfactual image and prediction
counterfactuals[lbl][i] = x_fake.cpu().detach().numpy()
predictions.append(pred.argmax().item())

# %% [markdown] tags=[]
Expand All @@ -842,13 +855,14 @@ def copy_parameters(source_model, target_model):
# <div class="alert alert-block alert-warning"><h3>Questions</h3>
# <ul>
# <li> How well is our GAN doing at creating counterfactual images? </li>
# <li> Do you think that the prototypes used matter? Why or why not? </li>
# <li> Does your choice of prototypes matter? Why or why not? </li>
# </ul>
# </div>

# %% [markdown] tags=[]
# Let's also plot some examples of the counterfactual images.

# %%
for i in np.random.choice(range(num_images), 4):
fig, axs = plt.subplots(1, 4, figsize=(20, 4))
for j, ax in enumerate(axs):
Expand All @@ -857,7 +871,7 @@ def copy_parameters(source_model, target_model):
ax.set_title(f"Class {j}")

# %% [markdown] tags=[]
# <div class="alert alert-block alert-info"><h3>Questions</h3>
# <div class="alert alert-block alert-warning"><h3>Questions</h3>
# <ul>
# <li>Can you easily tell which of these images is the original, and which ones are the counterfactuals?</li>
# <li>What is your hypothesis for the features that define each class?</li>
Expand All @@ -874,10 +888,121 @@ def copy_parameters(source_model, target_model):
#
# Let's try putting the two together to see if we can figure out what exactly makes a class.
#
# %%
batch_size = 4
batch = [random_test_mnist[i] for i in range(batch_size)]
x = torch.stack([b[0] for b in batch])
y = torch.tensor([b[1] for b in batch])
x_fake = torch.tensor(counterfactuals[0, :batch_size])
x = x.to(device).float()
y = y.to(device)
x_fake = x_fake.to(device).float()

# Generated attributions on integrated gradients
attributions = integrated_gradients.attribute(x, baselines=x_fake, target=y)

# %% [markdown] tags=[]

# %% Another visualization function
def visualize_color_attribution_and_counterfactual(
attribution, original_image, counterfactual_image
):
attribution = np.transpose(attribution, (1, 2, 0))
original_image = np.transpose(original_image, (1, 2, 0))
counterfactual_image = np.transpose(counterfactual_image, (1, 2, 0))

fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 5))
ax0.imshow(original_image)
ax0.set_title("Image")
ax0.axis("off")
ax1.imshow(counterfactual_image)
ax1.set_title("Counterfactual")
ax1.axis("off")
ax2.imshow(np.abs(attribution))
ax2.set_title("Attribution")
ax2.axis("off")
plt.show()


# %%
for idx in range(batch_size):
print("Source class:", y[idx].item())
print("Target class:", 0)
visualize_color_attribution_and_counterfactual(
attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy()
)
# %% [markdown]
# <div class="alert alert-block alert-warning"><h3>Questions</h3>
# <ul>
# <li> Do the attributions explain the differences between the images and their counterfactuals? </li>
# <li> What happens when the "counterfactual" and the original image are of the same class? Why do you think this is? </li>
# <li> Do you have a more refined hypothesis for what makes each class unique? </li>
# </ul>
# </div>
# %% [markdown]
# <div class="alert alert-block alert-success"><h2>Checkpoint 4</h2>
# At this point you have:
# - Created a StarGAN that can change the class of an image
# - Evaluated the StarGAN on unseen data
# - Used the StarGAN to create counterfactual images
# - Used the counterfactual images to highlight the differences between classes
#
# %% [markdown]
# # Part 6: Exploring the Style Space, finding the answer
# By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!
#
# Here is an example of two images that are very similar in color, but are of different classes.
# ![same_color_diff_class](assets/same_color_diff_class.png)
# While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it!
#
# Conversely, here is an example of two images with very different colors, but that are of the same class:
# ![same_class_diff_color](assets/same_class_diff_color.png)
# Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all!
#
#
# So color is important... but not always? What's going on!?
# There is a final piece of information that we can use to solve the puzzle: the style space.
# %%
# <div class="alert alert-block alert-info"><h3>Task 6.1: Explore the style space</h3>
# Let's take a look at the style space.
# We will use the style encoder to encode the style of the images and then use PCA to visualize it.
# </div>
# TODO

# %%
styles = []
labels = []
for img, label in random_test_mnist:
styles.append(
style_encoder(img.unsqueeze(0).to(device)).cpu().detach().numpy().squeeze()
)
labels.append(label)

# PCA
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
styles_pca = pca.fit_transform(styles)

# Plot the PCA
plt.figure(figsize=(10, 10))
for i in range(4):
plt.scatter(
styles_pca[np.array(labels) == i, 0],
styles_pca[np.array(labels) == i, 1],
label=f"Class {i}",
)

plt.show()

# %% [markdown]
# <div class="alert alert-block alert-info"><h3>Task 6.2: Adding color to the style space</h3>
# We know that color is important. Does interpreting the style space as colors help us understand better?
#
# Let's use the style space to color the PCA plot.
# </div>
# TODO WIP HERE

# %% [markdown] tags=[]
# ## Going Further
#
# Here are some ideas for how to continue with this notebook:
Expand Down

0 comments on commit 702c0e3

Please sign in to comment.