From ad907821e88cbff7fc89e5ddc0e23c48545d4ce9 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Fri, 10 Jan 2025 10:16:54 -0500 Subject: [PATCH 01/17] docs: :memo: Update tutorials --- docs/source/tutorials.rst | 43 ++++++-- docs/source/tutorials/generate.rst | 12 +- docs/source/tutorials/visualize.rst | 165 +++++++++++++++++++++++++++- 3 files changed, 197 insertions(+), 23 deletions(-) diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst index 57c9bf0..01845e4 100644 --- a/docs/source/tutorials.rst +++ b/docs/source/tutorials.rst @@ -14,7 +14,7 @@ QuAC, or Quantitative Attributions with Counterfactuals, is a method for generat Let's assume, for instance, that you have images of cells grown in two different conditions. To your eye, the phenotypic difference between the two conditions is hidden within the cell-to-cell variability of the dataset, but you know it is there because you've trained a classifier to differentiate the two conditions and it works. So how do you pull out the differences? -We begin by training a generative neural network to convert your images from one class to another. Here, we'll use a StarGAN. This allows us to go from our real, **query** image, to our **generated** image. +Assuming that you already have a classifier that does your task, we begin by training a generative neural network to convert your images from one class to another. Here, we'll use a StarGAN. This allows us to go from our real, **query** image, to our **generated** image. Using information learned from **reference** images, the StarGAN is trained in such a way that the **generated** image will have a different class! While very powerful, these generative networks *can potentially* make some changes that are not necessary to the classification. @@ -28,33 +28,56 @@ It is as close as possible to the original image, with only the necessary change :width: 800 :align: center -Before you begin, download [the data]() and [the pre-trained models]() for an example. -Then, make sure you've installed QuAC by following the :doc:`Installation guide `. +Before you begin, make sure you've installed QuAC by following the :doc:`Installation guide `. + +The classifier +============== +To use QuAC, we assume that you have a classifier trained on your data. +There are many different packages already to help you do that, for this code-base we will need you to have the weights to your classifier as a JIT-compiled `pytorch` model. + +If you just want to try QuAC as a learning experience, you can use one of the datasets `in this collection `_, and the pre-trained models we provide. The conversion network =============================== +Once you've set up your data and your classifier, you can move on to training the conversion network. +We'll use a StarGAN for this. +There are two options for training the StarGAN, but we recommend :doc:`training it using a YAML file `. +This will make it easier to keep track of your experiments! +If you prefer to define parameters directly in Python, however, you can follow the :doc:`alternative training tutorial ` instead. +Note that in both cases, you will need the JIT-compiled classifier model! -You have two options for training the StarGAN, you can either :doc:`define parameters directly in Python ` or :doc:`train it using a YAML file `. -We recommend the latter, which will make it easier to keep track of your experiments! -Once you've trained a decent model, generate a set of images using the :doc:`image generation tutorial ` before moving on to the next steps. +Once you've trained a decent model, you can generate a set of images using the :doc:`image generation tutorial `. +We recommend taking a look at your generated images, to make sure that they look like what you expect. +If that is the case, you can move on to the next steps! .. toctree:: :maxdepth: 1 - tutorials/train - tutorials/train_yaml - Generating images Attribution and evaluation ========================== -With the generated images in hand, we can now run the attribution and evaluation steps. +With the generated images in hand, we can now run the :doc:`attribution ` step, then the :doc:`evaluation ` step. These two steps allow us to overcome the limitations of the generative network to create *truly* minimal counterfactual images, and to score the query-counterfactual pairs based on how well they explain the classifier. +Visualizing results +=================== + +Finally, we can visualize the results of the attribution and evaluation steps using the :doc:`visualization tutorial `. +This will allow you to see the quantification results, in the form of QuAC curves. +It will also help you choose the best attribution method for each example, and load the counterfactual visual explanations for these examples. + +Table of Contents +================= +Here's a list of all available tutorials, in case you want to navigate directly to one of them. + .. toctree:: :maxdepth: 1 + Training the generator (recommended) + Training the generator (alternative) + Generating images Attribution Evaluation Visualizing results diff --git a/docs/source/tutorials/generate.rst b/docs/source/tutorials/generate.rst index 0775941..c3d2a49 100644 --- a/docs/source/tutorials/generate.rst +++ b/docs/source/tutorials/generate.rst @@ -4,10 +4,6 @@ How to generate images from a pre-trained network ================================================= -.. attention:: - This tutorial is still under construction. Come back soon for updates! - - Defining the dataset ==================== @@ -22,7 +18,7 @@ For example, below, we are going to be using the validation data, and our source from quac.generate import load_data img_size = 224 - data_directory = Path("root_directory/val/0_No_DR") + data_directory = Path("/path/to/directory/holding/the/data/source_class") dataset = load_data(data_directory, img_size, grayscale=False) @@ -86,7 +82,7 @@ Finally, we can run the image generation. from quac.generate import get_counterfactual from torchvision.utils import save_image - output_directory = Path("/path/to/output/latent/0_No_DR/1_Mild/") + output_directory = Path("/path/to/output/latent/source_class/target_class/") for x, name in tqdm(dataset): xcf = get_counterfactual( @@ -117,7 +113,7 @@ The first thing we need to do is to get the reference images. .. code-block:: python :linenos: - reference_data_directory = Path(f"{root_directory}/val/1_Mild") + reference_data_directory = Path("/path/to/directory/holding/the/data/target_class") reference_dataset = load_data(reference_data_directory, img_size, grayscale=False) Loading the StarGAN @@ -148,7 +144,7 @@ Finally, we combine the two by changing the `kind` in our counterfactual generat from torchvision.utils import save_image - output_directory = Path("/path/to/output/reference/0_No_DR/1_Mild/") + output_directory = Path("/path/to/output/reference/source_class/target_class/") for x, name in tqdm(dataset): xcf = get_counterfactual( diff --git a/docs/source/tutorials/visualize.rst b/docs/source/tutorials/visualize.rst index 3bd9022..52936d0 100644 --- a/docs/source/tutorials/visualize.rst +++ b/docs/source/tutorials/visualize.rst @@ -2,9 +2,164 @@ Visualizing the results ======================= -.. attention:: - This tutorial is still under construction. Come back soon for updates! +In this tutorial, we will show you how to visualize the results of the attribution and evaluation steps. +Make sure to modify the paths to the reports and the classifier to match your setup! - .. image:: ../assets/quac.png - :width: 100 - :align: center +Obtaining the QuAC curves +========================= +Let's start by loading the reports obtained in the previous step. + +.. code-block:: python + :linenos: + report_directory = "/path/to/report/directory/" + + + from quac.report import Report + + reports = { + method: Report(name=method) + for method in [ + "DDeepLift", + "DIntegratedGradients", + ] + } + + for method, report in reports.items(): + report.load(report_directory + method + "/default.json") + +Next, we can plot the QuAC curves for each method. +This allows us to get an idea of how well each method is performing, overall. + +.. code-block:: python + :linenos: + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + for method, report in reports.items(): + report.plot_curve(ax=ax) + # Add the legend + plt.legend() + plt.show() + + +Choosing the best attribution method for each sample +==================================================== + +While one attribution method may be better than another on average, it is possible that the best method for a given example is different. +Therefore, we will make a list of the best method for each example by comparing the quac scores. +.. code-block:: python + :linenos: + quac_scores = pd.DataFrame( + {method: report.quac_scores for method, report in reports.items()} + ) + best_methods = quac_scores.idxmax(axis=1) + best_quac_scores = quac_scores.max(axis=1) + +We'll also want to load the classifier at this point, so we can look at the classifications of the counterfactual images. + +.. code-block:: python + :linenos: + import torch + + classifier = torch.jit.load("/path/to/classifier/model.pt") + + +Choosing the best examples +========================== +Next we want to choose the best example, given the best method. +This is done by ordering the examples by the QuAC score, and then choosing the one with the highest score. + +.. code-block:: python + :linenos: + order = best_quac_scores[::-1].argsort() + + # For example, choose the 10th best example + idx = 10 + # Get the corresponding report + report = reports[best_methods[order[idx]]] + +We will then load that example and its counterfactual from its path, and visualize it. +We also want to see the classification of both the original and the counterfactual. + +.. code-block:: python + :linenos: + # Transform to apply to the images so they match each other + # loading + from PIL import Image + + image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]] + image, generated_image = Image.open(image_path), Image.open(generated_path) + + prediction = report.predictions[order[idx]] + target_prediction = report.target_predictions[order[idx]] + + image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]] + image, generated_image = Image.open(image_path), Image.open(generated_path) + + prediction = report.predictions[order[idx]] + target_prediction = report.target_predictions[order[idx]] + +Loading the attribution +======================= +We next want to load the attribution for the example, and visualize it. + +.. code-block:: python + :linenos: + + attribution_path = report.attribution_paths[order[idx]] + attribution = np.load(attribution_path) + +Getting the processor +===================== +We want to see the specific mask that was optimal in this case. +To do this, we will need to get the optimal threshold, and get the processor used for masking. + +.. code-block:: python + :linenos: + from quac.evaluation import Processor + + gaussian_kernel_size = 11 + struc = 10 + thresh = report.optimal_thresholds()[order[idx]] + print(thresh) + processor = Processor(gaussian_kernel_size=gaussian_kernel_size, struc=struc) + + mask, _ = processor.create_mask(attribution, thresh) + rgb_mask = mask.transpose(1, 2, 0) + # zero-out the green and blue channels + rgb_mask[:, :, 1] = 0 + rgb_mask[:, :, 2] = 0 + counterfactual = np.array(generated_image) / 255 * rgb_mask + np.array(image) / 255 * (1.0 - rgb_mask) + +Let's also get the classifier output for the counterfactual image. + +.. code-block:: python + :linenos: + + classifier_output = classifier( + torch.tensor(counterfactual).permute(2, 0, 1).float().unsqueeze(0).to(device) + ) + counterfactual_prediction = softmax(classifier_output[0].detach().cpu().numpy()) + +Visualizing the results +======================= +Finally, we can visualize the results. + +.. code-block:: python + :linenos: + + fig, axes = plt.subplots(2, 4) + axes[1, 0].imshow(image) + axes[0, 0].bar(np.arange(len(prediction)), prediction) + axes[1, 1].imshow(generated_image) + axes[0, 1].bar(np.arange(len(target_prediction)), target_prediction) + axes[0, 2].bar(np.arange(len(counterfactual_prediction)), counterfactual_prediction) + axes[1, 2].imshow(counterfactual) + axes[1, 3].imshow(rgb_mask) + axes[0, 3].axis("off") + fig.suptitle(f"QuAC Score: {report.quac_scores[order[idx]]}") + plt.show() + +You can now see the original image, the generated image, the counterfactual image, and the mask. +From here, you can choose to visualize other examples, of save the images for later use. From fd857430f954ba1725678e2fc4d8c8e16af57d33 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Fri, 10 Jan 2025 14:43:15 -0500 Subject: [PATCH 02/17] docs: :bug: Fix code-blocks in viz tutorial --- docs/source/tutorials/visualize.rst | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/visualize.rst b/docs/source/tutorials/visualize.rst index 52936d0..b690ebe 100644 --- a/docs/source/tutorials/visualize.rst +++ b/docs/source/tutorials/visualize.rst @@ -11,11 +11,10 @@ Let's start by loading the reports obtained in the previous step. .. code-block:: python :linenos: - report_directory = "/path/to/report/directory/" - from quac.report import Report + report_directory = "/path/to/report/directory/" reports = { method: Report(name=method) for method in [ @@ -48,8 +47,10 @@ Choosing the best attribution method for each sample While one attribution method may be better than another on average, it is possible that the best method for a given example is different. Therefore, we will make a list of the best method for each example by comparing the quac scores. + .. code-block:: python :linenos: + quac_scores = pd.DataFrame( {method: report.quac_scores for method, report in reports.items()} ) @@ -60,6 +61,7 @@ We'll also want to load the classifier at this point, so we can look at the clas .. code-block:: python :linenos: + import torch classifier = torch.jit.load("/path/to/classifier/model.pt") @@ -72,6 +74,7 @@ This is done by ordering the examples by the QuAC score, and then choosing the o .. code-block:: python :linenos: + order = best_quac_scores[::-1].argsort() # For example, choose the 10th best example @@ -84,6 +87,7 @@ We also want to see the classification of both the original and the counterfactu .. code-block:: python :linenos: + # Transform to apply to the images so they match each other # loading from PIL import Image @@ -117,6 +121,7 @@ To do this, we will need to get the optimal threshold, and get the processor use .. code-block:: python :linenos: + from quac.evaluation import Processor gaussian_kernel_size = 11 From a51bfdc1ccb19d074806b34d0354cfe18faf20b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 17:24:29 +0000 Subject: [PATCH 03/17] ci(pre-commit.ci): autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v3.2.0 → v5.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v3.2.0...v5.0.0) - [github.com/psf/black: 23.1.0 → 25.1.0](https://github.com/psf/black/compare/23.1.0...25.1.0) --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc763e9..bddbd1a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -15,7 +15,7 @@ repos: - id: check-added-large-files - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 25.1.0 hooks: - id: black From 8e158df9b7731e405bf091905ae51eb291be524b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 17:24:38 +0000 Subject: [PATCH 04/17] style(pre-commit.ci): auto fixes [...] --- src/quac/report.py | 6 +++--- src/quac/training/eval.py | 12 ++++++------ src/quac/training/solver.py | 12 ++++++------ 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/quac/report.py b/src/quac/report.py index e20fb2a..99fc29d 100644 --- a/src/quac/report.py +++ b/src/quac/report.py @@ -225,9 +225,9 @@ def optimal_thresholds(self, min_percentage=0.0): max_value = np.max(mask_scores, axis=1) threshold = min_value + min_percentage * (max_value - min_value) below_threshold = mask_scores < threshold[:, None] - tradeoff_scores[ - below_threshold - ] = np.inf # Ignores the points with not enough score change + tradeoff_scores[below_threshold] = ( + np.inf + ) # Ignores the points with not enough score change thr_idx = np.argmin(tradeoff_scores, axis=1) optimal_thresholds = np.take_along_axis( diff --git a/src/quac/training/eval.py b/src/quac/training/eval.py index cc7066f..1dac0b7 100644 --- a/src/quac/training/eval.py +++ b/src/quac/training/eval.py @@ -69,12 +69,12 @@ def calculate_metrics( std=std, grayscale=(input_dim == 1), ) - conversion_rate_values[ - "conversion_rate_%s/%s" % (mode, task) - ] = conversion_rate - translation_rate_values[ - "translation_rate_%s/%s" % (mode, task) - ] = translation_rate + conversion_rate_values["conversion_rate_%s/%s" % (mode, task)] = ( + conversion_rate + ) + translation_rate_values["translation_rate_%s/%s" % (mode, task)] = ( + translation_rate + ) # calculate the average conversion rate for all tasks conversion_rate_mean = 0 diff --git a/src/quac/training/solver.py b/src/quac/training/solver.py index f49cf63..c924a2e 100644 --- a/src/quac/training/solver.py +++ b/src/quac/training/solver.py @@ -454,12 +454,12 @@ def evaluate( translation_rate = np.mean(predictions == trg_idx) # STORE - conversion_rate_values[ - f"conversion_rate_{mode}/" + task - ] = conversion_rate - translation_rate_values[ - f"translation_rate_{mode}/" + task - ] = translation_rate + conversion_rate_values[f"conversion_rate_{mode}/" + task] = ( + conversion_rate + ) + translation_rate_values[f"translation_rate_{mode}/" + task] = ( + translation_rate + ) # Add average conversion rate and translation rate conversion_rate_values[f"conversion_rate_{mode}/average"] = np.mean( From 4e0746562c7f832122290f2bdfb4ae6cdbf9350f Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Wed, 26 Feb 2025 13:36:32 -0500 Subject: [PATCH 05/17] fix: :bug: Deal with non-absolute paths in data loaders Also adds tiff files in listdir Closes #17, Addresses #18 --- src/quac/training/data_loader.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/quac/training/data_loader.py b/src/quac/training/data_loader.py index 7e944e4..ef58e6a 100644 --- a/src/quac/training/data_loader.py +++ b/src/quac/training/data_loader.py @@ -40,7 +40,14 @@ def listdir(dname): chain( *[ list(Path(dname).rglob("*." + ext)) - for ext in ["png", "jpg", "jpeg", "JPG"] + for ext in [ + "png", + "jpg", + "jpeg", + "JPG", + "tiff", + "tif", + ] # TODO use Bioformats for this ] ) ) @@ -65,6 +72,7 @@ def __len__(self): return len(self.samples) +# TODO should the Augmented + Reference Datasets be combined into a single class class AugmentedDataset(data.Dataset): """Adds an augmented version of the input to the sample.""" @@ -82,10 +90,12 @@ def __init__(self, root, transform=None, augment=None): self.augment = augment def _make_dataset(self, root): - domains = glob.glob(os.path.join(root, "*")) + # Get all subitems, sorted, ignore hidden + domains = sorted(Path(root).glob("[!.]*")) + # only directories, absolute paths + domains = [d.absolute() for d in domains if d.is_dir()] fnames, labels = [], [] - for idx, domain in enumerate(sorted(domains)): - class_dir = os.path.join(root, domain) + for idx, class_dir in enumerate(domains): cls_fnames = listdir(class_dir) fnames += cls_fnames labels += [idx] * len(cls_fnames) @@ -111,10 +121,12 @@ def __init__(self, root, transform=None): self.transform = transform def _make_dataset(self, root): - domains = glob.glob(os.path.join(root, "*")) - fnames, fnames2, labels = [], [], [] - for idx, domain in enumerate(sorted(domains)): - class_dir = os.path.join(root, domain) + # Get all subitems, sorted, ignore hidden + domains = sorted(Path(root).glob("[!.]*")) + # only directories, absolute paths + domains = [d.absolute() for d in domains if d.is_dir()] + fnames, labels = [], [] + for idx, class_dir in enumerate(domains): cls_fnames = listdir(class_dir) fnames += cls_fnames fnames2 += random.sample(cls_fnames, len(cls_fnames)) From 5241ba0ab02aa06ad65c9794f3ccd10e389edd98 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Wed, 26 Feb 2025 13:39:17 -0500 Subject: [PATCH 06/17] perf: :bug: Add a check for empty datasets Partially addresses #18 --- src/quac/training/data_loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/quac/training/data_loader.py b/src/quac/training/data_loader.py index ef58e6a..5742091 100644 --- a/src/quac/training/data_loader.py +++ b/src/quac/training/data_loader.py @@ -78,6 +78,8 @@ class AugmentedDataset(data.Dataset): def __init__(self, root, transform=None, augment=None): self.samples, self.targets = self._make_dataset(root) + # Check if empty + assert len(self.samples) > 0, "Dataset is empty, no files found." self.transform = transform if augment is None: # Default augmentation: random horizontal flip, random vertical flip @@ -118,6 +120,8 @@ def __len__(self): class ReferenceDataset(data.Dataset): def __init__(self, root, transform=None): self.samples, self.targets = self._make_dataset(root) + # Check if empty + assert len(self.samples) > 0, "Dataset is empty, no files found." self.transform = transform def _make_dataset(self, root): From 302124e11bd0163159d2d191edd479ce14321f11 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Wed, 26 Feb 2025 14:10:05 -0500 Subject: [PATCH 07/17] feat: :sparkles: Allow reading tiff files This also reduces some redundancy in the code. I've switched from pillow to imageio for reading images. Closes #18 and #15. --- src/quac/training/data_loader.py | 87 ++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 33 deletions(-) diff --git a/src/quac/training/data_loader.py b/src/quac/training/data_loader.py index 5742091..ef71672 100644 --- a/src/quac/training/data_loader.py +++ b/src/quac/training/data_loader.py @@ -14,6 +14,7 @@ import os import random +import imageio from munch import Munch from PIL import Image import numpy as np @@ -47,7 +48,7 @@ def listdir(dname): "JPG", "tiff", "tif", - ] # TODO use Bioformats for this + ] ] ) ) @@ -72,24 +73,23 @@ def __len__(self): return len(self.samples) -# TODO should the Augmented + Reference Datasets be combined into a single class -class AugmentedDataset(data.Dataset): - """Adds an augmented version of the input to the sample.""" +class LabelledDataset(data.Dataset): + """A base dataset for QuAC.""" def __init__(self, root, transform=None, augment=None): self.samples, self.targets = self._make_dataset(root) # Check if empty assert len(self.samples) > 0, "Dataset is empty, no files found." self.transform = transform - if augment is None: - # Default augmentation: random horizontal flip, random vertical flip - augment = transforms.Compose( - [ - transforms.RandomHorizontalFlip(), - transforms.RandomVerticalFlip(), - ] - ) - self.augment = augment + + def _open_image(self, fname): + array = imageio.imread(fname) + # if no channel dimension, add it + if len(array.shape) == 2: + array = array[:, :, None] + # data will be h,w,c, switch to c,h,w + array = array.transpose(2, 0, 1) + return torch.from_numpy(array) def _make_dataset(self, root): # Get all subitems, sorted, ignore hidden @@ -106,7 +106,32 @@ def _make_dataset(self, root): def __getitem__(self, index): fname = self.samples[index] label = self.targets[index] - img = Image.open(fname) + img = self._open_image(fname) + if self.transform is not None: + img = self.transform(img) + return img, label + + +class AugmentedDataset(LabelledDataset): + """Adds an augmented version of the input to the sample.""" + + def __init__(self, root, transform=None, augment=None): + super().__init__(root, transform, augment) # Creates self.samples, self.targets + if self.augment is None: + # Default augmentation: random horizontal flip, random vertical flip + augment = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + ] + ) + self.augment = augment + + def __getitem__(self, index): + fname = self.samples[index] + label = self.targets[index] + img = self._open_image(fname) + # Augment the image to create a second image img2 = self.augment(img) if self.transform is not None: img = self.transform(img) @@ -117,31 +142,27 @@ def __len__(self): return len(self.targets) -class ReferenceDataset(data.Dataset): - def __init__(self, root, transform=None): - self.samples, self.targets = self._make_dataset(root) - # Check if empty - assert len(self.samples) > 0, "Dataset is empty, no files found." - self.transform = transform +class ReferenceDataset(LabelledDataset): + """A dataset that returns a reference image and a target image.""" - def _make_dataset(self, root): - # Get all subitems, sorted, ignore hidden - domains = sorted(Path(root).glob("[!.]*")) - # only directories, absolute paths - domains = [d.absolute() for d in domains if d.is_dir()] - fnames, labels = [], [] - for idx, class_dir in enumerate(domains): + def __init__(self, root, transform=None): + super().__init__(root, transform) # Creates self.samples, self.targets + # Create a second set of samples + fnames2 = [] + for fname in self.samples: + # Get the class of the current image + class_dir = Path(fname).parent + # Get a random image from the same class cls_fnames = listdir(class_dir) - fnames += cls_fnames - fnames2 += random.sample(cls_fnames, len(cls_fnames)) - labels += [idx] * len(cls_fnames) - return list(zip(fnames, fnames2)), labels + fname2 = random.choice(cls_fnames) + fnames2.append(fname2) + self.samples = list(zip(self.samples, fnames2)) def __getitem__(self, index): fname, fname2 = self.samples[index] label = self.targets[index] - img = Image.open(fname) - img2 = Image.open(fname2) + img = self._open_image(fname) + img2 = self._open_image(fname2) if self.transform is not None: img = self.transform(img) img2 = self.transform(img2) From fbc3ac4157e039a02b78ba1bb1d91046de3a4aa4 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 27 Feb 2025 11:45:00 -0500 Subject: [PATCH 08/17] ci: :green_heart: Add ruff to github actions --- .github/workflows/tests.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 384aaea..54613a5 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -4,6 +4,15 @@ on: push: jobs: + ruff: + name: Ruff + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 + with: + args: 'format --check' + test: runs-on: ubuntu-latest strategy: From 92a829e4d35429661c6fefe1962501b1a3f4e48f Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 27 Feb 2025 11:45:46 -0500 Subject: [PATCH 09/17] ci: :green_heart: Remove black from github actions --- .github/workflows/black.yaml | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 .github/workflows/black.yaml diff --git a/.github/workflows/black.yaml b/.github/workflows/black.yaml deleted file mode 100644 index 3d6329f..0000000 --- a/.github/workflows/black.yaml +++ /dev/null @@ -1,17 +0,0 @@ -name: Python Black - -on: [push, pull_request] - -jobs: - lint: - name: Python Lint - runs-on: ubuntu-latest - steps: - - name: Setup Python - uses: actions/setup-python@v4 - - name: Setup checkout - uses: actions/checkout@master - - name: Lint with Black - run: | - pip install black - black --diff --check src/quac tests From 9441bce2212f4e183b17246a7bfd83cb57fc285c Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 27 Feb 2025 11:48:29 -0500 Subject: [PATCH 10/17] ci: :green_heart: Replace black with ruff in pre-commit --- .pre-commit-config.yaml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc763e9..d67c222 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,12 +14,9 @@ repos: - id: check-yaml - id: check-added-large-files - - repo: https://github.com/psf/black - rev: 23.1.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.7 hooks: - - id: black - - # - repo: https://github.com/pre-commit/mirrors-mypy - # rev: v1.0.1 - # hooks: - # - id: mypy + - id: ruff + args: [--fix] + - id: ruff-format From 503deba218023ba36c205724214b24e6502de7eb Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 27 Feb 2025 11:56:01 -0500 Subject: [PATCH 11/17] style: :rotating_light: Run ruff on failing files and fix issues --- src/quac/report.py | 6 +++--- src/quac/training/data_loader.py | 5 +---- src/quac/training/eval.py | 12 ++++++------ src/quac/training/solver.py | 22 +++++++++------------- 4 files changed, 19 insertions(+), 26 deletions(-) diff --git a/src/quac/report.py b/src/quac/report.py index e20fb2a..99fc29d 100644 --- a/src/quac/report.py +++ b/src/quac/report.py @@ -225,9 +225,9 @@ def optimal_thresholds(self, min_percentage=0.0): max_value = np.max(mask_scores, axis=1) threshold = min_value + min_percentage * (max_value - min_value) below_threshold = mask_scores < threshold[:, None] - tradeoff_scores[ - below_threshold - ] = np.inf # Ignores the points with not enough score change + tradeoff_scores[below_threshold] = ( + np.inf + ) # Ignores the points with not enough score change thr_idx = np.argmin(tradeoff_scores, axis=1) optimal_thresholds = np.take_along_axis( diff --git a/src/quac/training/data_loader.py b/src/quac/training/data_loader.py index ef71672..33ff370 100644 --- a/src/quac/training/data_loader.py +++ b/src/quac/training/data_loader.py @@ -10,8 +10,6 @@ from pathlib import Path from itertools import chain -import glob -import os import random import imageio @@ -192,8 +190,7 @@ def get_train_loader( std=0.5, ): print( - "Preparing DataLoader to fetch %s images " - "during the training phase..." % which + "Preparing DataLoader to fetch %s images during the training phase..." % which ) crop = transforms.RandomResizedCrop(img_size, scale=[0.8, 1.0], ratio=[0.9, 1.1]) diff --git a/src/quac/training/eval.py b/src/quac/training/eval.py index cc7066f..1dac0b7 100644 --- a/src/quac/training/eval.py +++ b/src/quac/training/eval.py @@ -69,12 +69,12 @@ def calculate_metrics( std=std, grayscale=(input_dim == 1), ) - conversion_rate_values[ - "conversion_rate_%s/%s" % (mode, task) - ] = conversion_rate - translation_rate_values[ - "translation_rate_%s/%s" % (mode, task) - ] = translation_rate + conversion_rate_values["conversion_rate_%s/%s" % (mode, task)] = ( + conversion_rate + ) + translation_rate_values["translation_rate_%s/%s" % (mode, task)] = ( + translation_rate + ) # calculate the average conversion rate for all tasks conversion_rate_mean = 0 diff --git a/src/quac/training/solver.py b/src/quac/training/solver.py index f49cf63..27ae757 100644 --- a/src/quac/training/solver.py +++ b/src/quac/training/solver.py @@ -24,7 +24,6 @@ import torch.nn.functional as F from torchvision import transforms from tqdm import tqdm -import wandb transform = transforms.Compose( @@ -345,7 +344,7 @@ def log( def evaluate( self, val_loader, - iteration=None, + iteration, num_outs_per_domain=10, mode="latent", val_config=None, @@ -357,9 +356,6 @@ def evaluate( ---------- val_loader """ - if iteration is None: # Choose the iteration to evaluate - resume_iter = resume_iter - self._load_checkpoint(resume_iter) # Generate images for evaluation eval_dir = self.root_dir / "eval" @@ -378,6 +374,7 @@ def evaluate( assert mode in ["latent", "reference"] val_loader.set_mode(mode) + iter_ref = None domains = val_loader.available_targets print("Number of domains: %d" % len(domains)) @@ -414,12 +411,11 @@ def evaluate( z_trg = torch.randn(N, self.latent_dim).to(device) s_trg = self.nets_ema.mapping_network(z_trg, y_trg) else: - # x_ref = x_trg.clone() try: # TODO don't need to re-do this every time, just use # the same set of reference images for the whole dataset! x_ref = next(iter_ref).to(device) - except: + except TypeError: # iter_ref is None iter_ref = iter(loader_ref) x_ref = next(iter_ref).to(device) @@ -454,12 +450,12 @@ def evaluate( translation_rate = np.mean(predictions == trg_idx) # STORE - conversion_rate_values[ - f"conversion_rate_{mode}/" + task - ] = conversion_rate - translation_rate_values[ - f"translation_rate_{mode}/" + task - ] = translation_rate + conversion_rate_values[f"conversion_rate_{mode}/" + task] = ( + conversion_rate + ) + translation_rate_values[f"translation_rate_{mode}/" + task] = ( + translation_rate + ) # Add average conversion rate and translation rate conversion_rate_values[f"conversion_rate_{mode}/average"] = np.mean( From 5051f76627448b7c86ac16ab917ef6b2ff626bb6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Feb 2025 16:56:20 +0000 Subject: [PATCH 12/17] style(pre-commit.ci): auto fixes [...] --- docs/source/conf.py | 1 - src/quac/evaluation.py | 1 - src/quac/generate/model.py | 1 - tests/test_model.py | 1 - 4 files changed, 4 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 8230a61..85e8742 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,7 +6,6 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -from datetime import datetime import quac import tomli diff --git a/src/quac/evaluation.py b/src/quac/evaluation.py index 4608176..31bd1c5 100644 --- a/src/quac/evaluation.py +++ b/src/quac/evaluation.py @@ -1,5 +1,4 @@ import cv2 -from dataclasses import dataclass import numpy as np from pathlib import Path from quac.data import PairedImageDataset, CounterfactualDataset, PairedWithAttribution diff --git a/src/quac/generate/model.py b/src/quac/generate/model.py index 9852404..2e156c6 100644 --- a/src/quac/generate/model.py +++ b/src/quac/generate/model.py @@ -9,7 +9,6 @@ ) from quac.training.checkpoint import CheckpointIO import torch -from typing import Optional class InferenceModel(torch.nn.Module): diff --git a/tests/test_model.py b/tests/test_model.py index 3cf48ad..0483d45 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,4 +1,3 @@ -import pytest from quac.training.stargan import build_model from quac.training.config import ModelConfig import torch From e4c7b9c492aed5c15aba146c523e5823ed7dd180 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 27 Feb 2025 12:03:12 -0500 Subject: [PATCH 13/17] style: :rotating_light: Fix more ruff-raised issues. --- src/quac/evaluation.py | 1 - src/quac/generate/model.py | 1 - src/quac/training/stargan.py | 1 - src/quac/training/utils.py | 36 ++---------------------------------- 4 files changed, 2 insertions(+), 37 deletions(-) diff --git a/src/quac/evaluation.py b/src/quac/evaluation.py index 4608176..31bd1c5 100644 --- a/src/quac/evaluation.py +++ b/src/quac/evaluation.py @@ -1,5 +1,4 @@ import cv2 -from dataclasses import dataclass import numpy as np from pathlib import Path from quac.data import PairedImageDataset, CounterfactualDataset, PairedWithAttribution diff --git a/src/quac/generate/model.py b/src/quac/generate/model.py index 9852404..2e156c6 100644 --- a/src/quac/generate/model.py +++ b/src/quac/generate/model.py @@ -9,7 +9,6 @@ ) from quac.training.checkpoint import CheckpointIO import torch -from typing import Optional class InferenceModel(torch.nn.Module): diff --git a/src/quac/training/stargan.py b/src/quac/training/stargan.py index fdf4975..dbe25eb 100644 --- a/src/quac/training/stargan.py +++ b/src/quac/training/stargan.py @@ -317,7 +317,6 @@ def forward(self, x, y): x = F.interpolate(x, size=2**self.nearest_power, mode="bilinear") h = self.shared(x) h = h.view(h.size(0), -1) - out = [] s = self.output(h) return s diff --git a/src/quac/training/utils.py b/src/quac/training/utils.py index 6d7dd2f..950a08a 100644 --- a/src/quac/training/utils.py +++ b/src/quac/training/utils.py @@ -11,7 +11,6 @@ import json import matplotlib.pyplot as plt import numpy as np -from os.path import join as ospj from pathlib import Path import pandas as pd import re @@ -101,12 +100,11 @@ def translate_using_reference(self, x_src, x_ref, y_ref, filename): wb = torch.ones(1, C, H, W).to(x_src.device) x_src_with_wb = torch.cat([wb, x_src], dim=0) - masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None - s_ref = nets.style_encoder(x_ref, y_ref) + s_ref = self.nets.style_encoder(x_ref, y_ref) s_ref_list = s_ref.unsqueeze(1).repeat(1, N, 1) x_concat = [x_src_with_wb] for i, s_ref in enumerate(s_ref_list): - x_fake = nets.generator(x_src, s_ref, masks=masks) + x_fake = self.nets.generator(x_src, s_ref) x_fake_with_ref = torch.cat([x_ref[i : i + 1], x_fake], dim=0) x_concat += [x_fake_with_ref] @@ -114,36 +112,6 @@ def translate_using_reference(self, x_src, x_ref, y_ref, filename): save_image(x_concat, N + 1, filename) del x_concat - @torch.no_grad() - def debug_image(self, inputs, step): - x_src, y_src = inputs.x_src, inputs.y_src - x_ref, y_ref = inputs.x_ref, inputs.y_ref - - device = inputs.x_src.device - N = inputs.x_src.size(0) - - # translate and reconstruct (reference-guided) - filename = ospj(self.sample_dir, "%06d_cycle_consistency.jpg" % (step)) - self.translate_and_reconstruct(x_src, y_src, x_ref, y_ref, filename) - - # latent-guided image synthesis - y_trg_list = [ - torch.tensor(y).repeat(N).to(device) - for y in range(min(args.num_domains, 5)) - ] - z_trg_list = ( - torch.randn(self.num_outs_per_domain, 1, args.latent_dim) - .repeat(1, N, 1) - .to(device) - ) - for psi in [0.5, 0.7, 1.0]: - filename = ospj(self.sample_dir, "%06d_latent_psi_%.1f.jpg" % (step, psi)) - self.translate_using_latent(x_src, y_trg_list, z_trg_list, psi, filename) - - # reference-guided image synthesis - filename = ospj(self.sample_dir, "%06d_reference.jpg" % (step)) - self.translate_using_reference(x_src, x_ref, y_ref, filename) - ########################### # LOSS PLOTTING FUNCTIONS # From f1fb2bd69abb140d9b0471ee7cbbf8bdffae7ea6 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 27 Feb 2025 13:03:25 -0500 Subject: [PATCH 14/17] style: :rotating_light: Fix more ruff formatting issues --- .pre-commit-config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d67c222..7089331 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,8 +7,9 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v4.2.0 hooks: + - id: check-docstring-first - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml From 69637aab26b56f0bed9def54ca2e232efaba10ae Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 27 Feb 2025 13:12:45 -0500 Subject: [PATCH 15/17] ci: :green_heart: Remove ruff from github actions It should be handled by pre-commit. --- .github/workflows/tests.yaml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 54613a5..384aaea 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -4,15 +4,6 @@ on: push: jobs: - ruff: - name: Ruff - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: chartboost/ruff-action@v1 - with: - args: 'format --check' - test: runs-on: ubuntu-latest strategy: From cde06de3a75e8a2b0b1080de782b232d35047de1 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 27 Feb 2025 13:16:32 -0500 Subject: [PATCH 16/17] ci: Let ruff handle whitespace --- .pre-commit-config.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7089331..cb3a192 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,8 +10,6 @@ repos: rev: v4.2.0 hooks: - id: check-docstring-first - - id: trailing-whitespace - - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files From bc8106d6259516b62b560cde7396578cfcd4c74f Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Thu, 27 Feb 2025 13:20:57 -0500 Subject: [PATCH 17/17] ci: Bump up ruff version in pre-commit --- .pre-commit-config.yaml | 2 +- src/quac/generate/__init__.py | 6 +++--- src/quac/training/data_loader.py | 12 ++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cb3a192..f7a8763 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.7 + rev: v0.9.8 hooks: - id: ruff args: [--fix] diff --git a/src/quac/generate/__init__.py b/src/quac/generate/__init__.py index 1877589..6c8ff65 100644 --- a/src/quac/generate/__init__.py +++ b/src/quac/generate/__init__.py @@ -176,9 +176,9 @@ def get_counterfactual( # type: ignore # Copy x batch_size times x_multiple = torch.stack([x] * batch_size) if kind == "reference": - assert ( - dataset_ref is not None - ), "Reference dataset required for reference style." + assert dataset_ref is not None, ( + "Reference dataset required for reference style." + ) if len(dataset_ref) // batch_size < max_tries: max_tries = len(dataset_ref) // batch_size logger.warning( diff --git a/src/quac/training/data_loader.py b/src/quac/training/data_loader.py index 33ff370..7228191 100644 --- a/src/quac/training/data_loader.py +++ b/src/quac/training/data_loader.py @@ -535,15 +535,15 @@ def available_targets(self): return self._available_targets def set_target(self, target): - assert ( - target in self.available_targets - ), f"{target} not in {self.available_targets}" + assert target in self.available_targets, ( + f"{target} not in {self.available_targets}" + ) self.target = target def set_source(self, source): - assert ( - source in self.available_sources - ), f"{source} not in {self.available_sources}" + assert source in self.available_sources, ( + f"{source} not in {self.available_sources}" + ) self.source = source @property