-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to train VA-VAE? #1
Comments
The training code for VA-VAE is primarily based on the autoencoder training code from LDM. Implementing it should be relatively straightforward, as described in Section 3 of the paper. We are currently considering the most concise way to release it, such as forking or something else. |
Thanks! I tried to reproduce your code, but I found that the vf_loss does not converge easily. After training for 1000 steps, the model collapsed, and the output turned into solid-color images. Therefore, I would like to see more details in the code. |
Hi! Thanks for the great work! I noticed similar behavior to what @lavinal712 observed. Releasing the VA-VAE training code could help us better understand the process. |
@gkakogeorgiou @JingfengYao Is there any problem of the vf_loss? |
I'm not sure about the shapes of Here are my implementations:
By the way, here are 2 small issues:
|
@JingfengYao Thanks for your clarification.
This ensures compatibility with DINOv2's expected input format. And thank you for pointing this out! I agree that the connection layer should be included in the optimizer. I will update the code to ensure that the connection layer's parameters are properly added to the optimizer for training. |
Well, the model collapses again and I do not know why. Here is the code: lavinal712/VA-VAE |
Could you please provide your tensorboard logs? |
May I ask how many GPUs you use for training and your starting command? @lavinal712 |
Seems I found the possible reason. Here are my reproductions:
Hope this helps. |
Thanks! Is it normal for the KL loss to gradually increase? |
4 GPUs,
|
The batch size I used for training the VAE is too small. How many GPUs did you use for training in your paper? Are there any methods to increase the batch size? |
Yes, the weight of the KL loss is relatively small among various losses. A tremendous KL loss might impact the generation performance. We utilized 32/64 GPUs to train the VA-VAE. Perhaps you could experiment with mixed precision training, checkpointing, and gradient accumulation (which seems to have been already employed) to increase the batch size. |
Thank you for your guidance. |
Can you release the code? Thanks for your works!
The text was updated successfully, but these errors were encountered: