From 448669160e1aa67bd8270a603f2141d55ee28318 Mon Sep 17 00:00:00 2001 From: Roller44 Date: Tue, 11 Jul 2023 10:27:12 +0800 Subject: [PATCH 1/3] correct error of not using cuda when CUDA is available. --- vqvae.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vqvae.py b/vqvae.py index 25221ec..bd8fb7d 100644 --- a/vqvae.py +++ b/vqvae.py @@ -64,7 +64,8 @@ def main(args): if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + transforms.Normalize((0.5), (0.5)) ]) if args.dataset == 'mnist': # Define the train & test datasets @@ -150,9 +151,9 @@ def main(args): parser = argparse.ArgumentParser(description='VQ-VAE') # General - parser.add_argument('--data-folder', type=str, + parser.add_argument('--data-folder', type=str, default='./data', help='name of the data folder') - parser.add_argument('--dataset', type=str, + parser.add_argument('--dataset', type=str, default='mnist', help='name of the dataset (mnist, fashion-mnist, cifar10, miniimagenet)') # Latent space @@ -176,7 +177,7 @@ def main(args): help='name of the output folder (default: vqvae)') parser.add_argument('--num-workers', type=int, default=mp.cpu_count() - 1, help='number of workers for trajectories sampling (default: {0})'.format(mp.cpu_count() - 1)) - parser.add_argument('--device', type=str, default='cpu', + parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='set the device (cpu or cuda, default: cpu)') args = parser.parse_args() @@ -187,8 +188,7 @@ def main(args): if not os.path.exists('./models'): os.makedirs('./models') # Device - args.device = torch.device(args.device - if torch.cuda.is_available() else 'cpu') + args.device = torch.device(args.device) # Slurm if 'SLURM_JOB_ID' in os.environ: args.output_folder += '-{0}'.format(os.environ['SLURM_JOB_ID']) From a39c2458d7a3d1ae9180577635252ef66fc51093 Mon Sep 17 00:00:00 2001 From: Roller44 Date: Tue, 11 Jul 2023 10:49:33 +0800 Subject: [PATCH 2/3] correct error shape of transform for mnist dataset. --- vqvae.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vqvae.py b/vqvae.py index bd8fb7d..589090f 100644 --- a/vqvae.py +++ b/vqvae.py @@ -64,7 +64,6 @@ def main(args): if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']: transform = transforms.Compose([ transforms.ToTensor(), - # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.Normalize((0.5), (0.5)) ]) if args.dataset == 'mnist': From 9e27a9759de3cc3c06e9364fa91ea79621315515 Mon Sep 17 00:00:00 2001 From: Roller44 Date: Wed, 12 Jul 2023 15:59:03 +0800 Subject: [PATCH 3/3] Speed up distance calculation. --- functions.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/functions.py b/functions.py index 7ebd869..e76912f 100644 --- a/functions.py +++ b/functions.py @@ -8,13 +8,9 @@ def forward(ctx, inputs, codebook): embedding_size = codebook.size(1) inputs_size = inputs.size() inputs_flatten = inputs.view(-1, embedding_size) - - codebook_sqr = torch.sum(codebook ** 2, dim=1) - inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True) - + # Compute the distances to the codebook - distances = torch.addmm(codebook_sqr + inputs_sqr, - inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0) + distances = torch.cdist(inputs_flatten, codebook, 2) _, indices_flatten = torch.min(distances, dim=1) indices = indices_flatten.view(*inputs_size[:-1])