diff --git a/.github/workflows/black.yaml b/.github/workflows/black.yaml deleted file mode 100644 index 3d6329f..0000000 --- a/.github/workflows/black.yaml +++ /dev/null @@ -1,17 +0,0 @@ -name: Python Black - -on: [push, pull_request] - -jobs: - lint: - name: Python Lint - runs-on: ubuntu-latest - steps: - - name: Setup Python - uses: actions/setup-python@v4 - - name: Setup checkout - uses: actions/checkout@master - - name: Lint with Black - run: | - pip install black - black --diff --check src/quac tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc763e9..d92b6d1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,19 +7,15 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v5.0.0 hooks: - - id: trailing-whitespace - - id: end-of-file-fixer + - id: check-docstring-first - id: check-yaml - id: check-added-large-files - - repo: https://github.com/psf/black - rev: 23.1.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.8 hooks: - - id: black - - # - repo: https://github.com/pre-commit/mirrors-mypy - # rev: v1.0.1 - # hooks: - # - id: mypy + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/docs/source/conf.py b/docs/source/conf.py index 8230a61..85e8742 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,7 +6,6 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -from datetime import datetime import quac import tomli diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst index 57c9bf0..01845e4 100644 --- a/docs/source/tutorials.rst +++ b/docs/source/tutorials.rst @@ -14,7 +14,7 @@ QuAC, or Quantitative Attributions with Counterfactuals, is a method for generat Let's assume, for instance, that you have images of cells grown in two different conditions. To your eye, the phenotypic difference between the two conditions is hidden within the cell-to-cell variability of the dataset, but you know it is there because you've trained a classifier to differentiate the two conditions and it works. So how do you pull out the differences? -We begin by training a generative neural network to convert your images from one class to another. Here, we'll use a StarGAN. This allows us to go from our real, **query** image, to our **generated** image. +Assuming that you already have a classifier that does your task, we begin by training a generative neural network to convert your images from one class to another. Here, we'll use a StarGAN. This allows us to go from our real, **query** image, to our **generated** image. Using information learned from **reference** images, the StarGAN is trained in such a way that the **generated** image will have a different class! While very powerful, these generative networks *can potentially* make some changes that are not necessary to the classification. @@ -28,33 +28,56 @@ It is as close as possible to the original image, with only the necessary change :width: 800 :align: center -Before you begin, download [the data]() and [the pre-trained models]() for an example. -Then, make sure you've installed QuAC by following the :doc:`Installation guide `. +Before you begin, make sure you've installed QuAC by following the :doc:`Installation guide `. + +The classifier +============== +To use QuAC, we assume that you have a classifier trained on your data. +There are many different packages already to help you do that, for this code-base we will need you to have the weights to your classifier as a JIT-compiled `pytorch` model. + +If you just want to try QuAC as a learning experience, you can use one of the datasets `in this collection `_, and the pre-trained models we provide. The conversion network =============================== +Once you've set up your data and your classifier, you can move on to training the conversion network. +We'll use a StarGAN for this. +There are two options for training the StarGAN, but we recommend :doc:`training it using a YAML file `. +This will make it easier to keep track of your experiments! +If you prefer to define parameters directly in Python, however, you can follow the :doc:`alternative training tutorial ` instead. +Note that in both cases, you will need the JIT-compiled classifier model! -You have two options for training the StarGAN, you can either :doc:`define parameters directly in Python ` or :doc:`train it using a YAML file `. -We recommend the latter, which will make it easier to keep track of your experiments! -Once you've trained a decent model, generate a set of images using the :doc:`image generation tutorial ` before moving on to the next steps. +Once you've trained a decent model, you can generate a set of images using the :doc:`image generation tutorial `. +We recommend taking a look at your generated images, to make sure that they look like what you expect. +If that is the case, you can move on to the next steps! .. toctree:: :maxdepth: 1 - tutorials/train - tutorials/train_yaml - Generating images Attribution and evaluation ========================== -With the generated images in hand, we can now run the attribution and evaluation steps. +With the generated images in hand, we can now run the :doc:`attribution ` step, then the :doc:`evaluation ` step. These two steps allow us to overcome the limitations of the generative network to create *truly* minimal counterfactual images, and to score the query-counterfactual pairs based on how well they explain the classifier. +Visualizing results +=================== + +Finally, we can visualize the results of the attribution and evaluation steps using the :doc:`visualization tutorial `. +This will allow you to see the quantification results, in the form of QuAC curves. +It will also help you choose the best attribution method for each example, and load the counterfactual visual explanations for these examples. + +Table of Contents +================= +Here's a list of all available tutorials, in case you want to navigate directly to one of them. + .. toctree:: :maxdepth: 1 + Training the generator (recommended) + Training the generator (alternative) + Generating images Attribution Evaluation Visualizing results diff --git a/docs/source/tutorials/generate.rst b/docs/source/tutorials/generate.rst index 0775941..c3d2a49 100644 --- a/docs/source/tutorials/generate.rst +++ b/docs/source/tutorials/generate.rst @@ -4,10 +4,6 @@ How to generate images from a pre-trained network ================================================= -.. attention:: - This tutorial is still under construction. Come back soon for updates! - - Defining the dataset ==================== @@ -22,7 +18,7 @@ For example, below, we are going to be using the validation data, and our source from quac.generate import load_data img_size = 224 - data_directory = Path("root_directory/val/0_No_DR") + data_directory = Path("/path/to/directory/holding/the/data/source_class") dataset = load_data(data_directory, img_size, grayscale=False) @@ -86,7 +82,7 @@ Finally, we can run the image generation. from quac.generate import get_counterfactual from torchvision.utils import save_image - output_directory = Path("/path/to/output/latent/0_No_DR/1_Mild/") + output_directory = Path("/path/to/output/latent/source_class/target_class/") for x, name in tqdm(dataset): xcf = get_counterfactual( @@ -117,7 +113,7 @@ The first thing we need to do is to get the reference images. .. code-block:: python :linenos: - reference_data_directory = Path(f"{root_directory}/val/1_Mild") + reference_data_directory = Path("/path/to/directory/holding/the/data/target_class") reference_dataset = load_data(reference_data_directory, img_size, grayscale=False) Loading the StarGAN @@ -148,7 +144,7 @@ Finally, we combine the two by changing the `kind` in our counterfactual generat from torchvision.utils import save_image - output_directory = Path("/path/to/output/reference/0_No_DR/1_Mild/") + output_directory = Path("/path/to/output/reference/source_class/target_class/") for x, name in tqdm(dataset): xcf = get_counterfactual( diff --git a/docs/source/tutorials/visualize.rst b/docs/source/tutorials/visualize.rst index 3bd9022..b690ebe 100644 --- a/docs/source/tutorials/visualize.rst +++ b/docs/source/tutorials/visualize.rst @@ -2,9 +2,169 @@ Visualizing the results ======================= -.. attention:: - This tutorial is still under construction. Come back soon for updates! +In this tutorial, we will show you how to visualize the results of the attribution and evaluation steps. +Make sure to modify the paths to the reports and the classifier to match your setup! - .. image:: ../assets/quac.png - :width: 100 - :align: center +Obtaining the QuAC curves +========================= +Let's start by loading the reports obtained in the previous step. + +.. code-block:: python + :linenos: + + from quac.report import Report + + report_directory = "/path/to/report/directory/" + reports = { + method: Report(name=method) + for method in [ + "DDeepLift", + "DIntegratedGradients", + ] + } + + for method, report in reports.items(): + report.load(report_directory + method + "/default.json") + +Next, we can plot the QuAC curves for each method. +This allows us to get an idea of how well each method is performing, overall. + +.. code-block:: python + :linenos: + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + for method, report in reports.items(): + report.plot_curve(ax=ax) + # Add the legend + plt.legend() + plt.show() + + +Choosing the best attribution method for each sample +==================================================== + +While one attribution method may be better than another on average, it is possible that the best method for a given example is different. +Therefore, we will make a list of the best method for each example by comparing the quac scores. + +.. code-block:: python + :linenos: + + quac_scores = pd.DataFrame( + {method: report.quac_scores for method, report in reports.items()} + ) + best_methods = quac_scores.idxmax(axis=1) + best_quac_scores = quac_scores.max(axis=1) + +We'll also want to load the classifier at this point, so we can look at the classifications of the counterfactual images. + +.. code-block:: python + :linenos: + + import torch + + classifier = torch.jit.load("/path/to/classifier/model.pt") + + +Choosing the best examples +========================== +Next we want to choose the best example, given the best method. +This is done by ordering the examples by the QuAC score, and then choosing the one with the highest score. + +.. code-block:: python + :linenos: + + order = best_quac_scores[::-1].argsort() + + # For example, choose the 10th best example + idx = 10 + # Get the corresponding report + report = reports[best_methods[order[idx]]] + +We will then load that example and its counterfactual from its path, and visualize it. +We also want to see the classification of both the original and the counterfactual. + +.. code-block:: python + :linenos: + + # Transform to apply to the images so they match each other + # loading + from PIL import Image + + image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]] + image, generated_image = Image.open(image_path), Image.open(generated_path) + + prediction = report.predictions[order[idx]] + target_prediction = report.target_predictions[order[idx]] + + image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]] + image, generated_image = Image.open(image_path), Image.open(generated_path) + + prediction = report.predictions[order[idx]] + target_prediction = report.target_predictions[order[idx]] + +Loading the attribution +======================= +We next want to load the attribution for the example, and visualize it. + +.. code-block:: python + :linenos: + + attribution_path = report.attribution_paths[order[idx]] + attribution = np.load(attribution_path) + +Getting the processor +===================== +We want to see the specific mask that was optimal in this case. +To do this, we will need to get the optimal threshold, and get the processor used for masking. + +.. code-block:: python + :linenos: + + from quac.evaluation import Processor + + gaussian_kernel_size = 11 + struc = 10 + thresh = report.optimal_thresholds()[order[idx]] + print(thresh) + processor = Processor(gaussian_kernel_size=gaussian_kernel_size, struc=struc) + + mask, _ = processor.create_mask(attribution, thresh) + rgb_mask = mask.transpose(1, 2, 0) + # zero-out the green and blue channels + rgb_mask[:, :, 1] = 0 + rgb_mask[:, :, 2] = 0 + counterfactual = np.array(generated_image) / 255 * rgb_mask + np.array(image) / 255 * (1.0 - rgb_mask) + +Let's also get the classifier output for the counterfactual image. + +.. code-block:: python + :linenos: + + classifier_output = classifier( + torch.tensor(counterfactual).permute(2, 0, 1).float().unsqueeze(0).to(device) + ) + counterfactual_prediction = softmax(classifier_output[0].detach().cpu().numpy()) + +Visualizing the results +======================= +Finally, we can visualize the results. + +.. code-block:: python + :linenos: + + fig, axes = plt.subplots(2, 4) + axes[1, 0].imshow(image) + axes[0, 0].bar(np.arange(len(prediction)), prediction) + axes[1, 1].imshow(generated_image) + axes[0, 1].bar(np.arange(len(target_prediction)), target_prediction) + axes[0, 2].bar(np.arange(len(counterfactual_prediction)), counterfactual_prediction) + axes[1, 2].imshow(counterfactual) + axes[1, 3].imshow(rgb_mask) + axes[0, 3].axis("off") + fig.suptitle(f"QuAC Score: {report.quac_scores[order[idx]]}") + plt.show() + +You can now see the original image, the generated image, the counterfactual image, and the mask. +From here, you can choose to visualize other examples, of save the images for later use. diff --git a/src/quac/evaluation.py b/src/quac/evaluation.py index 4608176..31bd1c5 100644 --- a/src/quac/evaluation.py +++ b/src/quac/evaluation.py @@ -1,5 +1,4 @@ import cv2 -from dataclasses import dataclass import numpy as np from pathlib import Path from quac.data import PairedImageDataset, CounterfactualDataset, PairedWithAttribution diff --git a/src/quac/generate/__init__.py b/src/quac/generate/__init__.py index 1877589..6c8ff65 100644 --- a/src/quac/generate/__init__.py +++ b/src/quac/generate/__init__.py @@ -176,9 +176,9 @@ def get_counterfactual( # type: ignore # Copy x batch_size times x_multiple = torch.stack([x] * batch_size) if kind == "reference": - assert ( - dataset_ref is not None - ), "Reference dataset required for reference style." + assert dataset_ref is not None, ( + "Reference dataset required for reference style." + ) if len(dataset_ref) // batch_size < max_tries: max_tries = len(dataset_ref) // batch_size logger.warning( diff --git a/src/quac/generate/model.py b/src/quac/generate/model.py index 9852404..2e156c6 100644 --- a/src/quac/generate/model.py +++ b/src/quac/generate/model.py @@ -9,7 +9,6 @@ ) from quac.training.checkpoint import CheckpointIO import torch -from typing import Optional class InferenceModel(torch.nn.Module): diff --git a/src/quac/report.py b/src/quac/report.py index e20fb2a..99fc29d 100644 --- a/src/quac/report.py +++ b/src/quac/report.py @@ -225,9 +225,9 @@ def optimal_thresholds(self, min_percentage=0.0): max_value = np.max(mask_scores, axis=1) threshold = min_value + min_percentage * (max_value - min_value) below_threshold = mask_scores < threshold[:, None] - tradeoff_scores[ - below_threshold - ] = np.inf # Ignores the points with not enough score change + tradeoff_scores[below_threshold] = ( + np.inf + ) # Ignores the points with not enough score change thr_idx = np.argmin(tradeoff_scores, axis=1) optimal_thresholds = np.take_along_axis( diff --git a/src/quac/training/data_loader.py b/src/quac/training/data_loader.py index 7e944e4..7228191 100644 --- a/src/quac/training/data_loader.py +++ b/src/quac/training/data_loader.py @@ -10,10 +10,9 @@ from pathlib import Path from itertools import chain -import glob -import os import random +import imageio from munch import Munch from PIL import Image import numpy as np @@ -40,7 +39,14 @@ def listdir(dname): chain( *[ list(Path(dname).rglob("*." + ext)) - for ext in ["png", "jpg", "jpeg", "JPG"] + for ext in [ + "png", + "jpg", + "jpeg", + "JPG", + "tiff", + "tif", + ] ] ) ) @@ -65,13 +71,51 @@ def __len__(self): return len(self.samples) -class AugmentedDataset(data.Dataset): - """Adds an augmented version of the input to the sample.""" +class LabelledDataset(data.Dataset): + """A base dataset for QuAC.""" def __init__(self, root, transform=None, augment=None): self.samples, self.targets = self._make_dataset(root) + # Check if empty + assert len(self.samples) > 0, "Dataset is empty, no files found." self.transform = transform - if augment is None: + + def _open_image(self, fname): + array = imageio.imread(fname) + # if no channel dimension, add it + if len(array.shape) == 2: + array = array[:, :, None] + # data will be h,w,c, switch to c,h,w + array = array.transpose(2, 0, 1) + return torch.from_numpy(array) + + def _make_dataset(self, root): + # Get all subitems, sorted, ignore hidden + domains = sorted(Path(root).glob("[!.]*")) + # only directories, absolute paths + domains = [d.absolute() for d in domains if d.is_dir()] + fnames, labels = [], [] + for idx, class_dir in enumerate(domains): + cls_fnames = listdir(class_dir) + fnames += cls_fnames + labels += [idx] * len(cls_fnames) + return fnames, labels + + def __getitem__(self, index): + fname = self.samples[index] + label = self.targets[index] + img = self._open_image(fname) + if self.transform is not None: + img = self.transform(img) + return img, label + + +class AugmentedDataset(LabelledDataset): + """Adds an augmented version of the input to the sample.""" + + def __init__(self, root, transform=None, augment=None): + super().__init__(root, transform, augment) # Creates self.samples, self.targets + if self.augment is None: # Default augmentation: random horizontal flip, random vertical flip augment = transforms.Compose( [ @@ -81,20 +125,11 @@ def __init__(self, root, transform=None, augment=None): ) self.augment = augment - def _make_dataset(self, root): - domains = glob.glob(os.path.join(root, "*")) - fnames, labels = [], [] - for idx, domain in enumerate(sorted(domains)): - class_dir = os.path.join(root, domain) - cls_fnames = listdir(class_dir) - fnames += cls_fnames - labels += [idx] * len(cls_fnames) - return fnames, labels - def __getitem__(self, index): fname = self.samples[index] label = self.targets[index] - img = Image.open(fname) + img = self._open_image(fname) + # Augment the image to create a second image img2 = self.augment(img) if self.transform is not None: img = self.transform(img) @@ -105,27 +140,27 @@ def __len__(self): return len(self.targets) -class ReferenceDataset(data.Dataset): - def __init__(self, root, transform=None): - self.samples, self.targets = self._make_dataset(root) - self.transform = transform +class ReferenceDataset(LabelledDataset): + """A dataset that returns a reference image and a target image.""" - def _make_dataset(self, root): - domains = glob.glob(os.path.join(root, "*")) - fnames, fnames2, labels = [], [], [] - for idx, domain in enumerate(sorted(domains)): - class_dir = os.path.join(root, domain) + def __init__(self, root, transform=None): + super().__init__(root, transform) # Creates self.samples, self.targets + # Create a second set of samples + fnames2 = [] + for fname in self.samples: + # Get the class of the current image + class_dir = Path(fname).parent + # Get a random image from the same class cls_fnames = listdir(class_dir) - fnames += cls_fnames - fnames2 += random.sample(cls_fnames, len(cls_fnames)) - labels += [idx] * len(cls_fnames) - return list(zip(fnames, fnames2)), labels + fname2 = random.choice(cls_fnames) + fnames2.append(fname2) + self.samples = list(zip(self.samples, fnames2)) def __getitem__(self, index): fname, fname2 = self.samples[index] label = self.targets[index] - img = Image.open(fname) - img2 = Image.open(fname2) + img = self._open_image(fname) + img2 = self._open_image(fname2) if self.transform is not None: img = self.transform(img) img2 = self.transform(img2) @@ -155,8 +190,7 @@ def get_train_loader( std=0.5, ): print( - "Preparing DataLoader to fetch %s images " - "during the training phase..." % which + "Preparing DataLoader to fetch %s images during the training phase..." % which ) crop = transforms.RandomResizedCrop(img_size, scale=[0.8, 1.0], ratio=[0.9, 1.1]) @@ -501,15 +535,15 @@ def available_targets(self): return self._available_targets def set_target(self, target): - assert ( - target in self.available_targets - ), f"{target} not in {self.available_targets}" + assert target in self.available_targets, ( + f"{target} not in {self.available_targets}" + ) self.target = target def set_source(self, source): - assert ( - source in self.available_sources - ), f"{source} not in {self.available_sources}" + assert source in self.available_sources, ( + f"{source} not in {self.available_sources}" + ) self.source = source @property diff --git a/src/quac/training/eval.py b/src/quac/training/eval.py index cc7066f..1dac0b7 100644 --- a/src/quac/training/eval.py +++ b/src/quac/training/eval.py @@ -69,12 +69,12 @@ def calculate_metrics( std=std, grayscale=(input_dim == 1), ) - conversion_rate_values[ - "conversion_rate_%s/%s" % (mode, task) - ] = conversion_rate - translation_rate_values[ - "translation_rate_%s/%s" % (mode, task) - ] = translation_rate + conversion_rate_values["conversion_rate_%s/%s" % (mode, task)] = ( + conversion_rate + ) + translation_rate_values["translation_rate_%s/%s" % (mode, task)] = ( + translation_rate + ) # calculate the average conversion rate for all tasks conversion_rate_mean = 0 diff --git a/src/quac/training/solver.py b/src/quac/training/solver.py index f49cf63..27ae757 100644 --- a/src/quac/training/solver.py +++ b/src/quac/training/solver.py @@ -24,7 +24,6 @@ import torch.nn.functional as F from torchvision import transforms from tqdm import tqdm -import wandb transform = transforms.Compose( @@ -345,7 +344,7 @@ def log( def evaluate( self, val_loader, - iteration=None, + iteration, num_outs_per_domain=10, mode="latent", val_config=None, @@ -357,9 +356,6 @@ def evaluate( ---------- val_loader """ - if iteration is None: # Choose the iteration to evaluate - resume_iter = resume_iter - self._load_checkpoint(resume_iter) # Generate images for evaluation eval_dir = self.root_dir / "eval" @@ -378,6 +374,7 @@ def evaluate( assert mode in ["latent", "reference"] val_loader.set_mode(mode) + iter_ref = None domains = val_loader.available_targets print("Number of domains: %d" % len(domains)) @@ -414,12 +411,11 @@ def evaluate( z_trg = torch.randn(N, self.latent_dim).to(device) s_trg = self.nets_ema.mapping_network(z_trg, y_trg) else: - # x_ref = x_trg.clone() try: # TODO don't need to re-do this every time, just use # the same set of reference images for the whole dataset! x_ref = next(iter_ref).to(device) - except: + except TypeError: # iter_ref is None iter_ref = iter(loader_ref) x_ref = next(iter_ref).to(device) @@ -454,12 +450,12 @@ def evaluate( translation_rate = np.mean(predictions == trg_idx) # STORE - conversion_rate_values[ - f"conversion_rate_{mode}/" + task - ] = conversion_rate - translation_rate_values[ - f"translation_rate_{mode}/" + task - ] = translation_rate + conversion_rate_values[f"conversion_rate_{mode}/" + task] = ( + conversion_rate + ) + translation_rate_values[f"translation_rate_{mode}/" + task] = ( + translation_rate + ) # Add average conversion rate and translation rate conversion_rate_values[f"conversion_rate_{mode}/average"] = np.mean( diff --git a/src/quac/training/stargan.py b/src/quac/training/stargan.py index fdf4975..dbe25eb 100644 --- a/src/quac/training/stargan.py +++ b/src/quac/training/stargan.py @@ -317,7 +317,6 @@ def forward(self, x, y): x = F.interpolate(x, size=2**self.nearest_power, mode="bilinear") h = self.shared(x) h = h.view(h.size(0), -1) - out = [] s = self.output(h) return s diff --git a/src/quac/training/utils.py b/src/quac/training/utils.py index 6d7dd2f..950a08a 100644 --- a/src/quac/training/utils.py +++ b/src/quac/training/utils.py @@ -11,7 +11,6 @@ import json import matplotlib.pyplot as plt import numpy as np -from os.path import join as ospj from pathlib import Path import pandas as pd import re @@ -101,12 +100,11 @@ def translate_using_reference(self, x_src, x_ref, y_ref, filename): wb = torch.ones(1, C, H, W).to(x_src.device) x_src_with_wb = torch.cat([wb, x_src], dim=0) - masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None - s_ref = nets.style_encoder(x_ref, y_ref) + s_ref = self.nets.style_encoder(x_ref, y_ref) s_ref_list = s_ref.unsqueeze(1).repeat(1, N, 1) x_concat = [x_src_with_wb] for i, s_ref in enumerate(s_ref_list): - x_fake = nets.generator(x_src, s_ref, masks=masks) + x_fake = self.nets.generator(x_src, s_ref) x_fake_with_ref = torch.cat([x_ref[i : i + 1], x_fake], dim=0) x_concat += [x_fake_with_ref] @@ -114,36 +112,6 @@ def translate_using_reference(self, x_src, x_ref, y_ref, filename): save_image(x_concat, N + 1, filename) del x_concat - @torch.no_grad() - def debug_image(self, inputs, step): - x_src, y_src = inputs.x_src, inputs.y_src - x_ref, y_ref = inputs.x_ref, inputs.y_ref - - device = inputs.x_src.device - N = inputs.x_src.size(0) - - # translate and reconstruct (reference-guided) - filename = ospj(self.sample_dir, "%06d_cycle_consistency.jpg" % (step)) - self.translate_and_reconstruct(x_src, y_src, x_ref, y_ref, filename) - - # latent-guided image synthesis - y_trg_list = [ - torch.tensor(y).repeat(N).to(device) - for y in range(min(args.num_domains, 5)) - ] - z_trg_list = ( - torch.randn(self.num_outs_per_domain, 1, args.latent_dim) - .repeat(1, N, 1) - .to(device) - ) - for psi in [0.5, 0.7, 1.0]: - filename = ospj(self.sample_dir, "%06d_latent_psi_%.1f.jpg" % (step, psi)) - self.translate_using_latent(x_src, y_trg_list, z_trg_list, psi, filename) - - # reference-guided image synthesis - filename = ospj(self.sample_dir, "%06d_reference.jpg" % (step)) - self.translate_using_reference(x_src, x_ref, y_ref, filename) - ########################### # LOSS PLOTTING FUNCTIONS # diff --git a/tests/test_model.py b/tests/test_model.py index 3cf48ad..0483d45 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,4 +1,3 @@ -import pytest from quac.training.stargan import build_model from quac.training.config import ModelConfig import torch