Skip to content

Commit

Permalink
multi_gpu, train lsun datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
sxhxliang committed Nov 15, 2018
1 parent 480196c commit 9f65bd1
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 22 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ for 128\*128\*3 resolution

python main.py --batch_size 64 --dataset imagenet --adv_loss hinge --version biggan_imagenet --image_path /data/datasets

python main.py --batch_size 64 --dataset lsun --adv_loss hinge --version biggan_lsun --image_path /data1/datasets/lsun/lsun

python main.py --batch_size 64 --dataset lsun --adv_loss hinge --version biggan_lsun --image_path ./data

## Different

* not use cross-replica BatchNorm (Ioffe & Szegedy, 2015) in G
Expand Down
Binary file modified __pycache__/model_resnet.cpython-35.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def transform(self, resize, totensor, normalize, centercrop):
transform = transforms.Compose(options)
return transform

def load_lsun(self, classes='church_outdoor_train'):
def load_lsun(self, classes=['church_outdoor_train','classroom_train']):
transforms = self.transform(True, True, True, False)
dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms)
dataset = dsets.LSUN(self.path, classes=classes, transform=transforms)
return dataset

def load_imagenet(self):
Expand Down
6 changes: 5 additions & 1 deletion debug.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from model_resnet import *
from demo import *
from utils import *

dim_z = 120
vocab_size = 1000
Expand All @@ -26,8 +27,11 @@
# out = model(inputs,labels)

# print(out.size())
# model.apply(weights_init)

torch.save(model.state_dict(),'test_model.pth')

print('0,1,2,3'.split(','))
# torch.save(model.state_dict(),'test_model.pth')



3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def main(config):


config.n_class = len(glob.glob(os.path.join(config.image_path, '*/')))
print('number class:', config.n_class)
# Data loader
data_loader = Data_Loader(config.train, config.dataset, config.image_path, config.imsize,
config.batch_size, shuf=config.train)
Expand All @@ -26,6 +27,8 @@ def main(config):
make_folder(config.attn_path, config.version)


print('config data_loader and build logs folder')

if config.train:
if config.model=='sagan':
trainer = Trainer(data_loader.loader(), config)
Expand Down
18 changes: 14 additions & 4 deletions model_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def __init__(self,in_dim,activation=F.relu):
self.gamma = nn.Parameter(torch.zeros(1))

self.softmax = nn.Softmax(dim=-1) #

init_conv(self.query_conv)
init_conv(self.key_conv)
init_conv(self.value_conv)

def forward(self,x):
"""
inputs :
Expand Down Expand Up @@ -235,7 +240,6 @@ def forward(self, input, class_id):
class_emb = self.linear(class_id) # 128

out = self.G_linear(codes[0])
# print(out)
# out = out.view(-1, 1536, 4, 4)
out = out.view(-1, self.first_view, 4, 4)
ids = 1
Expand Down Expand Up @@ -268,9 +272,11 @@ def conv(in_channel, out_channel, downsample=True):
upsample=False, downsample=downsample)

gain = 2 ** 0.5


if debug:
chn = 8
self.debug = debug

self.pre_conv = nn.Sequential(SpectralNorm(nn.Conv2d(3, 1*chn, 3,padding=1),),
nn.ReLU(),
Expand All @@ -293,6 +299,7 @@ def conv(in_channel, out_channel, downsample=True):
self.embed = spectral_norm(self.embed)

def forward(self, input, class_id):

out = self.pre_conv(input)
out = out + self.pre_skip(F.avg_pool2d(input, 2))
# print(out.size())
Expand All @@ -303,9 +310,12 @@ def forward(self, input, class_id):
out_linear = self.linear(out).squeeze(1)
embed = self.embed(class_id)

# print(out_linear.size())
# print(embed.size())

prod = (out * embed).sum(1)

# if self.debug == debug:
# print('class_id',class_id.size())
# print('out_linear',out_linear.size())
# print('embed', embed.size())
# print('prod', prod.size())

return out_linear + prod
6 changes: 4 additions & 2 deletions parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def get_parameters():
parser.add_argument('--adv_loss', type=str, default='wgan-gp', choices=['wgan-gp', 'hinge'])
parser.add_argument('--imsize', type=int, default=128)
parser.add_argument('--g_num', type=int, default=5)
parser.add_argument('--chn', type=int, default=64)
parser.add_argument('--z_dim', type=int, default=120)
parser.add_argument('--g_conv_dim', type=int, default=64)
parser.add_argument('--d_conv_dim', type=int, default=64)
Expand All @@ -22,7 +23,7 @@ def get_parameters():
parser.add_argument('--total_step', type=int, default=1000000, help='how many times to update the generator')
parser.add_argument('--d_iters', type=float, default=5)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--num_workers', type=int, default=12)
parser.add_argument('--g_lr', type=float, default=0.0001)
parser.add_argument('--d_lr', type=float, default=0.0004)
parser.add_argument('--lr_decay', type=float, default=0.95)
Expand All @@ -35,11 +36,12 @@ def get_parameters():
# Misc
parser.add_argument('--train', type=str2bool, default=True)
parser.add_argument('--parallel', type=str2bool, default=False)
parser.add_argument('--gpus', type=str, default='0', help='gpuids eg: 0,1,2,3 --parallel True ')
parser.add_argument('--dataset', type=str, default='lsun', choices=['lsun', 'celeb','off'])
parser.add_argument('--use_tensorboard', type=str2bool, default=False)

# Path
parser.add_argument('--image_path', type=str, default='/Users/AaronLeong/deeplearning/test_img/root/folder2')
parser.add_argument('--image_path', type=str, default='./data')
parser.add_argument('--log_path', type=str, default='./logs')
parser.add_argument('--model_save_path', type=str, default='./models')
parser.add_argument('--sample_path', type=str, default='./samples')
Expand Down
35 changes: 23 additions & 12 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, data_loader, config):
self.g_conv_dim = config.g_conv_dim
self.d_conv_dim = config.d_conv_dim
self.parallel = config.parallel
self.gpus = config.gpus

self.lambda_gp = config.lambda_gp
self.total_step = config.total_step
Expand All @@ -53,6 +54,7 @@ def __init__(self, data_loader, config):
self.version = config.version

self.n_class = config.n_class
self.chn = config.chn

# Path
self.log_path = os.path.join(config.log_path, self.version)
Expand All @@ -61,22 +63,22 @@ def __init__(self, data_loader, config):

self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('build_model...')
self.build_model()

if self.use_tensorboard:
self.build_tensorboard()

# Start with trained model
if self.pretrained_model:
print('load_pretrained_model...')
self.load_pretrained_model()


def label_sampel(self):
label = torch.LongTensor(self.batch_size, 1).random_()%self.n_class
one_hot= torch.zeros(self.batch_size, self.n_class).scatter_(1, label, 1).to(self.device)

return label.to(self.device), one_hot

one_hot= torch.zeros(self.batch_size, self.n_class).scatter_(1, label, 1)
return label.squeeze(1).to(self.device), one_hot.to(self.device)

def train(self):

Expand All @@ -95,6 +97,7 @@ def train(self):
start = 0

# Start time
print('Start ====== training...')
start_time = time.time()
for step in range(start, self.total_step):

Expand All @@ -109,8 +112,8 @@ def train(self):
real_images, real_labels = next(data_iter)

# Compute loss with real images
real_labels = real_labels.unsqueeze(0)

real_labels = real_labels.to(self.device)
real_images = real_images.to(self.device)

d_out_real = self.D(real_images, real_labels)
Expand Down Expand Up @@ -187,12 +190,13 @@ def train(self):
if (step + 1) % self.log_step == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}".
print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, d_out_fake: {:.4f}, g_loss_fake: {:.4f}".
format(elapsed, step + 1, self.total_step, (step + 1),
self.total_step , d_loss_real.data[0]))
self.total_step , d_loss_real.item(), d_loss_fake.item(), g_loss_fake.item()))

# Sample images
if (step + 1) % self.sample_step == 0:
print('Sample images {}_fake.png'.format(step + 1))
fake_images= self.G(fixed_z, z_class_one_hot)
save_image(denorm(fake_images.data),
os.path.join(self.sample_path, '{}_fake.png'.format(step + 1)))
Expand All @@ -205,11 +209,18 @@ def train(self):

def build_model(self):
# code_dim=100, n_class=1000
self.G = Generator(self.z_dim, self.n_class).to(self.device)
self.D = Discriminator(self.n_class).to(self.device)
self.G = Generator(self.z_dim, self.n_class, chn=self.chn).to(self.device)
self.D = Discriminator(self.n_class, chn=self.chn).to(self.device)
if self.parallel:
self.G = nn.DataParallel(self.G)
self.D = nn.DataParallel(self.D)
print('use parallel...')
print('gpuids ', self.gpus)
gpus = self.gpus.split(',')

self.G = nn.DataParallel(self.G, device_ids=gpus)
self.D = nn.DataParallel(self.D, device_ids=gpus)

# self.G.apply(weights_init)
# self.D.apply(weights_init)

# Loss and optimizer
# self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
Expand Down
12 changes: 11 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import torch
from torch.autograd import Variable

from torch.nn import init

def make_folder(path, version):
if not os.path.exists(os.path.join(path, version)):
Expand All @@ -23,3 +23,13 @@ def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)

def weights_init(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv2d') != -1:
init.xavier_normal_(m.weight.data)
init.constant_(m.bias.data, 0.0)
elif classname.find('Linear') != -1:
init.xavier_normal_(m.weight.data)
init.constant_(m.bias.data, 0.0)

0 comments on commit 9f65bd1

Please sign in to comment.