Skip to content

Commit

Permalink
automatically add ddp kwarg handler
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 1, 2023
1 parent 822934a commit d719eb0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@ First train your VAE - `VQGanVAE`
```python
import torch
from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer
from accelerate import DistributedDataParallelKwargs

vae = VQGanVAE(
dim = 256,
vq_codebook_size = 512
)

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) # this is for ddp

# train on folder of images, as many images as possible

trainer = VQGanVAETrainer(
Expand All @@ -34,8 +31,7 @@ trainer = VQGanVAETrainer(
folder = '/path/to/images',
batch_size = 4,
grad_accum_every = 8,
num_train_steps = 50000,
accelerate_kwargs={'kwargs_handlers': [ddp_kwargs]}
num_train_steps = 50000
).cuda()

trainer.train()
Expand Down
17 changes: 16 additions & 1 deletion muse_maskgit_pytorch/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from einops import rearrange

from accelerate import Accelerator, DistributedType
from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs

from ema_pytorch import EMA

Expand Down Expand Up @@ -123,10 +123,23 @@ def __init__(
accelerate_kwargs: dict = dict()
):
super().__init__()

# instantiate accelerator

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)

kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', [])
kwargs_handlers.append(ddp_kwargs)
accelerate_kwargs.update(kwargs_handlers = kwargs_handlers)

self.accelerator = Accelerator(**accelerate_kwargs)

# vae

self.vae = vae

# training params

self.register_buffer('steps', torch.Tensor([0]))

self.num_train_steps = num_train_steps
Expand All @@ -139,6 +152,8 @@ def __init__(

self.vae_parameters = vae_parameters

# optimizers

self.optim = Adam(vae_parameters, lr = lr)
self.discr_optim = Adam(discr_parameters, lr = lr)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.25',
version = '0.0.26',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d719eb0

Please sign in to comment.