Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to train VA-VAE? #1

Open
lavinal712 opened this issue Jan 3, 2025 · 16 comments
Open

How to train VA-VAE? #1

lavinal712 opened this issue Jan 3, 2025 · 16 comments

Comments

@lavinal712
Copy link

Can you release the code? Thanks for your works!

@JingfengYao
Copy link
Member

The training code for VA-VAE is primarily based on the autoencoder training code from LDM. Implementing it should be relatively straightforward, as described in Section 3 of the paper. We are currently considering the most concise way to release it, such as forking or something else.

@lavinal712
Copy link
Author

Thanks! I tried to reproduce your code, but I found that the vf_loss does not converge easily. After training for 1000 steps, the model collapsed, and the output turned into solid-color images. Therefore, I would like to see more details in the code.

@gkakogeorgiou
Copy link

Hi! Thanks for the great work! I noticed similar behavior to what @lavinal712 observed. Releasing the VA-VAE training code could help us better understand the process.

@lavinal712
Copy link
Author

lavinal712 commented Jan 8, 2025

import torch
import torch.nn as nn

from taming.modules.losses.vqperceptual import *  # TODO: taming dependency yes/no?


class LPIPSWithDiscriminator(nn.Module):
    def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
                 disc_loss="hinge", 
                 vf_weight=0.1, adaptive_vf=True, vf_loss_type="combined_v3", distmat_margin=0.25, cos_margin=0.5):

        super().__init__()
        assert disc_loss in ["hinge", "vanilla"]
        self.kl_weight = kl_weight
        self.pixel_weight = pixelloss_weight
        self.perceptual_loss = LPIPS().eval()
        self.perceptual_weight = perceptual_weight
        # output log variance
        self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)

        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
                                                 n_layers=disc_num_layers,
                                                 use_actnorm=use_actnorm
                                                 ).apply(weights_init)
        self.discriminator_iter_start = disc_start
        self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
        self.disc_factor = disc_factor
        self.discriminator_weight = disc_weight
        self.disc_conditional = disc_conditional
        
        self.vf_weight = vf_weight
        self.adaptive_vf = adaptive_vf
        self.vf_loss_type = vf_loss_type
        self.distmat_margin = distmat_margin
        self.cos_margin = cos_margin

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
        if last_layer is not None:
            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        else:
            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight

    def forward(self, inputs, reconstructions, posteriors, z_prime, features, optimizer_idx,
                global_step, last_layer=None, encoder_last_layer=None, cond=None, split="train",
                weights=None):
        rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
        if self.perceptual_weight > 0:
            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
            rec_loss = rec_loss + self.perceptual_weight * p_loss

        nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
        weighted_nll_loss = nll_loss
        if weights is not None:
            weighted_nll_loss = weights*nll_loss
        weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
        nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
        kl_loss = posteriors.kl()
        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]

        # now the GAN part
        if optimizer_idx == 0:
            # generator update
            if cond is None:
                assert not self.disc_conditional
                logits_fake = self.discriminator(reconstructions.contiguous())
            else:
                assert self.disc_conditional
                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
            g_loss = -torch.mean(logits_fake)

            if self.disc_factor > 0.0:
                try:
                    d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
                except RuntimeError:
                    assert not self.training
                    d_weight = torch.tensor(0.0)
            else:
                d_weight = torch.tensor(0.0)
                
            if z_prime is not None:
                if self.vf_loss_type == "combined_v3":
                    mcos_loss = nn.functional.cosine_similarity(z_prime, features, dim=-1)
                    mcos_loss = nn.functional.relu(1 - self.cos_margin - mcos_loss)
                    mcos_loss = torch.mean(mcos_loss)
                    
                    z_prime = nn.functional.normalize(z_prime, dim=-1)
                    features = nn.functional.normalize(features, dim=-1)
                    
                    cossim_z_prime = torch.matmul(z_prime, z_prime.transpose(1, 2))
                    cossim_features = torch.matmul(features, features.transpose(1, 2))
                    
                    mdms_loss = torch.abs(cossim_z_prime - cossim_features)
                    mdms_loss = nn.functional.relu(mdms_loss - self.distmat_margin)
                    mdms_loss = torch.mean(mdms_loss)
                    
                    vf_loss = mcos_loss + mdms_loss
                else:
                    raise ValueError(f"Unknown vf_loss_type: {self.vf_loss_type}")
                
                if self.adaptive_vf:
                    try:
                        vf_weight = self.calculate_adaptive_weight(nll_loss, vf_loss, last_layer=encoder_last_layer)
                    except RuntimeError:
                        assert not self.training
                        vf_weight = torch.tensor(1.0)
                    vf_loss = vf_weight * vf_loss
            else:
                vf_loss = torch.tensor(0.0)

            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
            loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + self.vf_weight * vf_loss

            log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
                   "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
                   "{}/rec_loss".format(split): rec_loss.detach().mean(),
                   "{}/d_weight".format(split): d_weight.detach(),
                   "{}/disc_factor".format(split): torch.tensor(disc_factor),
                   "{}/g_loss".format(split): g_loss.detach().mean(),
                   "{}/vf_loss".format(split): vf_loss.detach().mean(),
                   }
            return loss, log

        if optimizer_idx == 1:
            # second pass for discriminator update
            if cond is None:
                logits_real = self.discriminator(inputs.contiguous().detach())
                logits_fake = self.discriminator(reconstructions.contiguous().detach())
            else:
                logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))

            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)

            log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
                   "{}/logits_real".format(split): logits_real.detach().mean(),
                   "{}/logits_fake".format(split): logits_fake.detach().mean()
                   }
            return d_loss, log

@gkakogeorgiou @JingfengYao Is there any problem of the vf_loss?

@JingfengYao
Copy link
Member

JingfengYao commented Jan 8, 2025

I'm not sure about the shapes of features and z_prime. Judging from your mcos_loss = nn.functional.cosine_similarity(z_prime, features, dim=-1) and z_prime = nn.functional.normalize(z_prime, dim=-1), are you placing the channel dimension in the last dimension?

Here are my implementations:

elif self.vf_loss_type == "combined_v3":
    # we also give a 0.25 margin to image dist mat loss
    z_flat = rearrange(z, 'b c h w -> b c (h w)')
    aux_feature_flat = rearrange(aux_feature, 'b c h w -> b c (h w)')
    z_norm = torch.nn.functional.normalize(z_flat, dim=1)
    aux_feature_norm = torch.nn.functional.normalize(aux_feature_flat, dim=1)
    z_cos_sim = torch.einsum('bci,bcj->bij', z_norm, z_norm)
    aux_feature_cos_sim = torch.einsum('bci,bcj->bij', aux_feature_norm, aux_feature_norm)
    diff = torch.abs(z_cos_sim - aux_feature_cos_sim)
    vf_loss_1 = torch.nn.functional.relu(diff-self.distmat_margin).mean()

    # margin_cos
    vf_loss_2 = torch.nn.functional.relu(1 - self.cos_margin - torch.nn.functional.cosine_similarity(aux_feature, z)).mean()
    vf_loss = vf_loss_1 + vf_loss_2

By the way, here are 2 small issues:

  1. We have not imposed an upper limit on the adaptive weighting for the vf loss, hence the upper bound value in d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() is set to a relatively large number with d_weight = torch.clamp(d_weight, 0.0, 1e8).detach() specifically for vf.
  2. https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/autoencoder.py#L386 There is a particular aspect in the LDM code that requires attention. Please ensure that the newly added projection layer for Z is correctly incorporated into the optimizer.

@lavinal712
Copy link
Author

@JingfengYao Thanks for your clarification.
The shape of z_prime is correct. I intentionally placed the channels dimension in the last dimension to align with the dimensionality of DINOv2. The code snippet is as follows:

z = posterior.sample()
z_prime = z.flatten(2).transpose(1, 2)
z_prime = self.proj(z_prime)

This ensures compatibility with DINOv2's expected input format.

And thank you for pointing this out! I agree that the connection layer should be included in the optimizer. I will update the code to ensure that the connection layer's parameters are properly added to the optimizer for training.

@lavinal712
Copy link
Author

lavinal712 commented Jan 8, 2025

Well, the model collapses again and I do not know why. Here is the code: lavinal712/VA-VAE

@JingfengYao
Copy link
Member

Could you please provide your tensorboard logs?

@lavinal712
Copy link
Author

image
image
The kl loss is extremely high.

@JingfengYao
Copy link
Member

JingfengYao commented Jan 8, 2025

May I ask how many GPUs you use for training and your starting command? @lavinal712

@JingfengYao
Copy link
Member

JingfengYao commented Jan 8, 2025

Seems I found the possible reason. Here are my reproductions:

  1. I clone your codes and run it with one GPU (I am not sure your starting command. In VA-VAE implementations, I modify to use torchrun for multi-gpu training)

    I start with

    python main.py --base configs/autoencoder/vavae_f16d32.yaml --train --gpu 0,
    

    and get a similar problem with you. My KL loss is also extremely high. That is because of this:

    if opt.scale_lr:
        model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
        print(
            "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
                model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
    

    in my setting, lr is automatically set to 2.4e-3, which means a much larger learning rate and smaller batch size than our real training. We use fixed 1e-4 lr and 256 or 512 batch size for training of VA-VAE reported in paper.

  2. Then I modify the command as following and log vf_loss before adaptive weighting:

    python main.py --base configs/autoencoder/vavae_f16d32.yaml --train --gpu 0, --scale_lr false
    

    The training becomes more stable:
    1736335677520

Hope this helps.

@lavinal712
Copy link
Author

Thanks! Is it normal for the KL loss to gradually increase?

@lavinal712
Copy link
Author

May I ask how many GPUs you use for training and your starting command? 请问您训练时使用了多少 GPU 以及您的启动命令是什么?@lavinal712

4 GPUs,

python main.py --b configs/autoencoder/vavae_f16d32.yaml --t --gpu 0,1,2,3

@lavinal712
Copy link
Author

The batch size I used for training the VAE is too small. How many GPUs did you use for training in your paper? Are there any methods to increase the batch size?

@JingfengYao
Copy link
Member

Yes, the weight of the KL loss is relatively small among various losses. A tremendous KL loss might impact the generation performance. We utilized 32/64 GPUs to train the VA-VAE. Perhaps you could experiment with mixed precision training, checkpointing, and gradient accumulation (which seems to have been already employed) to increase the batch size.

@lavinal712
Copy link
Author

Thank you for your guidance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants