From d719eb01eff29ff83d0053e9b5c32d2f59ddb57f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 1 Feb 2023 08:01:11 -0800 Subject: [PATCH] automatically add ddp kwarg handler --- README.md | 6 +----- muse_maskgit_pytorch/trainers.py | 17 ++++++++++++++++- setup.py | 2 +- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index a15668e..70544c6 100644 --- a/README.md +++ b/README.md @@ -17,15 +17,12 @@ First train your VAE - `VQGanVAE` ```python import torch from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer -from accelerate import DistributedDataParallelKwargs vae = VQGanVAE( dim = 256, vq_codebook_size = 512 ) -ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) # this is for ddp - # train on folder of images, as many images as possible trainer = VQGanVAETrainer( @@ -34,8 +31,7 @@ trainer = VQGanVAETrainer( folder = '/path/to/images', batch_size = 4, grad_accum_every = 8, - num_train_steps = 50000, - accelerate_kwargs={'kwargs_handlers': [ddp_kwargs]} + num_train_steps = 50000 ).cuda() trainer.train() diff --git a/muse_maskgit_pytorch/trainers.py b/muse_maskgit_pytorch/trainers.py index eba0552..dde06a5 100644 --- a/muse_maskgit_pytorch/trainers.py +++ b/muse_maskgit_pytorch/trainers.py @@ -18,7 +18,7 @@ from einops import rearrange -from accelerate import Accelerator, DistributedType +from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs from ema_pytorch import EMA @@ -123,10 +123,23 @@ def __init__( accelerate_kwargs: dict = dict() ): super().__init__() + + # instantiate accelerator + + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) + + kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', []) + kwargs_handlers.append(ddp_kwargs) + accelerate_kwargs.update(kwargs_handlers = kwargs_handlers) + self.accelerator = Accelerator(**accelerate_kwargs) + # vae + self.vae = vae + # training params + self.register_buffer('steps', torch.Tensor([0])) self.num_train_steps = num_train_steps @@ -139,6 +152,8 @@ def __init__( self.vae_parameters = vae_parameters + # optimizers + self.optim = Adam(vae_parameters, lr = lr) self.discr_optim = Adam(discr_parameters, lr = lr) diff --git a/setup.py b/setup.py index bbd923d..ff4cefb 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.25', + version = '0.0.26', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',