diff --git a/llm/auto_parallel/llama/run_llama2_13b_xpu.sh b/llm/auto_parallel/llama/run_llama2_13b_xpu.sh new file mode 100755 index 000000000000..17c60a600514 --- /dev/null +++ b/llm/auto_parallel/llama/run_llama2_13b_xpu.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +task_name_or_path="llama2-13b-auto" + +#export XPUAPI_DEBUG=0x1 +#export XPURT_DISPATCH_MODE=PROFILING +export XBLAS_FC_HBM_VERSION=40 + +# PaddlePaddle +export FLAGS_use_stride_kernel="0" +export XPU_PADDLE_L3_SIZE=98566144 # 94 MB +export XPU_CDNN_CLUSTER_PARALLEL=1 +export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2 + +# PDC +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +unset PADDLE_TRAINERS_NUM + +# BKCL +# export BKCL_DEBUG=1 +# Multi-computer RDMA +#export BKCL_ENABLE_XDR=1 +#export BKCL_RDMA_FORCE_TREE=1 +#export BKCL_TREE_THRESHOLD=0 +#export BKCL_RDMA_NICS=xgbe1,xgbe1,xgbe2,xgbe2,xgbe3,xgbe3,xgbe4,xgbe4 +#export BKCL_SOCKET_IFNAME=xgbe0 +#export BKCL_FORCE_L3_RDMA=0 +echo "bkcl version:" +strings ${bkcl_location}/libbkcl.so | grep COM + +export CUDA_DEVICE_MAX_CONNECTIONS=8 + +#PYTHONPATH +export PYTHONPATH=../../../:$PYTHONPATH + +# for debug +#export GLOG_v=6 +#export FLAGS_call_stack_level=2 + +rm -rf output/$task_name_or_path +PYTHONPATH=../:$PYTHONPATH \ +python -u -m paddle.distributed.launch \ + --xpus "0,1,2,3,4,5,6,7" \ + --log_dir "output/$task_name_or_path/" \ + run_pretrain_auto.py \ + --model_name_or_path "meta-llama/Llama-2-13b" \ + --tokenizer_name_or_path "meta-llama/Llama-2-13b" \ + --input_dir "./data" \ + --output_dir "output/$task_name_or_path" \ + --split 949,50,1 \ + --max_seq_length 4096 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --use_flash_attention 1 \ + --use_fused_rope 1 \ + --fuse_attention_ffn 1 \ + --fuse_attention_qkv 1 \ + --use_fused_rms_norm 0 \ + --num_hidden_layers 4 \ + --bf16 \ + --fp16_opt_level "O2" \ + --amp_master_grad true \ + --scale_loss 1024 \ + --learning_rate 0.00003 \ + --min_learning_rate 0.000005 \ + --lr_scheduler_type "cosine" \ + --max_steps 10 \ + --save_steps 100000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --max_grad_norm 1.0 \ + --logging_steps 1 \ + --sequence_parallel 0 \ + --dataloader_num_workers 4 \ + --pipeline_parallel_degree 1 \ + --tensor_parallel_degree 1 \ + --gradient_accumulation_steps 32 \ + --eval_steps 1000 \ + --report_to "visualdl" \ + --disable_tqdm true \ + --continue_training 0 \ + --recompute 0 \ + --do_train \ + --seed 1026 \ + --device "xpu" \ + --enable_auto_parallel 1 diff --git a/llm/auto_parallel/llama/run_pretrain_auto.py b/llm/auto_parallel/llama/run_pretrain_auto.py index fa3d8855afb5..24e737de544b 100644 --- a/llm/auto_parallel/llama/run_pretrain_auto.py +++ b/llm/auto_parallel/llama/run_pretrain_auto.py @@ -59,6 +59,7 @@ print_rank_0, ) from paddlenlp.trainer.utils.doc import add_start_docstrings +from paddlenlp.utils.tools import get_env_device @dataclass @@ -173,6 +174,11 @@ class ModelArguments: default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) + use_fast_layer_norm: bool = field( + default=False, + metadata={"help": "GPT3 model, use fast layernorm"}, + ) + config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -496,6 +502,8 @@ def main(): config = config_class.from_pretrained(model_args.model_name_or_path) + config.use_fast_layer_norm = model_args.use_fast_layer_norm + config.seq_length = data_args.max_seq_length # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings if not model_args.continue_training: @@ -544,6 +552,15 @@ def main(): pipeline = training_args.strategy.pipeline pipeline.vpp_degree = config.virtual_pp_degree pipeline.vpp_seg_method = training_args.virtual_pipeline_seg_method + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + # It's OK, not use accumulate_steps optimization + pass print("Final pre-training config:", config) diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 288f10e8f6ef..deb523a7f713 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -26,6 +26,7 @@ from tqdm.auto import tqdm from paddlenlp.trainer import Trainer +from paddlenlp.utils.tools import get_env_device from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler from ..utils.log import logger @@ -522,6 +523,10 @@ def _inner_training_loop( logger.info("\nTraining completed. \n") + # Hack for XPU that doesn't support Allgather yet. See LlamaPretrainingCriterion3DAuto in modeling_auto.py for details. + if get_env_device() == "xpu": + tr_loss = tr_loss.mean() + self._total_loss_scalar += self._get_item_from_loss(tr_loss) train_loss = self._total_loss_scalar / self.state.global_step diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index bb0b91224180..b488ada1be0e 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1632,13 +1632,12 @@ def is_segment_parallel_supported(): "enable_mp_async_allreduce", # allreduce_matmul_grad_overlapping in auto_parallel "enable_delay_scale_loss", "replace_with_c_embedding", - # "enable_mp_skip_c_identity", # "enable_mp_fused_linear_param_grad_add", "replace_with_parallel_cross_entropy", ]: raise ValueError( f"Found unknown tensor parallell config {x}, " - f"accept config is enable_mp_async_allreduce, replace_with_c_embedding, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add" + f"accept config is enable_mp_async_allreduce, replace_with_c_embedding, and enable_mp_fused_linear_param_grad_add" ) try: if "enable_mp_async_allreduce" in mp_config: diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index a629cf1ec955..613fdac93e01 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -52,7 +52,9 @@ def swiglu(x, y=None): CausalLMOutputWithCrossAttentions, ) from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model +from paddlenlp.utils.tools import get_env_device +from . import fusion_ops from .configuration import ( LLAMA_PRETRAINED_INIT_CONFIGURATION, LLAMA_PRETRAINED_RESOURCE_FILES_MAP, @@ -69,7 +71,6 @@ def swiglu(x, y=None): build_alibi_tensor, get_triangle_upper_mask, repeat_kv, - rms_norm_fused, ) try: @@ -194,10 +195,6 @@ def scaled_dot_product_attention( return (attn_output, attn_weights) if output_attentions else attn_output -colwise_placements = [dist.Replicate(), dist.Shard(1)] -rowise_placement = [dist.Replicate(), dist.Shard(0)] - - class LlamaRMSNormAuto(nn.Layer): def __init__(self, config, ipp): super().__init__() @@ -218,7 +215,9 @@ def __init__(self, config, ipp): def forward(self, hidden_states): if self.config.use_fused_rms_norm: - return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) + return fusion_ops.fusion_rms_norm( + hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm + ) with paddle.amp.auto_cast(False): variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) @@ -238,6 +237,16 @@ def __init__(self, config, ipp: Optional[int] = None): self.fuse_attention_ffn = config.fuse_attention_ffn self.ipp = ipp self.config = config + colwise_placements = ( + [dist.Replicate(), dist.Shard(1)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) + rowise_placement = ( + [dist.Replicate(), dist.Shard(0)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass(): self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) @@ -307,8 +316,19 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp: self.recompute_granularity = config.recompute_granularity self.ipp = ipp + colwise_placements = ( + [dist.Replicate(), dist.Shard(1)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) + rowise_placement = ( + [dist.Replicate(), dist.Shard(0)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) + self.use_fused_rope = config.use_fused_rope - if self.use_fused_rope: + if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu", "intel_hpu"]: if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: warnings.warn( "Enable fuse rope in the config, but fuse rope is not available. " @@ -935,7 +955,22 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values else: expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) # Convert bool attention_mask to float attention mask, which will be added to attention_scores later - expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + if get_env_device() in ["npu", "mlu", "intel_hpu"]: + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) + elif get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y) + elif get_env_device() == "gcu": + min_val = paddle.finfo(dtype).min + x = paddle.to_tensor(0.0, dtype=dtype) + y = paddle.to_tensor(min_val, dtype=dtype) + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) + else: + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min) + expanded_attn_mask = expanded_attn_mask.astype(dtype) return expanded_attn_mask def forward( @@ -1166,8 +1201,14 @@ def forward(self, prediction_scores, masked_lm_labels): masked_lm_labels.unsqueeze(2), ) - masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") - loss = paddle.mean(masked_lm_loss) + # Hack for XPU that doesn't support Allgather yet. + if get_env_device() == "xpu": + # masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") + loss = paddle.mean(masked_lm_loss, axis=-1) + else: + masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") + loss = paddle.mean(masked_lm_loss) + return loss @@ -1175,6 +1216,11 @@ class LlamaLMHeadAuto(nn.Layer): def __init__(self, config: LlamaConfig): super(LlamaLMHeadAuto, self).__init__() self.config = config + colwise_placements = ( + [dist.Replicate(), dist.Shard(1)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) vocab_size = config.vocab_size self.weight = self.create_parameter( shape=[config.hidden_size, vocab_size],