Skip to content
This repository has been archived by the owner on Jul 3, 2024. It is now read-only.

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarschnfld committed Sep 1, 2020
0 parents commit 7db02d8
Show file tree
Hide file tree
Showing 24 changed files with 7,184 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
orig_inception_moments/
__pycache__
logs/
*.npz
**.npz
**.tgz

971 changes: 971 additions & 0 deletions 3rd-party-licenses.txt

Large diffs are not rendered by default.

638 changes: 638 additions & 0 deletions BigGAN.py

Large diffs are not rendered by default.

661 changes: 661 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

316 changes: 316 additions & 0 deletions PyTorchDatasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@

from PIL import Image
import os
import pickle
import torch
import random
import torchvision

from torchvision.datasets.vision import VisionDataset
from torchvision import datasets, transforms, utils
import numpy as np
from matplotlib import pyplot as plt

class CocoAnimals(VisionDataset):
def __init__(self, root=None, batch_size = 80, classes = None, transform=None, return_all=False , test_mode = False, imsize=128):
self.num_classes = len(classes)
if root==None:
root = os.path.join(os.environ["SSD"],"animals")

self.root = root
self.return_all = return_all
self.names = classes
self.nclasses = len(classes)
print(self.names)
self.mask_trans =transforms.Compose(
[ transforms.Resize(imsize),
transforms.CenterCrop(imsize),
transforms.ToTensor(),
])
self.fixed_transform = transforms.Compose(
[ transforms.Resize(imsize),
transforms.CenterCrop(imsize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

self.image_files = [os.listdir(os.path.join(self.root, n )) for n in self.names]
self.lenghts = [len(f) for f in self.image_files]

print("images per class ", self.lenghts)
self.transform = transform
self.totensor = transforms.ToTensor()
self.batch_size = batch_size

self.k = int(batch_size/self.num_classes)
self.length = sum([len(folder) for folder in self.image_files])

all_files = [ [os.path.join(self.root, n , f ) for f in files] for n, files in zip(self.names,self.image_files)]
self.all_files = []
self.all_labels = []
for i, f in enumerate(all_files):
self.all_files += f
self.all_labels += [i]*len(f)

with open( os.path.join(root, "merged_bbdict_v2.p"), "rb") as h:
self.bbox = pickle.load(h)


with open(os.path.join(root,"coco_ann_dict.p"),"rb") as h:
self.mask_coco = pickle.load(h)

self.fixed_files = []
self.fixed_impaths = []
self.fixed_labels = []

for _ in range(batch_size):
id = np.random.randint(self.nclasses)
file = random.choice(self.image_files[id])
image_path = os.path.join(self.root, self.names[id], file)
self.fixed_files.append(file)
self.fixed_impaths.append(image_path)
self.fixed_labels.append(id)
self.fixed_labels = torch.tensor(self.fixed_labels).long().cuda()


def __len__(self):
# Here, we need to return the number of samples in this dataset.
return self.length

def fixed_batch(self, return_labels = False):
if return_labels == True:
images = torch.stack([self.random_batch(0,0,file=fi,image_path=im)[0].cuda() for fi,im in zip(self.fixed_files,self.fixed_impaths) ])
labels = self.fixed_labels
return images, labels
else:
return torch.stack([self.random_batch(0,0,file=fi,image_path=im)[0].cuda() for fi,im in zip(self.fixed_files,self.fixed_impaths) ])

def single_batch(self):
w = [self.__getitem__(np.random.randint(self.length)) for _ in range(self.batch_size) ]
x = torch.stack([e[0].cuda() for e in w])
y = torch.stack([e[1].long().cuda() for e in w])
return x, y

def random_batch(self,index, id=0, file=None, image_path=None):
# this function adds some data augmentation by cropping and resizing the
# images with the help of the bounding boxes. In particular we make sure
# that the animal is still in the frame if we crop randomly and resize.

if image_path==None:
id = np.random.randint(self.nclasses)
file = random.choice(self.image_files[id])
image_path = os.path.join(self.root, self.names[id], file)

#image_name = self.random.choice(self.image_files[random.choice(self.names)]["image_files"])
fixed = False
usebbx = True
else:
fixed = True
usebbx = False

img = Image.open( image_path ).convert('RGB')

w,h = img.size
im_id = file.strip(".jpg")
if usebbx:
if im_id.isdigit():
origin = "coco"
bbox = {"bbox":self.mask_coco[file]["bbox"], "label": self.mask_coco[file]["category_id"]}
else:
origin = "oi"
bbox = self.bbox[file.strip(".jpg")]

if len(bbox["bbox"])>0:
usebbx = True
else:
usebbx = False

if usebbx:
if isinstance(bbox["bbox"][0], list):
idx = random.choice(np.arange(len(bbox["bbox"])))
bbox = bbox["bbox"][idx] # choose a random bbox from the list
else:
bbox = bbox["bbox"]

if usebbx:
if origin == "coco":
a = bbox[0]
b = bbox[1]
c = bbox[2]
d = bbox[3]
else:
a = float(bbox[0])*w
b = float(bbox[1])*w
c = float(bbox[2])*h
d = float(bbox[3])*h

a, b, c, d = a, c, b-a, d-c

eps = 0.0001
longer_side = max(h,d)
r_max = min(float(longer_side)/(d+eps), float(longer_side)/(c+eps))
r_min = 1.5

if r_max > r_min and w > 200 and h > 200 and c*d > 30*30:
r = 1 + np.random.rand()*(r_max-1)
d_new = r*d
c_new = r*c

longer_side = min ( max(c_new ,d_new ) , h, w)

d_new = max(longer_side, 150)
c_new = max(longer_side, 150)

a_new = max(0, a - 0.5*(c_new - c) )
b_new = max(0, b - 0.5*(d_new - d) )

if c_new+a_new > w:
a_new = max(0,a_new - ((c_new+a_new)-w))
if d_new+b_new>h:
b_new = max(0,b_new - ((d_new+b_new)-h))

c_new = c_new + a_new
d_new = d_new + b_new

img = img.crop((a_new,b_new,c_new,d_new))

idx = image_path

if not fixed:
img = self.transform(img)
elif fixed:
img = self.fixed_transform(img)

label = torch.LongTensor([id])

bbox = self.bbox[file.strip(".jpg")]
out = (img, label , idx)

return out

def exact_batch(self,index):

image_path = self.all_files[index]
img = Image.open( image_path ).convert('RGB')
img = self.transform(img)
id = self.all_labels[index]
label = torch.LongTensor([id])
return img, label , image_path

def __getitem__(self,index):

if self.return_all:
return self.exact_batch(index)
else:
return self.random_batch(index)


class FFHQ(VisionDataset):

def __init__(self, root, transform, batch_size = 60, test_mode = False, return_all = False, imsize=256):

self.root = root
self.transform = transform
self.return_all = return_all

print("root:",self.root)
all_folders = os.listdir(self.root)

self.length = sum([len(os.listdir(os.path.join(self.root,folder))) for folder in all_folders]) # = 70000
self.fixed_transform = transforms.Compose(
[ transforms.Resize(imsize),
transforms.CenterCrop(imsize),
#transforms.RandomHorizontalFlip(),
#transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

self.fixed_indices = []

for _ in range(batch_size):
id = np.random.randint(self.length)
self.fixed_indices.append(id)

def __len__(self):
return self.length


def fixed_batch(self, random = False):
if random == False:
return torch.stack([self.random_batch(idx, True)[0].cuda() for idx in self.fixed_indices])
else:
return torch.stack([self.random_batch(np.random.randint(self.length), True)[0].cuda() for _ in range(len(self.fixed_indices))])

def random_batch(self,index, fixed=False):

folder = str(int(np.floor(index/1000)*1000)).zfill(5)
file = str(index).zfill(5) + ".png"
image_path = os.path.join(self.root, folder , file )
img = Image.open( image_path).convert('RGB')
if fixed:
img = self.fixed_transform(img)
else:
img = self.transform(img)


return img, torch.zeros(1).long(), image_path

def __getitem__(self,index):

if self.return_all:
return self.exact_batch(index)
else:
return self.random_batch(index)


class Celeba(VisionDataset):

def __init__(self, root, transform, batch_size = 60, test_mode = False, return_all = False, imsize=128):

self.root = root
self.transform = transform
self.return_all = return_all
all_files = os.listdir(self.root)
self.length = len(all_files)
self.fixed_transform = transforms.Compose(
[ transforms.Resize(imsize),
transforms.CenterCrop(imsize),
#transforms.RandomHorizontalFlip(),
#transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

self.fixed_indices = []

for _ in range(batch_size):
id = np.random.randint(self.length)
self.fixed_indices.append(id)

def __len__(self):
return self.length


def fixed_batch(self):
return torch.stack([self.random_batch(idx, True)[0].cuda() for idx in self.fixed_indices])


def random_batch(self,index, fixed=False):

file = str(index+1).zfill(6) + '.png'
image_path = os.path.join(self.root, file )
img = Image.open( image_path).convert('RGB')
if fixed:
img = self.fixed_transform(img)
else:
img = self.transform(img)

return img, torch.zeros(1).long(), image_path

def __getitem__(self,index):

if self.return_all:
return self.exact_batch(index)
else:
return self.random_batch(index)
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# U-Net GAN PyTorch

<p align="center">
<img src="teaser_final.png">
</p>

PyTorch implementation of the CVPR 2020 paper "A U-Net Based Discriminator for Generative Adversarial Networks". The paper can
be found [here](https://openaccess.thecvf.com/content_CVPR_2020/html/Schonfeld_A_U-Net_Based_Discriminator_for_Generative_Adversarial_Networks_CVPR_2020_paper.html). The code allows the users to
reproduce and extend the results reported in the study.Please cite the
above paper when reporting, reproducing or extending the results.

## Purpose of the project

This software is a research prototype, solely developed for and published as
part of the publication. It will neither be
maintained nor monitored in any way.

## Setup

Create the conda environment "unetgan" from the provided unetgan.yml file. The experiments can be reproduced with the scripts provided in the folder training_scripts (the experiment folder and dataset folder has to be set manually).

|Argument|Explanation|
|---|---|
|--unconditional | Use this if the dataset does not have classes (e.g. CelebA).|
|--unet_mixup | Use CutMix. |
|--slow_mixup | Use warmup for the CutMix-augmentation loss.|
|--slow_mixup_epochs | Number of epochs for the warmup |
|--full_batch_mixup | If True, a coin is tossed at every training step. With a certain probability the whole batch is mixed and the CutMix augmentation loss and consistency_loss is the only loss that is computed for this batch. The probability increases from 0 to 0.5 over the course of the specified warmup epochs. If False, the CutMix augmentation and consistency loss are computed for every batch and added to the default GAN loss. In the case of a warmup, the augmentation loss is multiplied with a factor that increases from 0 to 1 over the course of the specified warmup epochs.|
|--consistency_loss | Compute only the CutMix consistency loss, but not the CutMix augmentation loss (Can increase stability but might perform worse). |
|--consistency_loss_and_augmentation | Compute both CutMix augmentation and consistency loss.|


## Details

This implementation of U-Net GAN is based on the PyTorch code for BigGAN (https://github.com/ajbrock/BigGAN-PyTorch). The main differences are that (1) we use our own data-loader which does not require HDF5 pre-processing, (2) applied changes in the generator and discriminator class in BigGAN.py, and (3) modified train.py and train_fns.py.

## License

U-Net GAN PyTorch is open-sourced under the AGPL-3.0 license. See the
[LICENSE](LICENSE) file for details.

For a list of other open source components included in PROJECT-NAME, see the
file [3rd-party-licenses.txt](3rd-party-licenses.txt).
Loading

0 comments on commit 7db02d8

Please sign in to comment.