Skip to content

Commit

Permalink
add trainer for vae
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 10, 2023
1 parent f227737 commit 6a196ad
Show file tree
Hide file tree
Showing 4 changed files with 385 additions and 9 deletions.
20 changes: 12 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,25 @@ First train your VAE - `VQGanVAE`

```python
import torch
from muse_maskgit_pytorch import VQGanVAE
from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer

vae = VQGanVAE(
dim = 256,
vq_codebook_size = 512
).cuda()

# mock images
)

images = torch.randn(4, 3, 256, 256).cuda()
# train on folder of images, as many images as possible

# do this for as many images as possible
trainer = VQGanVAETrainer(
vae = vae,
image_size = 128, # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it
folder = '/path/to/images',
batch_size = 4,
grad_accum_every = 8,
num_train_steps = 50000
).cuda()

loss = vae(images, return_loss = True)
loss.backward()
trainer.train()
```

Then pass the trained `VQGanVAE` and a `Transformer` to `MaskGit`
Expand Down
2 changes: 2 additions & 0 deletions muse_maskgit_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
from muse_maskgit_pytorch.muse_maskgit_pytorch import Transformer, MaskGit, Muse

from muse_maskgit_pytorch.trainers import VQGanVAETrainer
Loading

0 comments on commit 6a196ad

Please sign in to comment.