-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
36gb minimum GPU memory is required using batch size 1 and fp16 mixed precision training? #61
Comments
Not the repo author but I also had some concerns related, see my comment and the one below here: #31 (comment) For starters, the entire unet is stored as https://github.com/pixeli99/SVD_Xtend/blob/main/train_svd.py#L730 Also the number of frames it defaults to training on is 25, which can really blow up your GPU memory. |
Thanks, this is helpful. I replied in the thread you link with some follow up questions. |
I'm sorry, at the beginning of writing this code, I was more focused on supporting SVD training and didn't consider the memory issues much. This has caused some inconvenience to everyone. As @christopher-beckham mentioned, this line of code should not have been commented out. I have fixed this issue. |
I removed the comment on the following line: Is this the correct line? What other lines did you modify? Thank you in advance. |
You can use the following to cast everything to fp16 except the trainable params.
|
Thank you for sharing. What do you think about keeping this line commented: and cast only the frozen parameters:
This way we don't lose the model precision by downcasting to |
No. You don't want the untrained params in f32. You're trying to save
memory on the gpu.
…On Tue, Jul 23, 2024 at 15:57 Khaled ***@***.***> wrote:
You can use the following to cast everything to fp16 except the trainable
params.
unet.requires_grad_(True)
parameters_list = []
# Customize the parameters that need to be trained; if necessary, you can uncomment them yourself.
for name, para in unet.named_parameters():
if 'temporal_transformer_block' in name:
parameters_list.append(para)
para.requires_grad = True
para.data = para.data.to(dtype=torch.float32)
else:
para.requires_grad = False
Thank you for sharing.
What do you think about keeping this line commented:
https://github.com/pixeli99/SVD_Xtend/blob/main/train_svd.py#L739
and cast only the frozen parameters:
'''
unet.requires_grad_(True)
parameters_list = []
# Customize the parameters that need to be trained; if necessary, you can uncomment them yourself.
for name, para in unet.named_parameters():
if 'temporal_transformer_block' in name:
parameters_list.append(para)
para.requires_grad = True
else:
para.requires_grad = False
para.data = para.data.to(dtype=weight_dtype) # torch.float16
'''
This way we don't lose the model precision by downcasting to float16 and
upcasting to float32 again.
—
Reply to this email directly, view it on GitHub
<#61 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AASOJADFRGCETOVS6B5FYULZN2YTRAVCNFSM6AAAAABK53FMI2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENBWGE4DINBRGY>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
You can also enable |
We are not doing that. Untrained params will be in In my comment, I suggested keeping the model in Both will save the same GPU memory, but you will lose some precision on the training params following the second approach. |
It seems 16GB GPU memory is not enough, I get CUDA out of memory error immediately, and can see in Colab resource usage GPU memory spike to max before crashing.
So to estimate how much GPU vRAM would be required, I first summed total model params for CLIPVisionModelWithProjection, AutoencoderKLTemporalDecoder, and UNetSpatioTemporalConditionModel:
total params: 2254442729
Next, I multiply the model params by (2 + 2 + 12). These numbers are from:
Multiplying this out, I get 2254442729 * (2 + 2 + 12) = 36071083664 bytes, which is ~36Gb of GPU memory required to fine tune using batch size of 1 and fp16 mixed precision training.
Is this accurate?
The text was updated successfully, but these errors were encountered: