Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed zero3 QLoRA finetuning #11625

Merged
merged 30 commits into from
Aug 13, 2024
Merged

Conversation

Uxito-Ada
Copy link
Contributor

Description

transferred from #11048

1. Why the change?

2. User API changes

3. Summary of the change

4. How to test?

  • N/A
  • Unit test: Please manually trigger the PR Validation here by inputting the PR number (e.g., 1234). And paste your action link here once it has been successfully finished.
  • Application test
  • Document test
  • ...

5. New dependencies

  • New Python dependencies
    - Dependency1
    - Dependency2
    - ...
  • New Java/Scala dependencies and their license
    - Dependency1 and license1
    - Dependency2 and license2
    - ...

@Uxito-Ada Uxito-Ada requested review from qiyuangong and glorysdj July 19, 2024 07:42
Comment on lines 233 to 235
if enable_deepspeed_zero3:
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16,
device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should always do that for NF4 (only)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other NF4s are packed in torch.uint8, which do not make the buffer length redundant.
Only deepspeed zero3 needs NF4 to be packed in torch.bfloat16, which needs to halve the buffer.

@@ -259,9 +264,12 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,


def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move os import to top, because other module may share this import.

dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
device=device)
if enable_deepspeed_zero3:
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments for magic value 2 and hard-coded type bfloat16.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



# Arc platfrom does not support FP64,
# Disable FP64 in DeepSpeedZeroOptimizer_Stage3's _constant_buffered_norm2 method
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's different between our implementation and ds's one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ds is double(), fp64
here removes double(), as Arc does not support fp64

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@Uxito-Ada
Copy link
Contributor Author

Any more comment or approve? @qiyuangong

@@ -524,7 +525,8 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision)
mixed_precision=mixed_precision,
enable_deepspeed_zero3=enable_deepspeed_zero3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to introduce this user-level parameter; we should either change all NF4 to BF16, or all training (QLoRA) NF4 to BF16, instead of doing something special for zero3 only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls take a look again @jason-dai @qiyuangong


invalidInputError(tensor.dtype == torch.uint8,
"Input tensor must be uint8")
invalidInputError(tensor.dtype == torch.bfloat16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this change impact other features?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NF4 applications e.g. QLoRA (zero2) will not be influenced. While maybe better add judgement qtype == NF4?

Copy link
Contributor

@qiyuangong qiyuangong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Uxito-Ada
Copy link
Contributor Author

Passed PR validation.

@Uxito-Ada Uxito-Ada merged commit 70c828b into intel:main Aug 13, 2024
1 check passed
@Uxito-Ada Uxito-Ada deleted the heyang_24_7_19 branch August 13, 2024 08:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants