Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support XPU for auto-paralllel LLaMa #9796

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions llm/auto_parallel/llama/run_llama2_13b_xpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#!/bin/bash

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

task_name_or_path="llama2-13b-auto"

#export XPUAPI_DEBUG=0x1
#export XPURT_DISPATCH_MODE=PROFILING
export XBLAS_FC_HBM_VERSION=40

# PaddlePaddle
export FLAGS_use_stride_kernel="0"
export XPU_PADDLE_L3_SIZE=98566144 # 94 MB
export XPU_CDNN_CLUSTER_PARALLEL=1
export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2

# PDC
unset PADDLE_ELASTIC_JOB_ID
unset PADDLE_TRAINER_ENDPOINTS
unset DISTRIBUTED_TRAINER_ENDPOINTS
unset FLAGS_START_PORT
unset PADDLE_ELASTIC_TIMEOUT
unset PADDLE_TRAINERS_NUM

# BKCL
# export BKCL_DEBUG=1
# Multi-computer RDMA
#export BKCL_ENABLE_XDR=1
#export BKCL_RDMA_FORCE_TREE=1
#export BKCL_TREE_THRESHOLD=0
#export BKCL_RDMA_NICS=xgbe1,xgbe1,xgbe2,xgbe2,xgbe3,xgbe3,xgbe4,xgbe4
#export BKCL_SOCKET_IFNAME=xgbe0
#export BKCL_FORCE_L3_RDMA=0
echo "bkcl version:"
strings ${bkcl_location}/libbkcl.so | grep COM

export CUDA_DEVICE_MAX_CONNECTIONS=8

#PYTHONPATH
export PYTHONPATH=../../../:$PYTHONPATH

# for debug
#export GLOG_v=6
#export FLAGS_call_stack_level=2

rm -rf output/$task_name_or_path
PYTHONPATH=../:$PYTHONPATH \
python -u -m paddle.distributed.launch \
--xpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name_or_path/" \
run_pretrain_auto.py \
--model_name_or_path "meta-llama/Llama-2-13b" \
--tokenizer_name_or_path "meta-llama/Llama-2-13b" \
--input_dir "./data" \
--output_dir "output/$task_name_or_path" \
--split 949,50,1 \
--max_seq_length 4096 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--use_flash_attention 1 \
--use_fused_rope 1 \
--fuse_attention_ffn 1 \
--fuse_attention_qkv 1 \
--use_fused_rms_norm 0 \
--num_hidden_layers 4 \
--bf16 \
--fp16_opt_level "O2" \
--amp_master_grad true \
--scale_loss 1024 \
--learning_rate 0.00003 \
--min_learning_rate 0.000005 \
--lr_scheduler_type "cosine" \
--max_steps 10 \
--save_steps 100000 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1 \
--sequence_parallel 0 \
--dataloader_num_workers 4 \
--pipeline_parallel_degree 1 \
--tensor_parallel_degree 1 \
--gradient_accumulation_steps 32 \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0 \
--recompute 0 \
--do_train \
--seed 1026 \
--device "xpu" \
--enable_auto_parallel 1
17 changes: 17 additions & 0 deletions llm/auto_parallel/llama/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
print_rank_0,
)
from paddlenlp.trainer.utils.doc import add_start_docstrings
from paddlenlp.utils.tools import get_env_device


@dataclass
Expand Down Expand Up @@ -173,6 +174,11 @@ class ModelArguments:
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)

use_fast_layer_norm: bool = field(
default=False,
metadata={"help": "GPT3 model, use fast layernorm"},
)

config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -496,6 +502,8 @@ def main():

config = config_class.from_pretrained(model_args.model_name_or_path)

config.use_fast_layer_norm = model_args.use_fast_layer_norm

config.seq_length = data_args.max_seq_length
# There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings
if not model_args.continue_training:
Expand Down Expand Up @@ -544,6 +552,15 @@ def main():
pipeline = training_args.strategy.pipeline
pipeline.vpp_degree = config.virtual_pp_degree
pipeline.vpp_seg_method = training_args.virtual_pipeline_seg_method
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
# It's OK, not use accumulate_steps optimization
pass

print("Final pre-training config:", config)

Expand Down
5 changes: 5 additions & 0 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tqdm.auto import tqdm

from paddlenlp.trainer import Trainer
from paddlenlp.utils.tools import get_env_device

Check warning on line 29 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L29

Added line #L29 was not covered by tests

from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
from ..utils.log import logger
Expand Down Expand Up @@ -522,6 +523,10 @@

logger.info("\nTraining completed. \n")

# Hack for XPU that doesn't support Allgather yet. See LlamaPretrainingCriterion3DAuto in modeling_auto.py for details.
if get_env_device() == "xpu":
tr_loss = tr_loss.mean()

Check warning on line 528 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L527-L528

Added lines #L527 - L528 were not covered by tests

self._total_loss_scalar += self._get_item_from_loss(tr_loss)
train_loss = self._total_loss_scalar / self.state.global_step

Expand Down
3 changes: 1 addition & 2 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,13 +1632,12 @@ def is_segment_parallel_supported():
"enable_mp_async_allreduce", # allreduce_matmul_grad_overlapping in auto_parallel
"enable_delay_scale_loss",
"replace_with_c_embedding",
# "enable_mp_skip_c_identity",
# "enable_mp_fused_linear_param_grad_add",
"replace_with_parallel_cross_entropy",
]:
raise ValueError(
f"Found unknown tensor parallell config {x}, "
f"accept config is enable_mp_async_allreduce, replace_with_c_embedding, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add"
f"accept config is enable_mp_async_allreduce, replace_with_c_embedding, and enable_mp_fused_linear_param_grad_add"
)
try:
if "enable_mp_async_allreduce" in mp_config:
Expand Down
66 changes: 56 additions & 10 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
CausalLMOutputWithCrossAttentions,
)
from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model
from paddlenlp.utils.tools import get_env_device

from . import fusion_ops
from .configuration import (
LLAMA_PRETRAINED_INIT_CONFIGURATION,
LLAMA_PRETRAINED_RESOURCE_FILES_MAP,
Expand All @@ -69,7 +71,6 @@
build_alibi_tensor,
get_triangle_upper_mask,
repeat_kv,
rms_norm_fused,
)

try:
Expand Down Expand Up @@ -194,10 +195,6 @@
return (attn_output, attn_weights) if output_attentions else attn_output


colwise_placements = [dist.Replicate(), dist.Shard(1)]
rowise_placement = [dist.Replicate(), dist.Shard(0)]


class LlamaRMSNormAuto(nn.Layer):
def __init__(self, config, ipp):
super().__init__()
Expand All @@ -218,7 +215,9 @@

def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)
return fusion_ops.fusion_rms_norm(

Check warning on line 218 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L218

Added line #L218 was not covered by tests
hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm
)

with paddle.amp.auto_cast(False):
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
Expand All @@ -238,6 +237,16 @@
self.fuse_attention_ffn = config.fuse_attention_ffn
self.ipp = ipp
self.config = config
colwise_placements = (

Check warning on line 240 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L240

Added line #L240 was not covered by tests
[dist.Replicate(), dist.Shard(1)]
if self.config.tensor_parallel_degree > 1
else [dist.Replicate(), dist.Replicate()]
)
rowise_placement = (

Check warning on line 245 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L245

Added line #L245 was not covered by tests
[dist.Replicate(), dist.Shard(0)]
if self.config.tensor_parallel_degree > 1
else [dist.Replicate(), dist.Replicate()]
)

if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass():
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
Expand Down Expand Up @@ -307,8 +316,19 @@
self.recompute_granularity = config.recompute_granularity
self.ipp = ipp

colwise_placements = (

Check warning on line 319 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L319

Added line #L319 was not covered by tests
[dist.Replicate(), dist.Shard(1)]
if self.config.tensor_parallel_degree > 1
else [dist.Replicate(), dist.Replicate()]
)
rowise_placement = (

Check warning on line 324 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L324

Added line #L324 was not covered by tests
[dist.Replicate(), dist.Shard(0)]
if self.config.tensor_parallel_degree > 1
else [dist.Replicate(), dist.Replicate()]
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope:
if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu", "intel_hpu"]:

Check warning on line 331 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L331

Added line #L331 was not covered by tests
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand Down Expand Up @@ -935,7 +955,22 @@
else:
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y)
elif get_env_device() == "gcu":
min_val = paddle.finfo(dtype).min
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(min_val, dtype=dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)

Check warning on line 970 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L958-L970

Added lines #L958 - L970 were not covered by tests
else:
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min)
expanded_attn_mask = expanded_attn_mask.astype(dtype)

Check warning on line 973 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L972-L973

Added lines #L972 - L973 were not covered by tests
return expanded_attn_mask

def forward(
Expand Down Expand Up @@ -1166,15 +1201,26 @@
masked_lm_labels.unsqueeze(2),
)

masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss)
# Hack for XPU that doesn't support Allgather yet.
if get_env_device() == "xpu":

Check warning on line 1205 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1205

Added line #L1205 was not covered by tests
# masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss, axis=-1)

Check warning on line 1207 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1207

Added line #L1207 was not covered by tests
else:
masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32")
loss = paddle.mean(masked_lm_loss, axis=-1)

Check warning on line 1210 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1209-L1210

Added lines #L1209 - L1210 were not covered by tests

return loss


class LlamaLMHeadAuto(nn.Layer):
def __init__(self, config: LlamaConfig):
super(LlamaLMHeadAuto, self).__init__()
self.config = config
colwise_placements = (

Check warning on line 1219 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1219

Added line #L1219 was not covered by tests
[dist.Replicate(), dist.Shard(1)]
if self.config.tensor_parallel_degree > 1
else [dist.Replicate(), dist.Replicate()]
)
vocab_size = config.vocab_size
self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
Expand Down