Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a set of bugs flagged by users #20

Merged
merged 21 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f5de7f4
Merge pull request #13 from funkelab/dev
adjavon Dec 9, 2024
ad90782
docs: :memo: Update tutorials
adjavon Jan 10, 2025
fd85743
docs: :bug: Fix code-blocks in viz tutorial
adjavon Jan 10, 2025
a51bfdc
ci(pre-commit.ci): autoupdate
pre-commit-ci[bot] Feb 3, 2025
8e158df
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 3, 2025
4e07465
fix: :bug: Deal with non-absolute paths in data loaders
adjavon Feb 26, 2025
5241ba0
perf: :bug: Add a check for empty datasets
adjavon Feb 26, 2025
302124e
feat: :sparkles: Allow reading tiff files
adjavon Feb 26, 2025
fbc3ac4
ci: :green_heart: Add ruff to github actions
adjavon Feb 27, 2025
92a829e
ci: :green_heart: Remove black from github actions
adjavon Feb 27, 2025
9441bce
ci: :green_heart: Replace black with ruff in pre-commit
adjavon Feb 27, 2025
503deba
style: :rotating_light: Run ruff on failing files and fix issues
adjavon Feb 27, 2025
5051f76
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Feb 27, 2025
e4c7b9c
style: :rotating_light: Fix more ruff-raised issues.
adjavon Feb 27, 2025
d41f7cc
Merge branch 'bugfix' of https://github.com/funkelab/quac into bugfix
adjavon Feb 27, 2025
f1fb2bd
style: :rotating_light: Fix more ruff formatting issues
adjavon Feb 27, 2025
69637aa
ci: :green_heart: Remove ruff from github actions
adjavon Feb 27, 2025
cde06de
ci: Let ruff handle whitespace
adjavon Feb 27, 2025
b95cc82
Merge pull request #16 from funkelab/pre-commit-ci-update-config
adjavon Feb 27, 2025
bc8106d
ci: Bump up ruff version in pre-commit
adjavon Feb 27, 2025
cf8bd46
Merge branch 'main' into bugfix
adjavon Feb 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions .github/workflows/black.yaml

This file was deleted.

18 changes: 7 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,15 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-docstring-first
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/psf/black
rev: 23.1.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.8
hooks:
- id: black

# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.0.1
# hooks:
# - id: mypy
- id: ruff
args: [--fix]
- id: ruff-format
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

from datetime import datetime

import quac
import tomli
Expand Down
43 changes: 33 additions & 10 deletions docs/source/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ QuAC, or Quantitative Attributions with Counterfactuals, is a method for generat
Let's assume, for instance, that you have images of cells grown in two different conditions.
To your eye, the phenotypic difference between the two conditions is hidden within the cell-to-cell variability of the dataset, but you know it is there because you've trained a classifier to differentiate the two conditions and it works. So how do you pull out the differences?

We begin by training a generative neural network to convert your images from one class to another. Here, we'll use a StarGAN. This allows us to go from our real, **query** image, to our **generated** image.
Assuming that you already have a classifier that does your task, we begin by training a generative neural network to convert your images from one class to another. Here, we'll use a StarGAN. This allows us to go from our real, **query** image, to our **generated** image.
Using information learned from **reference** images, the StarGAN is trained in such a way that the **generated** image will have a different class!

While very powerful, these generative networks *can potentially* make some changes that are not necessary to the classification.
Expand All @@ -28,33 +28,56 @@ It is as close as possible to the original image, with only the necessary change
:width: 800
:align: center

Before you begin, download [the data]() and [the pre-trained models]() for an example.
Then, make sure you've installed QuAC by following the :doc:`Installation guide <install>`.

Before you begin, make sure you've installed QuAC by following the :doc:`Installation guide <install>`.

The classifier
==============
To use QuAC, we assume that you have a classifier trained on your data.
There are many different packages already to help you do that, for this code-base we will need you to have the weights to your classifier as a JIT-compiled `pytorch` model.

If you just want to try QuAC as a learning experience, you can use one of the datasets `in this collection <https://doi.org/10.25378/janelia.c.7620737.v1>`_, and the pre-trained models we provide.

The conversion network
===============================
Once you've set up your data and your classifier, you can move on to training the conversion network.
We'll use a StarGAN for this.
There are two options for training the StarGAN, but we recommend :doc:`training it using a YAML file <tutorials/train_yaml>`.
This will make it easier to keep track of your experiments!
If you prefer to define parameters directly in Python, however, you can follow the :doc:`alternative training tutorial <tutorials/train>` instead.
Note that in both cases, you will need the JIT-compiled classifier model!

You have two options for training the StarGAN, you can either :doc:`define parameters directly in Python <tutorials/train>` or :doc:`train it using a YAML file <tutorials/train_yaml>`.
We recommend the latter, which will make it easier to keep track of your experiments!
Once you've trained a decent model, generate a set of images using the :doc:`image generation tutorial <tutorials/generate>` before moving on to the next steps.
Once you've trained a decent model, you can generate a set of images using the :doc:`image generation tutorial <tutorials/generate>`.
We recommend taking a look at your generated images, to make sure that they look like what you expect.
If that is the case, you can move on to the next steps!

.. toctree::
:maxdepth: 1

tutorials/train
tutorials/train_yaml
Generating images <tutorials/generate>

Attribution and evaluation
==========================

With the generated images in hand, we can now run the attribution and evaluation steps.
With the generated images in hand, we can now run the :doc:`attribution <tutorials/attribute>` step, then the :doc:`evaluation <tutorials/evaluate>` step.
These two steps allow us to overcome the limitations of the generative network to create *truly* minimal counterfactual images, and to score the query-counterfactual pairs based on how well they explain the classifier.

Visualizing results
===================

Finally, we can visualize the results of the attribution and evaluation steps using the :doc:`visualization tutorial <tutorials/visualize>`.
This will allow you to see the quantification results, in the form of QuAC curves.
It will also help you choose the best attribution method for each example, and load the counterfactual visual explanations for these examples.

Table of Contents
=================
Here's a list of all available tutorials, in case you want to navigate directly to one of them.

.. toctree::
:maxdepth: 1

Training the generator (recommended) <tutorials/train_yaml>
Training the generator (alternative) <tutorials/train>
Generating images <tutorials/generate>
Attribution <tutorials/attribute>
Evaluation <tutorials/evaluate>
Visualizing results <tutorials/visualize>
12 changes: 4 additions & 8 deletions docs/source/tutorials/generate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
How to generate images from a pre-trained network
=================================================

.. attention::
This tutorial is still under construction. Come back soon for updates!


Defining the dataset
====================

Expand All @@ -22,7 +18,7 @@ For example, below, we are going to be using the validation data, and our source
from quac.generate import load_data

img_size = 224
data_directory = Path("root_directory/val/0_No_DR")
data_directory = Path("/path/to/directory/holding/the/data/source_class")
dataset = load_data(data_directory, img_size, grayscale=False)


Expand Down Expand Up @@ -86,7 +82,7 @@ Finally, we can run the image generation.
from quac.generate import get_counterfactual
from torchvision.utils import save_image

output_directory = Path("/path/to/output/latent/0_No_DR/1_Mild/")
output_directory = Path("/path/to/output/latent/source_class/target_class/")

for x, name in tqdm(dataset):
xcf = get_counterfactual(
Expand Down Expand Up @@ -117,7 +113,7 @@ The first thing we need to do is to get the reference images.
.. code-block:: python
:linenos:

reference_data_directory = Path(f"{root_directory}/val/1_Mild")
reference_data_directory = Path("/path/to/directory/holding/the/data/target_class")
reference_dataset = load_data(reference_data_directory, img_size, grayscale=False)

Loading the StarGAN
Expand Down Expand Up @@ -148,7 +144,7 @@ Finally, we combine the two by changing the `kind` in our counterfactual generat

from torchvision.utils import save_image

output_directory = Path("/path/to/output/reference/0_No_DR/1_Mild/")
output_directory = Path("/path/to/output/reference/source_class/target_class/")

for x, name in tqdm(dataset):
xcf = get_counterfactual(
Expand Down
170 changes: 165 additions & 5 deletions docs/source/tutorials/visualize.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,169 @@
Visualizing the results
=======================

.. attention::
This tutorial is still under construction. Come back soon for updates!
In this tutorial, we will show you how to visualize the results of the attribution and evaluation steps.
Make sure to modify the paths to the reports and the classifier to match your setup!

.. image:: ../assets/quac.png
:width: 100
:align: center
Obtaining the QuAC curves
=========================
Let's start by loading the reports obtained in the previous step.

.. code-block:: python
:linenos:

from quac.report import Report

report_directory = "/path/to/report/directory/"
reports = {
method: Report(name=method)
for method in [
"DDeepLift",
"DIntegratedGradients",
]
}

for method, report in reports.items():
report.load(report_directory + method + "/default.json")

Next, we can plot the QuAC curves for each method.
This allows us to get an idea of how well each method is performing, overall.

.. code-block:: python
:linenos:

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
for method, report in reports.items():
report.plot_curve(ax=ax)
# Add the legend
plt.legend()
plt.show()


Choosing the best attribution method for each sample
====================================================

While one attribution method may be better than another on average, it is possible that the best method for a given example is different.
Therefore, we will make a list of the best method for each example by comparing the quac scores.

.. code-block:: python
:linenos:

quac_scores = pd.DataFrame(
{method: report.quac_scores for method, report in reports.items()}
)
best_methods = quac_scores.idxmax(axis=1)
best_quac_scores = quac_scores.max(axis=1)

We'll also want to load the classifier at this point, so we can look at the classifications of the counterfactual images.

.. code-block:: python
:linenos:

import torch

classifier = torch.jit.load("/path/to/classifier/model.pt")


Choosing the best examples
==========================
Next we want to choose the best example, given the best method.
This is done by ordering the examples by the QuAC score, and then choosing the one with the highest score.

.. code-block:: python
:linenos:

order = best_quac_scores[::-1].argsort()

# For example, choose the 10th best example
idx = 10
# Get the corresponding report
report = reports[best_methods[order[idx]]]

We will then load that example and its counterfactual from its path, and visualize it.
We also want to see the classification of both the original and the counterfactual.

.. code-block:: python
:linenos:

# Transform to apply to the images so they match each other
# loading
from PIL import Image

image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]]
image, generated_image = Image.open(image_path), Image.open(generated_path)

prediction = report.predictions[order[idx]]
target_prediction = report.target_predictions[order[idx]]

image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]]
image, generated_image = Image.open(image_path), Image.open(generated_path)

prediction = report.predictions[order[idx]]
target_prediction = report.target_predictions[order[idx]]

Loading the attribution
=======================
We next want to load the attribution for the example, and visualize it.

.. code-block:: python
:linenos:

attribution_path = report.attribution_paths[order[idx]]
attribution = np.load(attribution_path)

Getting the processor
=====================
We want to see the specific mask that was optimal in this case.
To do this, we will need to get the optimal threshold, and get the processor used for masking.

.. code-block:: python
:linenos:

from quac.evaluation import Processor

gaussian_kernel_size = 11
struc = 10
thresh = report.optimal_thresholds()[order[idx]]
print(thresh)
processor = Processor(gaussian_kernel_size=gaussian_kernel_size, struc=struc)

mask, _ = processor.create_mask(attribution, thresh)
rgb_mask = mask.transpose(1, 2, 0)
# zero-out the green and blue channels
rgb_mask[:, :, 1] = 0
rgb_mask[:, :, 2] = 0
counterfactual = np.array(generated_image) / 255 * rgb_mask + np.array(image) / 255 * (1.0 - rgb_mask)

Let's also get the classifier output for the counterfactual image.

.. code-block:: python
:linenos:

classifier_output = classifier(
torch.tensor(counterfactual).permute(2, 0, 1).float().unsqueeze(0).to(device)
)
counterfactual_prediction = softmax(classifier_output[0].detach().cpu().numpy())

Visualizing the results
=======================
Finally, we can visualize the results.

.. code-block:: python
:linenos:

fig, axes = plt.subplots(2, 4)
axes[1, 0].imshow(image)
axes[0, 0].bar(np.arange(len(prediction)), prediction)
axes[1, 1].imshow(generated_image)
axes[0, 1].bar(np.arange(len(target_prediction)), target_prediction)
axes[0, 2].bar(np.arange(len(counterfactual_prediction)), counterfactual_prediction)
axes[1, 2].imshow(counterfactual)
axes[1, 3].imshow(rgb_mask)
axes[0, 3].axis("off")
fig.suptitle(f"QuAC Score: {report.quac_scores[order[idx]]}")
plt.show()

You can now see the original image, the generated image, the counterfactual image, and the mask.
From here, you can choose to visualize other examples, of save the images for later use.
1 change: 0 additions & 1 deletion src/quac/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import cv2
from dataclasses import dataclass
import numpy as np
from pathlib import Path
from quac.data import PairedImageDataset, CounterfactualDataset, PairedWithAttribution
Expand Down
6 changes: 3 additions & 3 deletions src/quac/generate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def get_counterfactual( # type: ignore
# Copy x batch_size times
x_multiple = torch.stack([x] * batch_size)
if kind == "reference":
assert (
dataset_ref is not None
), "Reference dataset required for reference style."
assert dataset_ref is not None, (
"Reference dataset required for reference style."
)
if len(dataset_ref) // batch_size < max_tries:
max_tries = len(dataset_ref) // batch_size
logger.warning(
Expand Down
1 change: 0 additions & 1 deletion src/quac/generate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
)
from quac.training.checkpoint import CheckpointIO
import torch
from typing import Optional


class InferenceModel(torch.nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions src/quac/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def optimal_thresholds(self, min_percentage=0.0):
max_value = np.max(mask_scores, axis=1)
threshold = min_value + min_percentage * (max_value - min_value)
below_threshold = mask_scores < threshold[:, None]
tradeoff_scores[
below_threshold
] = np.inf # Ignores the points with not enough score change
tradeoff_scores[below_threshold] = (
np.inf
) # Ignores the points with not enough score change
thr_idx = np.argmin(tradeoff_scores, axis=1)

optimal_thresholds = np.take_along_axis(
Expand Down
Loading
Loading