Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support XPU for auto-paralllel LLaMa #9796

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from

Conversation

From00
Copy link
Collaborator

@From00 From00 commented Jan 20, 2025

PR types

New features

PR changes

Models

Description

Llama模型适配xpu自动并行训练,目前仅支持动态图+纯dp(只包含allreduce通信)。
依赖主框架PR:PaddlePaddle/Paddle#70997

Copy link

codecov bot commented Jan 20, 2025

Codecov Report

Attention: Patch coverage is 6.45161% with 29 lines in your changes missing coverage. Please review.

Project coverage is 52.20%. Comparing base (13053a7) to head (927d878).
Report is 12 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/transformers/llama/modeling_auto.py 7.14% 26 Missing ⚠️
paddlenlp/trainer/auto_trainer.py 0.00% 3 Missing ⚠️

❌ Your patch check has failed because the patch coverage (6.45%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (52.20%) is below the target coverage (58.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9796      +/-   ##
===========================================
+ Coverage    52.06%   52.20%   +0.14%     
===========================================
  Files          734      730       -4     
  Lines       116591   115836     -755     
===========================================
- Hits         60703    60475     -228     
+ Misses       55888    55361     -527     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +958 to +970
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y)
elif get_env_device() == "gcu":
min_val = paddle.finfo(dtype).min
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(min_val, dtype=dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's mask generation differs between different devices.

Copy link
Collaborator Author

@From00 From00 Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask generation logic is same as here: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1606.

For the following two reasons that XPU needs a different mask:

  1. The flash_attention kernel implemented in XPU is different than in GPU, which may lead to numeric overflow when the mask value is too small. Therefore, a specific mask number -1.7005809656952787e38 is needed. @runzhech is fixing this issue and we can use paddle.finfo(dtype).min like GPU after fixed.
  2. The flash_attention kernel in XPU requires the mask input to be float32,so the astype(dtype) cannot be added in XPU mask generation.

See these two PRs for more details: #9495, #9652

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants