This repository has been archived by the owner on Jul 3, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 7db02d8
Showing
24 changed files
with
7,184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
orig_inception_moments/ | ||
__pycache__ | ||
logs/ | ||
*.npz | ||
**.npz | ||
**.tgz | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
|
||
from PIL import Image | ||
import os | ||
import pickle | ||
import torch | ||
import random | ||
import torchvision | ||
|
||
from torchvision.datasets.vision import VisionDataset | ||
from torchvision import datasets, transforms, utils | ||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
|
||
class CocoAnimals(VisionDataset): | ||
def __init__(self, root=None, batch_size = 80, classes = None, transform=None, return_all=False , test_mode = False, imsize=128): | ||
self.num_classes = len(classes) | ||
if root==None: | ||
root = os.path.join(os.environ["SSD"],"animals") | ||
|
||
self.root = root | ||
self.return_all = return_all | ||
self.names = classes | ||
self.nclasses = len(classes) | ||
print(self.names) | ||
self.mask_trans =transforms.Compose( | ||
[ transforms.Resize(imsize), | ||
transforms.CenterCrop(imsize), | ||
transforms.ToTensor(), | ||
]) | ||
self.fixed_transform = transforms.Compose( | ||
[ transforms.Resize(imsize), | ||
transforms.CenterCrop(imsize), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
]) | ||
|
||
self.image_files = [os.listdir(os.path.join(self.root, n )) for n in self.names] | ||
self.lenghts = [len(f) for f in self.image_files] | ||
|
||
print("images per class ", self.lenghts) | ||
self.transform = transform | ||
self.totensor = transforms.ToTensor() | ||
self.batch_size = batch_size | ||
|
||
self.k = int(batch_size/self.num_classes) | ||
self.length = sum([len(folder) for folder in self.image_files]) | ||
|
||
all_files = [ [os.path.join(self.root, n , f ) for f in files] for n, files in zip(self.names,self.image_files)] | ||
self.all_files = [] | ||
self.all_labels = [] | ||
for i, f in enumerate(all_files): | ||
self.all_files += f | ||
self.all_labels += [i]*len(f) | ||
|
||
with open( os.path.join(root, "merged_bbdict_v2.p"), "rb") as h: | ||
self.bbox = pickle.load(h) | ||
|
||
|
||
with open(os.path.join(root,"coco_ann_dict.p"),"rb") as h: | ||
self.mask_coco = pickle.load(h) | ||
|
||
self.fixed_files = [] | ||
self.fixed_impaths = [] | ||
self.fixed_labels = [] | ||
|
||
for _ in range(batch_size): | ||
id = np.random.randint(self.nclasses) | ||
file = random.choice(self.image_files[id]) | ||
image_path = os.path.join(self.root, self.names[id], file) | ||
self.fixed_files.append(file) | ||
self.fixed_impaths.append(image_path) | ||
self.fixed_labels.append(id) | ||
self.fixed_labels = torch.tensor(self.fixed_labels).long().cuda() | ||
|
||
|
||
def __len__(self): | ||
# Here, we need to return the number of samples in this dataset. | ||
return self.length | ||
|
||
def fixed_batch(self, return_labels = False): | ||
if return_labels == True: | ||
images = torch.stack([self.random_batch(0,0,file=fi,image_path=im)[0].cuda() for fi,im in zip(self.fixed_files,self.fixed_impaths) ]) | ||
labels = self.fixed_labels | ||
return images, labels | ||
else: | ||
return torch.stack([self.random_batch(0,0,file=fi,image_path=im)[0].cuda() for fi,im in zip(self.fixed_files,self.fixed_impaths) ]) | ||
|
||
def single_batch(self): | ||
w = [self.__getitem__(np.random.randint(self.length)) for _ in range(self.batch_size) ] | ||
x = torch.stack([e[0].cuda() for e in w]) | ||
y = torch.stack([e[1].long().cuda() for e in w]) | ||
return x, y | ||
|
||
def random_batch(self,index, id=0, file=None, image_path=None): | ||
# this function adds some data augmentation by cropping and resizing the | ||
# images with the help of the bounding boxes. In particular we make sure | ||
# that the animal is still in the frame if we crop randomly and resize. | ||
|
||
if image_path==None: | ||
id = np.random.randint(self.nclasses) | ||
file = random.choice(self.image_files[id]) | ||
image_path = os.path.join(self.root, self.names[id], file) | ||
|
||
#image_name = self.random.choice(self.image_files[random.choice(self.names)]["image_files"]) | ||
fixed = False | ||
usebbx = True | ||
else: | ||
fixed = True | ||
usebbx = False | ||
|
||
img = Image.open( image_path ).convert('RGB') | ||
|
||
w,h = img.size | ||
im_id = file.strip(".jpg") | ||
if usebbx: | ||
if im_id.isdigit(): | ||
origin = "coco" | ||
bbox = {"bbox":self.mask_coco[file]["bbox"], "label": self.mask_coco[file]["category_id"]} | ||
else: | ||
origin = "oi" | ||
bbox = self.bbox[file.strip(".jpg")] | ||
|
||
if len(bbox["bbox"])>0: | ||
usebbx = True | ||
else: | ||
usebbx = False | ||
|
||
if usebbx: | ||
if isinstance(bbox["bbox"][0], list): | ||
idx = random.choice(np.arange(len(bbox["bbox"]))) | ||
bbox = bbox["bbox"][idx] # choose a random bbox from the list | ||
else: | ||
bbox = bbox["bbox"] | ||
|
||
if usebbx: | ||
if origin == "coco": | ||
a = bbox[0] | ||
b = bbox[1] | ||
c = bbox[2] | ||
d = bbox[3] | ||
else: | ||
a = float(bbox[0])*w | ||
b = float(bbox[1])*w | ||
c = float(bbox[2])*h | ||
d = float(bbox[3])*h | ||
|
||
a, b, c, d = a, c, b-a, d-c | ||
|
||
eps = 0.0001 | ||
longer_side = max(h,d) | ||
r_max = min(float(longer_side)/(d+eps), float(longer_side)/(c+eps)) | ||
r_min = 1.5 | ||
|
||
if r_max > r_min and w > 200 and h > 200 and c*d > 30*30: | ||
r = 1 + np.random.rand()*(r_max-1) | ||
d_new = r*d | ||
c_new = r*c | ||
|
||
longer_side = min ( max(c_new ,d_new ) , h, w) | ||
|
||
d_new = max(longer_side, 150) | ||
c_new = max(longer_side, 150) | ||
|
||
a_new = max(0, a - 0.5*(c_new - c) ) | ||
b_new = max(0, b - 0.5*(d_new - d) ) | ||
|
||
if c_new+a_new > w: | ||
a_new = max(0,a_new - ((c_new+a_new)-w)) | ||
if d_new+b_new>h: | ||
b_new = max(0,b_new - ((d_new+b_new)-h)) | ||
|
||
c_new = c_new + a_new | ||
d_new = d_new + b_new | ||
|
||
img = img.crop((a_new,b_new,c_new,d_new)) | ||
|
||
idx = image_path | ||
|
||
if not fixed: | ||
img = self.transform(img) | ||
elif fixed: | ||
img = self.fixed_transform(img) | ||
|
||
label = torch.LongTensor([id]) | ||
|
||
bbox = self.bbox[file.strip(".jpg")] | ||
out = (img, label , idx) | ||
|
||
return out | ||
|
||
def exact_batch(self,index): | ||
|
||
image_path = self.all_files[index] | ||
img = Image.open( image_path ).convert('RGB') | ||
img = self.transform(img) | ||
id = self.all_labels[index] | ||
label = torch.LongTensor([id]) | ||
return img, label , image_path | ||
|
||
def __getitem__(self,index): | ||
|
||
if self.return_all: | ||
return self.exact_batch(index) | ||
else: | ||
return self.random_batch(index) | ||
|
||
|
||
class FFHQ(VisionDataset): | ||
|
||
def __init__(self, root, transform, batch_size = 60, test_mode = False, return_all = False, imsize=256): | ||
|
||
self.root = root | ||
self.transform = transform | ||
self.return_all = return_all | ||
|
||
print("root:",self.root) | ||
all_folders = os.listdir(self.root) | ||
|
||
self.length = sum([len(os.listdir(os.path.join(self.root,folder))) for folder in all_folders]) # = 70000 | ||
self.fixed_transform = transforms.Compose( | ||
[ transforms.Resize(imsize), | ||
transforms.CenterCrop(imsize), | ||
#transforms.RandomHorizontalFlip(), | ||
#transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
]) | ||
|
||
self.fixed_indices = [] | ||
|
||
for _ in range(batch_size): | ||
id = np.random.randint(self.length) | ||
self.fixed_indices.append(id) | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
|
||
def fixed_batch(self, random = False): | ||
if random == False: | ||
return torch.stack([self.random_batch(idx, True)[0].cuda() for idx in self.fixed_indices]) | ||
else: | ||
return torch.stack([self.random_batch(np.random.randint(self.length), True)[0].cuda() for _ in range(len(self.fixed_indices))]) | ||
|
||
def random_batch(self,index, fixed=False): | ||
|
||
folder = str(int(np.floor(index/1000)*1000)).zfill(5) | ||
file = str(index).zfill(5) + ".png" | ||
image_path = os.path.join(self.root, folder , file ) | ||
img = Image.open( image_path).convert('RGB') | ||
if fixed: | ||
img = self.fixed_transform(img) | ||
else: | ||
img = self.transform(img) | ||
|
||
|
||
return img, torch.zeros(1).long(), image_path | ||
|
||
def __getitem__(self,index): | ||
|
||
if self.return_all: | ||
return self.exact_batch(index) | ||
else: | ||
return self.random_batch(index) | ||
|
||
|
||
class Celeba(VisionDataset): | ||
|
||
def __init__(self, root, transform, batch_size = 60, test_mode = False, return_all = False, imsize=128): | ||
|
||
self.root = root | ||
self.transform = transform | ||
self.return_all = return_all | ||
all_files = os.listdir(self.root) | ||
self.length = len(all_files) | ||
self.fixed_transform = transforms.Compose( | ||
[ transforms.Resize(imsize), | ||
transforms.CenterCrop(imsize), | ||
#transforms.RandomHorizontalFlip(), | ||
#transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
]) | ||
|
||
self.fixed_indices = [] | ||
|
||
for _ in range(batch_size): | ||
id = np.random.randint(self.length) | ||
self.fixed_indices.append(id) | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
|
||
def fixed_batch(self): | ||
return torch.stack([self.random_batch(idx, True)[0].cuda() for idx in self.fixed_indices]) | ||
|
||
|
||
def random_batch(self,index, fixed=False): | ||
|
||
file = str(index+1).zfill(6) + '.png' | ||
image_path = os.path.join(self.root, file ) | ||
img = Image.open( image_path).convert('RGB') | ||
if fixed: | ||
img = self.fixed_transform(img) | ||
else: | ||
img = self.transform(img) | ||
|
||
return img, torch.zeros(1).long(), image_path | ||
|
||
def __getitem__(self,index): | ||
|
||
if self.return_all: | ||
return self.exact_batch(index) | ||
else: | ||
return self.random_batch(index) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# U-Net GAN PyTorch | ||
|
||
<p align="center"> | ||
<img src="teaser_final.png"> | ||
</p> | ||
|
||
PyTorch implementation of the CVPR 2020 paper "A U-Net Based Discriminator for Generative Adversarial Networks". The paper can | ||
be found [here](https://openaccess.thecvf.com/content_CVPR_2020/html/Schonfeld_A_U-Net_Based_Discriminator_for_Generative_Adversarial_Networks_CVPR_2020_paper.html). The code allows the users to | ||
reproduce and extend the results reported in the study.Please cite the | ||
above paper when reporting, reproducing or extending the results. | ||
|
||
## Purpose of the project | ||
|
||
This software is a research prototype, solely developed for and published as | ||
part of the publication. It will neither be | ||
maintained nor monitored in any way. | ||
|
||
## Setup | ||
|
||
Create the conda environment "unetgan" from the provided unetgan.yml file. The experiments can be reproduced with the scripts provided in the folder training_scripts (the experiment folder and dataset folder has to be set manually). | ||
|
||
|Argument|Explanation| | ||
|---|---| | ||
|--unconditional | Use this if the dataset does not have classes (e.g. CelebA).| | ||
|--unet_mixup | Use CutMix. | | ||
|--slow_mixup | Use warmup for the CutMix-augmentation loss.| | ||
|--slow_mixup_epochs | Number of epochs for the warmup | | ||
|--full_batch_mixup | If True, a coin is tossed at every training step. With a certain probability the whole batch is mixed and the CutMix augmentation loss and consistency_loss is the only loss that is computed for this batch. The probability increases from 0 to 0.5 over the course of the specified warmup epochs. If False, the CutMix augmentation and consistency loss are computed for every batch and added to the default GAN loss. In the case of a warmup, the augmentation loss is multiplied with a factor that increases from 0 to 1 over the course of the specified warmup epochs.| | ||
|--consistency_loss | Compute only the CutMix consistency loss, but not the CutMix augmentation loss (Can increase stability but might perform worse). | | ||
|--consistency_loss_and_augmentation | Compute both CutMix augmentation and consistency loss.| | ||
|
||
|
||
## Details | ||
|
||
This implementation of U-Net GAN is based on the PyTorch code for BigGAN (https://github.com/ajbrock/BigGAN-PyTorch). The main differences are that (1) we use our own data-loader which does not require HDF5 pre-processing, (2) applied changes in the generator and discriminator class in BigGAN.py, and (3) modified train.py and train_fns.py. | ||
|
||
## License | ||
|
||
U-Net GAN PyTorch is open-sourced under the AGPL-3.0 license. See the | ||
[LICENSE](LICENSE) file for details. | ||
|
||
For a list of other open source components included in PROJECT-NAME, see the | ||
file [3rd-party-licenses.txt](3rd-party-licenses.txt). |
Oops, something went wrong.