From 859900ec1ba703c10da54951ef016d0f8d3e76be Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 20 Jan 2025 11:47:37 +0800 Subject: [PATCH 1/6] Support XPU for auto-paralllel LLaMa --- llm/auto_parallel/llama/run_pretrain_auto.py | 10 +++ paddlenlp/trainer/training_args.py | 3 +- paddlenlp/transformers/llama/modeling_auto.py | 20 ++++- run_llama2_13b_4k_auto.sh | 87 +++++++++++++++++++ run_llama2_13b_4k_pp2.sh | 87 +++++++++++++++++++ 5 files changed, 203 insertions(+), 4 deletions(-) create mode 100755 run_llama2_13b_4k_auto.sh create mode 100755 run_llama2_13b_4k_pp2.sh diff --git a/llm/auto_parallel/llama/run_pretrain_auto.py b/llm/auto_parallel/llama/run_pretrain_auto.py index fa3d8855afb5..bd749d9933d0 100644 --- a/llm/auto_parallel/llama/run_pretrain_auto.py +++ b/llm/auto_parallel/llama/run_pretrain_auto.py @@ -58,6 +58,7 @@ check_data_split, print_rank_0, ) +from paddlenlp.utils.tools import get_env_device from paddlenlp.trainer.utils.doc import add_start_docstrings @@ -544,6 +545,15 @@ def main(): pipeline = training_args.strategy.pipeline pipeline.vpp_degree = config.virtual_pp_degree pipeline.vpp_seg_method = training_args.virtual_pipeline_seg_method + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + # It's OK, not use accumulate_steps optimization + pass print("Final pre-training config:", config) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index bb0b91224180..b488ada1be0e 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1632,13 +1632,12 @@ def is_segment_parallel_supported(): "enable_mp_async_allreduce", # allreduce_matmul_grad_overlapping in auto_parallel "enable_delay_scale_loss", "replace_with_c_embedding", - # "enable_mp_skip_c_identity", # "enable_mp_fused_linear_param_grad_add", "replace_with_parallel_cross_entropy", ]: raise ValueError( f"Found unknown tensor parallell config {x}, " - f"accept config is enable_mp_async_allreduce, replace_with_c_embedding, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add" + f"accept config is enable_mp_async_allreduce, replace_with_c_embedding, and enable_mp_fused_linear_param_grad_add" ) try: if "enable_mp_async_allreduce" in mp_config: diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index a629cf1ec955..aff512e286af 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -52,6 +52,7 @@ def swiglu(x, y=None): CausalLMOutputWithCrossAttentions, ) from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model +from paddlenlp.utils.tools import get_env_device from .configuration import ( LLAMA_PRETRAINED_INIT_CONFIGURATION, @@ -308,7 +309,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp: self.ipp = ipp self.use_fused_rope = config.use_fused_rope - if self.use_fused_rope: + if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu", "intel_hpu"]: if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: warnings.warn( "Enable fuse rope in the config, but fuse rope is not available. " @@ -935,7 +936,22 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values else: expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) # Convert bool attention_mask to float attention mask, which will be added to attention_scores later - expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + if get_env_device() in ["npu", "mlu", "intel_hpu"]: + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) + elif get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y) + elif get_env_device() == "gcu": + min_val = paddle.finfo(dtype).min + x = paddle.to_tensor(0.0, dtype=dtype) + y = paddle.to_tensor(min_val, dtype=dtype) + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) + else: + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min) + expanded_attn_mask = expanded_attn_mask.astype(dtype) return expanded_attn_mask def forward( diff --git a/run_llama2_13b_4k_auto.sh b/run_llama2_13b_4k_auto.sh new file mode 100755 index 000000000000..9ca0edd50c47 --- /dev/null +++ b/run_llama2_13b_4k_auto.sh @@ -0,0 +1,87 @@ +#!/bin/bash +cd llm +task_name_or_path="llama2-13b-4k" + +#export XPUAPI_DEBUG=0x1 +#export XPURT_DISPATCH_MODE=PROFILING +export XBLAS_FC_HBM_VERSION=40 + +# PaddlePaddle +export FLAGS_use_stride_kernel="0" +export XPU_PADDLE_L3_SIZE=98566144 # 94 MB +export XPU_CDNN_CLUSTER_PARALLEL=1 +export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2 + +# PDC +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +unset PADDLE_TRAINERS_NUM + +# BKCL +# export BKCL_DEBUG=1 +# Multi-computer RDMA +#export BKCL_ENABLE_XDR=1 +#export BKCL_RDMA_FORCE_TREE=1 +#export BKCL_TREE_THRESHOLD=0 +#export BKCL_RDMA_NICS=xgbe1,xgbe1,xgbe2,xgbe2,xgbe3,xgbe3,xgbe4,xgbe4 +#export BKCL_SOCKET_IFNAME=xgbe0 +#export BKCL_FORCE_L3_RDMA=0 +echo "bkcl version:" +strings ${bkcl_location}/libbkcl.so | grep COM + +export CUDA_DEVICE_MAX_CONNECTIONS=8 + +export GLOG_v=10 + +timestamp=$(date +%Y%m%d%H%M%S) +echo $timestamp +PYTHONPATH=../:$PYTHONPATH \ +python -u -m paddle.distributed.launch \ + --xpus "0,1,2,3,4,5,6,7" \ + --log_dir "output/$task_name_or_path/$timestamp""_log" \ + auto_parallel/llama/run_pretrain_auto.py \ + --model_name_or_path "meta-llama/Llama-2-13b" \ + --tokenizer_name_or_path "meta-llama/Llama-2-13b" \ + --input_dir "./data" \ + --output_dir "output/$task_name_or_path/$timestamp" \ + --split 949,50,1 \ + --max_seq_length 4096 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --use_flash_attention 1 \ + --use_fused_rope 1 \ + --fuse_attention_ffn 1 \ + --fuse_attention_qkv 1 \ + --use_fused_rms_norm 1 \ + --num_hidden_layers 40 \ + --bf16 \ + --fp16_opt_level "O2" \ + --amp_master_grad true \ + --scale_loss 1024 \ + --learning_rate 0.00003 \ + --min_learning_rate 0.000005 \ + --lr_scheduler_type "cosine" \ + --max_steps 100000 \ + --save_steps 100000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --max_grad_norm 1.0 \ + --logging_steps 1 \ + --sequence_parallel 0 \ + --dataloader_num_workers 4 \ + --pipeline_parallel_degree 2 \ + --tensor_parallel_degree 2 \ + --gradient_accumulation_steps 32 \ + --sharding "stage1" \ + --eval_steps 1000 \ + --report_to "visualdl" \ + --disable_tqdm true \ + --continue_training 0 \ + --recompute 0 \ + --do_train \ + --seed 1026 \ + --device "xpu" \ + --enable_auto_parallel 1 diff --git a/run_llama2_13b_4k_pp2.sh b/run_llama2_13b_4k_pp2.sh new file mode 100755 index 000000000000..f49df41249ca --- /dev/null +++ b/run_llama2_13b_4k_pp2.sh @@ -0,0 +1,87 @@ +#!/bin/bash +cd llm +task_name_or_path="llama2-13b-4k" + +#export XPUAPI_DEBUG=0x1 +#export XPURT_DISPATCH_MODE=PROFILING +export XBLAS_FC_HBM_VERSION=40 + +# PaddlePaddle +export FLAGS_use_stride_kernel="0" +export XPU_PADDLE_L3_SIZE=98566144 # 94 MB +export XPU_CDNN_CLUSTER_PARALLEL=1 +export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2 + +# PDC +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +unset PADDLE_TRAINERS_NUM + +# BKCL +# export BKCL_DEBUG=1 +# Multi-computer RDMA +#export BKCL_ENABLE_XDR=1 +#export BKCL_RDMA_FORCE_TREE=1 +#export BKCL_TREE_THRESHOLD=0 +#export BKCL_RDMA_NICS=xgbe1,xgbe1,xgbe2,xgbe2,xgbe3,xgbe3,xgbe4,xgbe4 +#export BKCL_SOCKET_IFNAME=xgbe0 +#export BKCL_FORCE_L3_RDMA=0 +echo "bkcl version:" +strings ${bkcl_location}/libbkcl.so | grep COM + +export CUDA_DEVICE_MAX_CONNECTIONS=8 + +timestamp=$(date +%Y%m%d%H%M%S) +echo $timestamp +PYTHONPATH=../:$PYTHONPATH \ +python -u -m paddle.distributed.launch \ + --xpus "0,1,2,3,4,5,6,7" \ + --log_dir "output/$task_name_or_path/$timestamp""_log" \ + run_pretrain.py \ + --model_name_or_path "meta-llama/Llama-2-13b" \ + --tokenizer_name_or_path "meta-llama/Llama-2-13b" \ + --input_dir "./data" \ + --output_dir "output/$task_name_or_path/$timestamp" \ + --split 949,50,1 \ + --max_seq_length 4096 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --use_flash_attention 1 \ + --use_fused_rope 1 \ + --fuse_attention_ffn 1 \ + --fuse_attention_qkv 1 \ + --use_fused_rms_norm 1 \ + --num_hidden_layers 40 \ + --bf16 \ + --fp16_opt_level "O2" \ + --scale_loss 1024 \ + --learning_rate 0.00003 \ + --min_learning_rate 0.000005 \ + --lr_scheduler_type "cosine" \ + --max_steps 100000 \ + --save_steps 100000 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --max_grad_norm 1.0 \ + --logging_steps 1 \ + --sequence_parallel 0 \ + --dataloader_num_workers 4 \ + --pipeline_parallel_degree 2 \ + --pipeline_parallel_config "disable_partial_send_recv" \ + --tensor_parallel_degree 2 \ + --tensor_parallel_config "enable_mp_async_allreduce,enable_mp_skip_c_identity" \ + --gradient_accumulation_steps 32 \ + --sharding "stage1" \ + --sharding_parallel_config "split_param" \ + --eval_steps 1000 \ + --report_to "visualdl" \ + --disable_tqdm true \ + --continue_training 0 \ + --recompute 0 \ + --do_train \ + --seed 1026 \ + --device "xpu" \ + --amp_master_grad true From 925210d78d2b54512d0ac3fff4a12059b8534e00 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 29 Jan 2025 15:18:32 +0800 Subject: [PATCH 2/6] Update --- llm/auto_parallel/llama/run_pretrain_auto.py | 9 +++- paddlenlp/trainer/auto_trainer.py | 5 ++ paddlenlp/transformers/llama/modeling_auto.py | 46 +++++++++++++++---- 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/llm/auto_parallel/llama/run_pretrain_auto.py b/llm/auto_parallel/llama/run_pretrain_auto.py index bd749d9933d0..24e737de544b 100644 --- a/llm/auto_parallel/llama/run_pretrain_auto.py +++ b/llm/auto_parallel/llama/run_pretrain_auto.py @@ -58,8 +58,8 @@ check_data_split, print_rank_0, ) -from paddlenlp.utils.tools import get_env_device from paddlenlp.trainer.utils.doc import add_start_docstrings +from paddlenlp.utils.tools import get_env_device @dataclass @@ -174,6 +174,11 @@ class ModelArguments: default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) + use_fast_layer_norm: bool = field( + default=False, + metadata={"help": "GPT3 model, use fast layernorm"}, + ) + config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -497,6 +502,8 @@ def main(): config = config_class.from_pretrained(model_args.model_name_or_path) + config.use_fast_layer_norm = model_args.use_fast_layer_norm + config.seq_length = data_args.max_seq_length # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings if not model_args.continue_training: diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 288f10e8f6ef..deb523a7f713 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -26,6 +26,7 @@ from tqdm.auto import tqdm from paddlenlp.trainer import Trainer +from paddlenlp.utils.tools import get_env_device from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler from ..utils.log import logger @@ -522,6 +523,10 @@ def _inner_training_loop( logger.info("\nTraining completed. \n") + # Hack for XPU that doesn't support Allgather yet. See LlamaPretrainingCriterion3DAuto in modeling_auto.py for details. + if get_env_device() == "xpu": + tr_loss = tr_loss.mean() + self._total_loss_scalar += self._get_item_from_loss(tr_loss) train_loss = self._total_loss_scalar / self.state.global_step diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index aff512e286af..1aa0a2f71502 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -54,6 +54,7 @@ def swiglu(x, y=None): from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model from paddlenlp.utils.tools import get_env_device +from . import fusion_ops from .configuration import ( LLAMA_PRETRAINED_INIT_CONFIGURATION, LLAMA_PRETRAINED_RESOURCE_FILES_MAP, @@ -70,7 +71,6 @@ def swiglu(x, y=None): build_alibi_tensor, get_triangle_upper_mask, repeat_kv, - rms_norm_fused, ) try: @@ -195,10 +195,6 @@ def scaled_dot_product_attention( return (attn_output, attn_weights) if output_attentions else attn_output -colwise_placements = [dist.Replicate(), dist.Shard(1)] -rowise_placement = [dist.Replicate(), dist.Shard(0)] - - class LlamaRMSNormAuto(nn.Layer): def __init__(self, config, ipp): super().__init__() @@ -219,7 +215,9 @@ def __init__(self, config, ipp): def forward(self, hidden_states): if self.config.use_fused_rms_norm: - return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) + return fusion_ops.fusion_rms_norm( + hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm + ) with paddle.amp.auto_cast(False): variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) @@ -239,6 +237,16 @@ def __init__(self, config, ipp: Optional[int] = None): self.fuse_attention_ffn = config.fuse_attention_ffn self.ipp = ipp self.config = config + colwise_placements = ( + [dist.Replicate(), dist.Shard(1)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) + rowise_placement = ( + [dist.Replicate(), dist.Shard(0)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass(): self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) @@ -308,6 +316,17 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp: self.recompute_granularity = config.recompute_granularity self.ipp = ipp + colwise_placements = ( + [dist.Replicate(), dist.Shard(1)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) + rowise_placement = ( + [dist.Replicate(), dist.Shard(0)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) + self.use_fused_rope = config.use_fused_rope if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu", "intel_hpu"]: if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: @@ -1182,8 +1201,14 @@ def forward(self, prediction_scores, masked_lm_labels): masked_lm_labels.unsqueeze(2), ) - masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") - loss = paddle.mean(masked_lm_loss) + # Hack for XPU that doesn't support Allgather yet. + if get_env_device() == "xpu": + # masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") + loss = paddle.mean(masked_lm_loss, axis=-1) + else: + masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") + loss = paddle.mean(masked_lm_loss, axis=-1) + return loss @@ -1191,6 +1216,11 @@ class LlamaLMHeadAuto(nn.Layer): def __init__(self, config: LlamaConfig): super(LlamaLMHeadAuto, self).__init__() self.config = config + colwise_placements = ( + [dist.Replicate(), dist.Shard(1)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) vocab_size = config.vocab_size self.weight = self.create_parameter( shape=[config.hidden_size, vocab_size], From 11ae2035ce22e830ff6857b75d7be2ff652caa0f Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 29 Jan 2025 18:45:48 +0800 Subject: [PATCH 3/6] Update --- .../auto_parallel/llama/run_llama2_13b_xpu.sh | 44 +++++++--- run_llama2_13b_4k_pp2.sh | 87 ------------------- 2 files changed, 31 insertions(+), 100 deletions(-) rename run_llama2_13b_4k_auto.sh => llm/auto_parallel/llama/run_llama2_13b_xpu.sh (64%) delete mode 100755 run_llama2_13b_4k_pp2.sh diff --git a/run_llama2_13b_4k_auto.sh b/llm/auto_parallel/llama/run_llama2_13b_xpu.sh similarity index 64% rename from run_llama2_13b_4k_auto.sh rename to llm/auto_parallel/llama/run_llama2_13b_xpu.sh index 9ca0edd50c47..6a36f36d1f71 100755 --- a/run_llama2_13b_4k_auto.sh +++ b/llm/auto_parallel/llama/run_llama2_13b_xpu.sh @@ -1,6 +1,21 @@ #!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + cd llm -task_name_or_path="llama2-13b-4k" +task_name_or_path="llama2-13b-auto" #export XPUAPI_DEBUG=0x1 #export XPURT_DISPATCH_MODE=PROFILING @@ -34,19 +49,23 @@ strings ${bkcl_location}/libbkcl.so | grep COM export CUDA_DEVICE_MAX_CONNECTIONS=8 -export GLOG_v=10 +#PYTHONPATH +export PYTHONPATH=../../../:$PYTHONPATH + +# for debug +#export GLOG_v=6 +#export FLAGS_call_stack_level=2 -timestamp=$(date +%Y%m%d%H%M%S) -echo $timestamp +rm -rf output/$task_name_or_path PYTHONPATH=../:$PYTHONPATH \ python -u -m paddle.distributed.launch \ --xpus "0,1,2,3,4,5,6,7" \ - --log_dir "output/$task_name_or_path/$timestamp""_log" \ - auto_parallel/llama/run_pretrain_auto.py \ + --log_dir "output/$task_name_or_path/" \ + run_pretrain_auto.py \ --model_name_or_path "meta-llama/Llama-2-13b" \ --tokenizer_name_or_path "meta-llama/Llama-2-13b" \ --input_dir "./data" \ - --output_dir "output/$task_name_or_path/$timestamp" \ + --output_dir "output/$task_name_or_path" \ --split 949,50,1 \ --max_seq_length 4096 \ --per_device_train_batch_size 1 \ @@ -55,8 +74,8 @@ python -u -m paddle.distributed.launch \ --use_fused_rope 1 \ --fuse_attention_ffn 1 \ --fuse_attention_qkv 1 \ - --use_fused_rms_norm 1 \ - --num_hidden_layers 40 \ + --use_fused_rms_norm 0 \ + --num_hidden_layers 4 \ --bf16 \ --fp16_opt_level "O2" \ --amp_master_grad true \ @@ -64,7 +83,7 @@ python -u -m paddle.distributed.launch \ --learning_rate 0.00003 \ --min_learning_rate 0.000005 \ --lr_scheduler_type "cosine" \ - --max_steps 100000 \ + --max_steps 10 \ --save_steps 100000 \ --weight_decay 0.01 \ --warmup_ratio 0.01 \ @@ -72,10 +91,9 @@ python -u -m paddle.distributed.launch \ --logging_steps 1 \ --sequence_parallel 0 \ --dataloader_num_workers 4 \ - --pipeline_parallel_degree 2 \ - --tensor_parallel_degree 2 \ + --pipeline_parallel_degree 1 \ + --tensor_parallel_degree 1 \ --gradient_accumulation_steps 32 \ - --sharding "stage1" \ --eval_steps 1000 \ --report_to "visualdl" \ --disable_tqdm true \ diff --git a/run_llama2_13b_4k_pp2.sh b/run_llama2_13b_4k_pp2.sh deleted file mode 100755 index f49df41249ca..000000000000 --- a/run_llama2_13b_4k_pp2.sh +++ /dev/null @@ -1,87 +0,0 @@ -#!/bin/bash -cd llm -task_name_or_path="llama2-13b-4k" - -#export XPUAPI_DEBUG=0x1 -#export XPURT_DISPATCH_MODE=PROFILING -export XBLAS_FC_HBM_VERSION=40 - -# PaddlePaddle -export FLAGS_use_stride_kernel="0" -export XPU_PADDLE_L3_SIZE=98566144 # 94 MB -export XPU_CDNN_CLUSTER_PARALLEL=1 -export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2 - -# PDC -unset PADDLE_ELASTIC_JOB_ID -unset PADDLE_TRAINER_ENDPOINTS -unset DISTRIBUTED_TRAINER_ENDPOINTS -unset FLAGS_START_PORT -unset PADDLE_ELASTIC_TIMEOUT -unset PADDLE_TRAINERS_NUM - -# BKCL -# export BKCL_DEBUG=1 -# Multi-computer RDMA -#export BKCL_ENABLE_XDR=1 -#export BKCL_RDMA_FORCE_TREE=1 -#export BKCL_TREE_THRESHOLD=0 -#export BKCL_RDMA_NICS=xgbe1,xgbe1,xgbe2,xgbe2,xgbe3,xgbe3,xgbe4,xgbe4 -#export BKCL_SOCKET_IFNAME=xgbe0 -#export BKCL_FORCE_L3_RDMA=0 -echo "bkcl version:" -strings ${bkcl_location}/libbkcl.so | grep COM - -export CUDA_DEVICE_MAX_CONNECTIONS=8 - -timestamp=$(date +%Y%m%d%H%M%S) -echo $timestamp -PYTHONPATH=../:$PYTHONPATH \ -python -u -m paddle.distributed.launch \ - --xpus "0,1,2,3,4,5,6,7" \ - --log_dir "output/$task_name_or_path/$timestamp""_log" \ - run_pretrain.py \ - --model_name_or_path "meta-llama/Llama-2-13b" \ - --tokenizer_name_or_path "meta-llama/Llama-2-13b" \ - --input_dir "./data" \ - --output_dir "output/$task_name_or_path/$timestamp" \ - --split 949,50,1 \ - --max_seq_length 4096 \ - --per_device_train_batch_size 1 \ - --per_device_eval_batch_size 1 \ - --use_flash_attention 1 \ - --use_fused_rope 1 \ - --fuse_attention_ffn 1 \ - --fuse_attention_qkv 1 \ - --use_fused_rms_norm 1 \ - --num_hidden_layers 40 \ - --bf16 \ - --fp16_opt_level "O2" \ - --scale_loss 1024 \ - --learning_rate 0.00003 \ - --min_learning_rate 0.000005 \ - --lr_scheduler_type "cosine" \ - --max_steps 100000 \ - --save_steps 100000 \ - --weight_decay 0.01 \ - --warmup_ratio 0.01 \ - --max_grad_norm 1.0 \ - --logging_steps 1 \ - --sequence_parallel 0 \ - --dataloader_num_workers 4 \ - --pipeline_parallel_degree 2 \ - --pipeline_parallel_config "disable_partial_send_recv" \ - --tensor_parallel_degree 2 \ - --tensor_parallel_config "enable_mp_async_allreduce,enable_mp_skip_c_identity" \ - --gradient_accumulation_steps 32 \ - --sharding "stage1" \ - --sharding_parallel_config "split_param" \ - --eval_steps 1000 \ - --report_to "visualdl" \ - --disable_tqdm true \ - --continue_training 0 \ - --recompute 0 \ - --do_train \ - --seed 1026 \ - --device "xpu" \ - --amp_master_grad true From 3f4639caff8e33d126ff5292ec9095d6384571fd Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 29 Jan 2025 18:48:04 +0800 Subject: [PATCH 4/6] Update --- llm/auto_parallel/llama/run_llama2_13b_xpu.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/llm/auto_parallel/llama/run_llama2_13b_xpu.sh b/llm/auto_parallel/llama/run_llama2_13b_xpu.sh index 6a36f36d1f71..17c60a600514 100755 --- a/llm/auto_parallel/llama/run_llama2_13b_xpu.sh +++ b/llm/auto_parallel/llama/run_llama2_13b_xpu.sh @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -cd llm task_name_or_path="llama2-13b-auto" #export XPUAPI_DEBUG=0x1 From 927d878ce35655a1719c81d8228aeeedd3f11c43 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 5 Feb 2025 16:04:01 +0800 Subject: [PATCH 5/6] Update --- paddlenlp/transformers/llama/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 1aa0a2f71502..825b836e39e0 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -1207,7 +1207,7 @@ def forward(self, prediction_scores, masked_lm_labels): loss = paddle.mean(masked_lm_loss, axis=-1) else: masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") - loss = paddle.mean(masked_lm_loss, axis=-1) + loss = paddle.mean(masked_lm_loss) return loss From bbbeb6275494e9c4d31b1e7b160ab6f338eada75 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 5 Feb 2025 19:35:50 +0800 Subject: [PATCH 6/6] Fix CI errors --- paddlenlp/transformers/llama/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 825b836e39e0..613fdac93e01 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -969,7 +969,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values y = paddle.to_tensor(min_val, dtype=dtype) expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) else: - expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min) + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min) expanded_attn_mask = expanded_attn_mask.astype(dtype) return expanded_attn_mask