Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#19 from bzantium/fix/17-enable-ddp-training
Browse files Browse the repository at this point in the history
enable distributed training with accelerate launch
  • Loading branch information
lucidrains authored Feb 1, 2023
2 parents 2cc82dc + d83c86c commit 822934a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 17 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ First train your VAE - `VQGanVAE`
```python
import torch
from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer
from accelerate import DistributedDataParallelKwargs

vae = VQGanVAE(
dim = 256,
vq_codebook_size = 512
)

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) # this is for ddp

# train on folder of images, as many images as possible

trainer = VQGanVAETrainer(
Expand All @@ -31,7 +34,8 @@ trainer = VQGanVAETrainer(
folder = '/path/to/images',
batch_size = 4,
grad_accum_every = 8,
num_train_steps = 50000
num_train_steps = 50000,
accelerate_kwargs={'kwargs_handlers': [ddp_kwargs]}
).cuda()

trainer.train()
Expand Down
36 changes: 20 additions & 16 deletions muse_maskgit_pytorch/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from einops import rearrange

from accelerate import Accelerator
from accelerate import Accelerator, DistributedType

from ema_pytorch import EMA

Expand Down Expand Up @@ -127,10 +127,6 @@ def __init__(

self.vae = vae

self.use_ema = use_ema
if self.is_main and use_ema:
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)

self.register_buffer('steps', torch.Tensor([0]))

self.num_train_steps = num_train_steps
Expand Down Expand Up @@ -194,6 +190,12 @@ def __init__(
self.valid_dl
)

self.use_ema = use_ema

if use_ema:
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
self.ema_vae = self.accelerator.prepare(self.ema_vae)

self.dl_iter = cycle(self.dl)
self.valid_dl_iter = cycle(self.valid_dl)

Expand Down Expand Up @@ -257,6 +259,8 @@ def train_step(self):
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)

self.vae.train()
discr = self.vae.module.discr if self.is_distributed else self.vae.discr
ema_vae = self.ema_vae.module if self.is_distributed else self.ema_vae

# logs

Expand Down Expand Up @@ -287,7 +291,7 @@ def train_step(self):

# update discriminator

if exists(self.vae.discr):
if exists(discr):
self.discr_optim.zero_grad()

for _ in range(self.grad_accum_every):
Expand All @@ -301,7 +305,7 @@ def train_step(self):
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})

if exists(self.discr_max_grad_norm):
self.accelerator.clip_grad_norm_(self.vae.discr.parameters(), self.discr_max_grad_norm)
self.accelerator.clip_grad_norm_(discr.parameters(), self.discr_max_grad_norm)

self.discr_optim.step()

Expand All @@ -311,16 +315,16 @@ def train_step(self):

# update exponential moving averaged generator

if self.is_main and self.use_ema:
self.ema_vae.update()
if self.use_ema:
ema_vae.update()

# sample results every so often

if self.is_main and not (steps % self.save_results_every):
if not (steps % self.save_results_every):
vaes_to_evaluate = ((self.vae, str(steps)),)

if self.use_ema:
vaes_to_evaluate = ((self.ema_vae.ema_model, f'{steps}.ema'),) + vaes_to_evaluate
vaes_to_evaluate = ((ema_vae.ema_model, f'{steps}.ema'),) + vaes_to_evaluate

for model, filename in vaes_to_evaluate:
model.eval()
Expand All @@ -345,16 +349,16 @@ def train_step(self):
self.print(f'{steps}: saving to {str(self.results_folder)}')

# save model every so often

self.accelerator.wait_for_everyone()
if self.is_main and not (steps % self.save_model_every):
state_dict = self.vae.state_dict()
state_dict = self.accelerator.unwrap_model(self.vae).state_dict()
model_path = str(self.results_folder / f'vae.{steps}.pt')
torch.save(state_dict, model_path)
self.accelerator.save(state_dict, model_path)

if self.use_ema:
ema_state_dict = self.ema_vae.state_dict()
ema_state_dict = self.accelerator.unwrap_model(self.ema_vae).state_dict()
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
torch.save(ema_state_dict, model_path)
self.accelerator.save(ema_state_dict, model_path)

self.print(f'{steps}: saving model to {str(self.results_folder)}')

Expand Down

0 comments on commit 822934a

Please sign in to comment.