Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

36gb minimum GPU memory is required using batch size 1 and fp16 mixed precision training? #61

Open
danielvegamyhre opened this issue Jul 16, 2024 · 9 comments

Comments

@danielvegamyhre
Copy link
Contributor

It seems 16GB GPU memory is not enough, I get CUDA out of memory error immediately, and can see in Colab resource usage GPU memory spike to max before crashing.

So to estimate how much GPU vRAM would be required, I first summed total model params for CLIPVisionModelWithProjection, AutoencoderKLTemporalDecoder, and UNetSpatioTemporalConditionModel:

total params: 2254442729

Next, I multiply the model params by (2 + 2 + 12). These numbers are from:

  • 2 bytes for fp16 copy of model params (used in fp16 mixed precision training)
  • 2 bytes for fp16 model gradients (used in fp16 mixed precision training),
  • 12 bytes for optimizer state (4 bytes for each fp32 parameter, momentum and variance)

Multiplying this out, I get 2254442729 * (2 + 2 + 12) = 36071083664 bytes, which is ~36Gb of GPU memory required to fine tune using batch size of 1 and fp16 mixed precision training.

Is this accurate?

@christopher-beckham
Copy link

Not the repo author but I also had some concerns related, see my comment and the one below here: #31 (comment)

For starters, the entire unet is stored as fp32 in the script, since for some reason this cast is commented out:

https://github.com/pixeli99/SVD_Xtend/blob/main/train_svd.py#L730

Also the number of frames it defaults to training on is 25, which can really blow up your GPU memory.

@danielvegamyhre
Copy link
Contributor Author

Not the repo author but I also had some concerns related, see my comment and the one below here: #31 (comment)

For starters, the entire unet is stored as fp32 in the script, since for some reason this cast is commented out:

https://github.com/pixeli99/SVD_Xtend/blob/main/train_svd.py#L730

Also the number of frames it defaults to training on is 25, which can really blow up your GPU memory.

Thanks, this is helpful. I replied in the thread you link with some follow up questions.

@pixeli99
Copy link
Owner

I'm sorry, at the beginning of writing this code, I was more focused on supporting SVD training and didn't consider the memory issues much. This has caused some inconvenience to everyone. As @christopher-beckham mentioned, this line of code should not have been commented out. I have fixed this issue.

@KhaledButainy
Copy link

I'm sorry, at the beginning of writing this code, I was more focused on supporting SVD training and didn't consider the memory issues much. This has caused some inconvenience to everyone. As @christopher-beckham mentioned, this line of code should not have been commented out. I have fixed this issue.

I removed the comment on the following line:
https://github.com/pixeli99/SVD_Xtend/blob/main/train_svd.py#L739

Is this the correct line? What other lines did you modify?

Thank you in advance.

@christopher-beckham
Copy link

You can use the following to cast everything to fp16 except the trainable params.

    unet.requires_grad_(True)
    parameters_list = []

    # Customize the parameters that need to be trained; if necessary, you can uncomment them yourself.
    for name, para in unet.named_parameters():
        if 'temporal_transformer_block' in name:
            parameters_list.append(para)
            para.requires_grad = True
            para.data = para.data.to(dtype=torch.float32)
        else:
            para.requires_grad = False

@KhaledButainy
Copy link

KhaledButainy commented Jul 23, 2024

You can use the following to cast everything to fp16 except the trainable params.

    unet.requires_grad_(True)
    parameters_list = []

    # Customize the parameters that need to be trained; if necessary, you can uncomment them yourself.
    for name, para in unet.named_parameters():
        if 'temporal_transformer_block' in name:
            parameters_list.append(para)
            para.requires_grad = True
            para.data = para.data.to(dtype=torch.float32)
        else:
            para.requires_grad = False

Thank you for sharing.

What do you think about keeping this line commented:
https://github.com/pixeli99/SVD_Xtend/blob/main/train_svd.py#L739

and cast only the frozen parameters:

    unet.requires_grad_(True)
    parameters_list = []

    # Customize the parameters that need to be trained; if necessary, you can uncomment them yourself.
    for name, para in unet.named_parameters():
        if 'temporal_transformer_block' in name:
            parameters_list.append(para)
            para.requires_grad = True
        else:
            para.requires_grad = False
            para.data = para.data.to(dtype=weight_dtype) # torch.float16

This way we don't lose the model precision by downcasting to float16 and upcasting to float32 again.

@christopher-beckham
Copy link

christopher-beckham commented Jul 23, 2024 via email

@KhaledButainy
Copy link

You can also enable --gradient-checkpointing to save more GPU memory. However, this might result in slower training.

@KhaledButainy
Copy link

KhaledButainy commented Jul 23, 2024

No. You don't want the untrained params in f32. You're trying to save memory on the GPU.

We are not doing that. Untrained params will be in f16, and only trained params are in f32 as you suggested.

In my comment, I suggested keeping the model in f32 and only downcasting the frozen params to f16 instead of downcasting the full model to f16 and then upcasting the training params to f32.

Both will save the same GPU memory, but you will lose some precision on the training params following the second approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants