-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
144 lines (100 loc) · 4.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import shutil
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets as dsets, transforms
from models.gamma_vae import gamma_vae, compute_gamma
from models.normal_vae import gaussian_vae, compute_gaussian
from utils import AverageMeter, Logger, Model, ReshapeTransform, save_model, binarized_mnist_fixed_binarization
from train import train
from val import val
if torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor')
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="normal",
help='VAE type (normal or gamma)')
parser.add_argument('--dataset', type=str, default="mnist",
help='Dataset (cifar-10, b_mnist or mnist)')
parser.add_argument('--epochs', type=int, default=1001,
help='number of epochs')
parser.add_argument('--b_size', type=int, default=128,
help='batch size')
parser.add_argument('--z_dim', type=int, default=4,
help='size of the latent space')
opt = parser.parse_args()
assert opt.model in ['normal', 'gamma'], "Model {} is not supported.".format(opt.model)
assert opt.dataset in ['cifar-10','mnist','b_mnist'], "Dataset {} is not supported".format(opt.dataset)
data_folder = os.path.join(".","data")
if not os.path.isdir(data_folder):
os.makedirs(data_folder)
if opt.dataset == 'cifar-10':
t_dataset = dsets.CIFAR10(root=os.path.join(".",'data'), train=True,
download=True, transform=transforms.Compose([transforms.ToTensor()]))
v_dataset = dsets.CIFAR10(root=os.path.join(".",'data'), train=False,
download=True, transform=transforms.Compose([transforms.ToTensor()]))
if os.path.exists(os.path.join(data_folder,'cifar-10-python.tar.gz')):
os.remove(os.path.join(data_folder,'cifar-10-python.tar.gz'))
elif opt.dataset == 'mnist':
t_dataset = dsets.MNIST(root=os.path.join(".",'data'), train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(), ReshapeTransform((-1,))]))
v_dataset = dsets.MNIST(root=os.path.join(".",'data'), train=False, download=True,
transform=transforms.Compose([transforms.ToTensor(), ReshapeTransform((-1,))]))
if os.path.exists(os.path.join(".","data","MNIST","raw")):
shutil.rmtree(os.path.join(".","data","MNIST","raw"))
elif opt.dataset == 'b_mnist':
bpath = os.path.join(".","data","b_mnist")
train_data, validation_data, test_data = binarized_mnist_fixed_binarization(bpath)
t_dataset = TensorDataset(torch.from_numpy(train_data))
v_dataset = TensorDataset(torch.from_numpy(validation_data))
t_generator = DataLoader(t_dataset, batch_size=opt.b_size, num_workers=8, shuffle=True)
v_generator = DataLoader(v_dataset, batch_size=opt.b_size, num_workers=8, shuffle=True)
recons_meter = AverageMeter()
kl_meter = AverageMeter()
taotal_meter = AverageMeter()
metrics = [recons_meter, kl_meter, taotal_meter]
logger_list = ["Epoch","Recons","KL","Full"]
suffix_logger = "{}_{}_{}_{}".format(opt.model,opt.dataset,opt.b_size,opt.z_dim)
train_logger = Logger(os.path.join(data_folder,'train_{}.log'.format(suffix_logger)),logger_list)
val_logger = Logger(os.path.join(data_folder,'val_{}.log'.format(suffix_logger)),logger_list + ["MLikeli"])
if opt.model == 'normal':
vae_model = gaussian_vae(opt.dataset)
compute_vae = compute_gaussian
elif opt.model == 'gamma':
vae_model = gamma_vae(opt.dataset)
compute_vae = compute_gamma
maps_folder = os.path.join(data_folder, "maps", opt.model)
if not os.path.isdir(maps_folder):
os.makedirs(os.path.join(maps_folder,"train"))
os.makedirs(os.path.join(maps_folder,"val"))
models_folder = os.path.join(data_folder, "models")
if not os.path.isdir(models_folder):
os.makedirs(models_folder)
print("{} model chosen.\n".format(opt.model))
vae = Model(vae_model,z_dim=opt.z_dim)
best_loss = float("inf")
best_epoch = -1
for epoch in range(opt.epochs):
for m in metrics:
m.reset()
print("====== Epoch {} ======".format(epoch))
train(epoch, vae, t_generator, compute_vae, metrics, (models_folder, maps_folder), opt, train_logger)
vae_loss,log_p_x = val(epoch, vae, v_generator, compute_vae, metrics, (models_folder, maps_folder), opt, val_logger)
is_best = False
if vae_loss < best_loss:
best_loss = vae_loss
best_epoch = epoch
is_best = True
internal_state = {
'model':opt.model,
'dataset': opt.dataset,
'z_dim': opt.z_dim,
'current_epoch': epoch,
'best_epoch': best_epoch,
'best_loss': best_loss,
'model_vae_state_dict': vae.vae.state_dict(),
'optimizer_vae_state_dict': vae.vae_optimizer.state_dict()
}
save_model(internal_state, models_folder, is_best, epoch, opt.model)
train_logger.close()
val_logger.close()