From 5f7a2827cb8fc9c89a681f44ced642f55bb51226 Mon Sep 17 00:00:00 2001 From: chenxuyi Date: Thu, 23 Jan 2025 16:43:42 +0800 Subject: [PATCH 1/4] [FlashCheckpoint] support EMA --- paddlenlp/trainer/trainer.py | 13 +- paddlenlp/trainer/training_args.py | 4 + paddlenlp/trainer/utils/flash_checkpoint.py | 197 ++++++++++++++++++-- 3 files changed, 193 insertions(+), 21 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 2bd5d0081f0d..250beea787a9 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -162,7 +162,7 @@ try: from .utils.flash_checkpoint import FlashCheckpointManager, get_fused_param_mappings -except: +except (ImportError, ModuleNotFoundError): FlashCheckpointManager, get_fused_param_mappings = None, None from .utils.helper import ( # nested_truncate, broadcast_dp_optimizer, @@ -692,10 +692,11 @@ def _wrap_model_and_load_sharded_checkpoint(self, resume_from_checkpoint): self._load_from_checkpoint(resume_from_checkpoint) return model - def create_flash_checkpoint_manager(self, unwrapped_model): + def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoint=None): """ Create flash checkpoint manager. Has to be called after pipeline model is created. + resume_from_checkpoint: if use Flash checkpoing EMA, load previous checkpoint status """ assert isinstance(self.model, PretrainedModel), "model should be a PretrainedModel when using flash" logger.info("Create flash checkpoint manager...") @@ -707,6 +708,7 @@ def create_flash_checkpoint_manager(self, unwrapped_model): worker_num=self.args.flash_workers_num, pipeline_hooks_capacity=pipeline_hooks_capacity, capacity_usage=self.args.flash_pipeline_hooks_capacity_usage, + ema_coef=self.args.flash_save_ema_coef, ) for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity): unwrapped_model.register_forward_pipeline_parallel_hook( @@ -716,6 +718,11 @@ def create_flash_checkpoint_manager(self, unwrapped_model): unwrapped_model.register_backward_pipeline_parallel_hook( location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook ) + if resume_from_checkpoint is not None: + path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) + path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema") + logger.info(f"FC EMA load from {path}") + self.flash_checkpoint_manager.set_ema_state_dict(path) logger.info("Create flash checkpoint manager done.") def maybe_update_flash_checkpoint_worker(self): @@ -900,7 +907,7 @@ def train( self._load_optimizer_and_scheduler(resume_from_checkpoint) if self.args.enable_flash_save_mode: - self.create_flash_checkpoint_manager(model) + self.create_flash_checkpoint_manager(model, resume_from_checkpoint) logger.info(f"{self.runtime_timer.log()}") diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index c991548cb841..1e8cf7da2836 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -876,6 +876,10 @@ class TrainingArguments: "help": "Set pipeline hook capacity usage ratio. Lower value brings faster save speed but may effect calculation speed." }, ) + flash_save_ema_coef: Optional[float] = field( + default=0, + metadata={"help": "The coefficient of EMA parameters in flash save mode. if set to 0, skip EMA process"}, + ) save_tokenizer: Optional[bool] = field( default=True, metadata={"help": "Save tokenizer to output_dir."}, diff --git a/paddlenlp/trainer/utils/flash_checkpoint.py b/paddlenlp/trainer/utils/flash_checkpoint.py index 572ab7d8d02e..8e8703559140 100644 --- a/paddlenlp/trainer/utils/flash_checkpoint.py +++ b/paddlenlp/trainer/utils/flash_checkpoint.py @@ -13,6 +13,7 @@ # limitations under the License. import atexit +import hashlib import json # import copy @@ -32,6 +33,7 @@ ) from paddle.optimizer.fusion_utils import FusionStorageHelper +from paddlenlp.transformers.utils import device_guard from paddlenlp.utils.env import ( CONFIG_NAME, MODEL_META_NAME, @@ -43,6 +45,13 @@ from paddlenlp.utils.log import logger +def md5(tensor): + """debug use""" + numpy_array = tensor.numpy() + array_bytes = numpy_array.tobytes() + return hashlib.md5(array_bytes).hexdigest() + + class FCTaskType(Enum): """ TaskType defines the type of tasks that can be executed by the FlashCheckpointWorker. @@ -52,6 +61,7 @@ class FCTaskType(Enum): PREPARE = 1 OFFLOAD = 2 FINISH = 3 + SET_EMA_STATE_DICT = 4 class FCWorkerStatus(Enum): @@ -88,6 +98,124 @@ def get_fused_param_mappings(optimizer, manipulated_state_dict): return param_mappings, ipc_meta_mappings +class FlashEMAProcessor: + """ + 生活在 FC worker 里面的 EMA 处理模块. + 通过 `optimizer_fusion_storage_helper` 以及 `param_fusion_storage_helper` 获取主模型的参数 + """ + + def __init__(self, optimizer_fusion_storage_helper, param_fusion_storage_helper, ema_coef): + self.optimizer_fusion_storage_helper = optimizer_fusion_storage_helper + self.param_fusion_storage_helper = param_fusion_storage_helper + self.ema_coef = ema_coef + ( + self.ema_buffer, + self.ema_buffer_model_params, + self.master_min_offset, + self.master_max_offset, + ) = self.build_ema_buffer() + + def status(self): + if self.ema_buffer is None: + return "[EMA buffer] not initizied" + opt_md = md5(self.ema_buffer) + param_md = {k: md5(v) for k, v in self.ema_buffer_model_params.items()} + return f"[EMA buffer] opt:{opt_md}, param:{param_md}" + + @imperative_base.no_grad() + def build_ema_buffer(self): + logger.info("[FC EMA] build ema buffer") + master_max_offset = max( + self.optimizer_fusion_storage_helper.master_weights_meta.values(), key=lambda i: i["end"] + )["end"] + master_min_offset = min( + self.optimizer_fusion_storage_helper.master_weights_meta.values(), key=lambda i: i["start"] + )["start"] + ema_buffer = paddle.Tensor( + self.optimizer_fusion_storage_helper.cpu_buffer._slice(master_min_offset, master_max_offset), + place=paddle.CPUPlace(), + ) + # ema model params, only works on float32 model weights (aka, moe gates) + ema_buffer_model_params = { + k: cpu_buf.clone() + for k, (cuda_buf, cpu_buf) in self.param_fusion_storage_helper.inited_buffers.items() + if cuda_buf.dtype == paddle.float32 + } + return ema_buffer, ema_buffer_model_params, master_min_offset, master_max_offset + + def ema_reset(self): + self.ema_buffer = None + self.ema_buffer_modele_params = None + + @imperative_base.no_grad() + def ema_accumulate(self): + """ + perform ema update : ` \alpha * EMA + (1-\alpha) + model` + buid `self.ema_buffer` if necessary + """ + # logger.info(f'[FC EMA] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}') + # do update: ema = alpha * ema + (1-alpha) * model + cpu_master_weights = self.optimizer_fusion_storage_helper.cpu_buffer._slice( + self.master_min_offset, self.master_max_offset + ).cpu() + self.ema_buffer = self.ema_coef * self.ema_buffer + (1 - self.ema_coef) * cpu_master_weights + # logger.info(f'[FC EMA2] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}') + for index, ema_buf in self.ema_buffer_model_params.items(): + _, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index] + updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf + self.ema_buffer_model_params[index] = updated_ema + logger.info(f"[FC EMA] done, buffer type:{self.ema_buffer.dtype}") + + @imperative_base.no_grad() + def ema_state_dict(self): + assert self.optimizer_fusion_storage_helper is not None + logger.info("[FC EMA] convert ema master weights state dict") + ema_state_dict = {} + for k, tensor_meta in self.param_fusion_storage_helper.model_weights_metas.items(): + shape = tensor_meta["shape"] + name = tensor_meta["name"] + start = tensor_meta["start"] + end = tensor_meta["end"] + if tensor_meta["buffer_index"] not in self.ema_buffer_model_params: + continue # non fp32 has no `self.ema_buffer_model_params` + cpu_buffer = self.ema_buffer_model_params[tensor_meta["buffer_index"]] + tensor = cpu_buffer._slice(start, end) + tensor.get_tensor()._set_dims(shape) + tensor.name = name + ema_state_dict[k] = tensor + ema_state_dict_master_weights = {} + for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items(): + t = self.ema_buffer._slice(meta["start"] - self.master_min_offset, meta["end"] - self.master_min_offset) + t.get_tensor()._set_dims(meta["shape"]) + t.name = meta["name"] + ema_state_dict_master_weights[k] = t + ema_state_dict["master_weights"] = ema_state_dict_master_weights + logger.info("[FC EMA] done covert") + return ema_state_dict + + def load_ema_state_dict(self, path): + with device_guard(): + logger.info(f"[FC EMA] load state dict from {path}") + state_dict = paddle.load(path) + for k, tensor_meta in self.param_fusion_storage_helper.model_weights_metas.items(): + logger.info(f"[FC EMA] load model weight key={k}") + start = tensor_meta["start"] + end = tensor_meta["end"] + if tensor_meta["buffer_index"] not in self.ema_buffer_model_params: + continue # non fp32 has no `self.ema_buffer_model_params` + cpu_buffer = self.ema_buffer_model_params[tensor_meta["buffer_index"]] + tensor = state_dict[k].flatten() + cpu_buffer[start:end] = tensor + + ema_master = state_dict["master_weights"] + for k, meta in self.optimizer_fusion_storage_helper.master_weights_meta.items(): + logger.info(f"[FC EMA] load optimizer weight key={k}") + s = meta["start"] - self.master_min_offset + e = meta["end"] - self.master_min_offset + self.ema_buffer[s:e] = ema_master[k] + logger.info("[FC EMA] done loading") + + class ParamFusionStorageHelper: def __init__( self, @@ -208,7 +336,7 @@ def restore_tensor_from_meta(self, tensor_meta): class FlashCheckpointManager: - def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage): + def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef=None): assert worker_num > 0, "worker_num must be greater than 0" assert capacity_usage <= 1.0, "capacity_usage must be less than or equal to 1.0" self.cache_version = 0 @@ -218,8 +346,11 @@ def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage): self.current_worker = None self.device_id = int(os.getenv("FLAGS_selected_gpus")) self.pipeline_hooks_steps = max(int(pipeline_hooks_capacity * capacity_usage), 1) + self.ema_coef = ema_coef logger.info( - f"[FC manager] pipeline hooks capacity: {pipeline_hooks_capacity}; pipeline hooks steps for offloading: {self.pipeline_hooks_steps}" + f"[FC manager] pipeline hooks capacity: {pipeline_hooks_capacity}; " + f"pipeline hooks steps for offloading: {self.pipeline_hooks_steps} " + f"ema coefficient: {ema_coef} " ) self.current_pipeline_hook_step = 0 ctx = multiprocessing.get_context("spawn") @@ -235,6 +366,7 @@ def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage): worker_task_queue, worker_status, worker_version, + ema_coef, ) p = ctx.Process(target=worker_loop, args=(worker,)) p.start() @@ -243,6 +375,13 @@ def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage): self.ready_to_save = False atexit.register(self.terminate_workers) + def set_ema_state_dict(self, path): + logger.info(f"[FC manager] setting EMA state dict: {path}") + for worker in self.workers: + assert worker.status.value == FCWorkerStatus.IDLE.value, "[FC manager] worker should be idle, when " + worker.task_queue.put((FCTaskType.SET_EMA_STATE_DICT, path)) + logger.info("[FC manager] done setting EMA state dict") + def update_flash_workers(self, new_version, dynamic_objecs, static_objects): self.report_error_worker() self.cache_version = new_version @@ -337,7 +476,7 @@ def worker_loop(worker): class FlashCheckpointWorker: - def __init__(self, worker_id, device_id, global_rank, offload_chunks, task_queue, status, version): + def __init__(self, worker_id, device_id, global_rank, offload_chunks, task_queue, status, version, ema_coef=None): super().__init__() self.worker_id = worker_id self.device_id = device_id @@ -346,6 +485,7 @@ def __init__(self, worker_id, device_id, global_rank, offload_chunks, task_queue self.task_queue = task_queue self.status = status self.version = version + self.ema_coef = ema_coef # for dynamic objects saving self.optimizer_fusion_storage_helper = None @@ -370,6 +510,7 @@ def __init__(self, worker_id, device_id, global_rank, offload_chunks, task_queue # for dumping self.flash_save_dir = None self.persistent_save_dir = None + self.flash_ema_processor = None def process_update_task(self, updates): """ @@ -426,6 +567,8 @@ def process_offload_task(self): if self.offloaded_numels == self.all_numel: self.optimizer_fusion_storage_helper.wait_all() self.param_fusion_storage_helper.wait_all() + if self.ema_coef: + self.flash_ema_processor.ema_accumulate() self.status.value = FCWorkerStatus.DUMPING.value # continue to process dumping task at the last chunk @@ -494,7 +637,9 @@ def process_dump_task_impl(self, output_dir): # Step2.2: save optimizer states optimizer_state_name_path = os.path.join(output_dir, self.optimizer_states_name_path) paddle.save(self.optimizer_fusion_storage_helper.state_dict(), optimizer_state_name_path) - + if self.ema_coef: + ema_name_path = os.path.join(output_dir, self.optimizer_states_name_path).replace("optimizer", "ema") + paddle.save(self.flash_ema_processor.ema_state_dict(), ema_name_path) # Step2.3: save LR Scheduler (To be removed) lr_state_name_path = os.path.join(output_dir, SCHEDULER_NAME) if self.device_id == 0: @@ -514,20 +659,36 @@ def run(self): core.set_cuda_current_device_id(self.device_id) paddle.set_device(f"gpu:{self.device_id}") logger.info(f"[FC worker{self.worker_id}] Worker{self.worker_id} started.") - while True: - task = self.task_queue.get() - task_type, task_body = task - if task_type == FCTaskType.FINISH: - logger.info(f"[FC worker{self.worker_id}] Flash checkpoint worker{self.worker_id} exit") - break - elif task_type == FCTaskType.UPDATE: - self.process_update_task(task_body) - elif task_type == FCTaskType.PREPARE: - self.process_prepare_task(task_body) - elif task_type == FCTaskType.OFFLOAD: - self.process_offload_task() - else: - raise ValueError(f"[FC worker{self.worker_id}] Unknown task type: {task_type}") + ema_ckpt_path = None + try: + while True: + task = self.task_queue.get() + task_type, task_body = task + # logger.info(f'[FC worker{self.worker_id}] Received a new task of type {task_type}., ema:{self.flash_ema_processor.status() if self.flash_ema_processor is not None else None}') + if task_type == FCTaskType.FINISH: + logger.info(f"[FC worker{self.worker_id}] Flash checkpoint worker{self.worker_id} exit") + break + elif task_type == FCTaskType.UPDATE: + self.process_update_task(task_body) + self.flash_ema_processor = FlashEMAProcessor( # 在 updte task 后刷新 EMA buffer + self.optimizer_fusion_storage_helper, self.param_fusion_storage_helper, self.ema_coef + ) + if ema_ckpt_path is not None: # update ema if needed + self.flash_ema_processor.load_ema_state_dict(ema_ckpt_path) + ema_ckpt_path = None + elif task_type == FCTaskType.PREPARE: + self.process_prepare_task(task_body) + elif task_type == FCTaskType.OFFLOAD: + self.process_offload_task() + elif task_type == FCTaskType.SET_EMA_STATE_DICT: + ema_ckpt_path = task_body # mark ema state dict path + else: + raise ValueError(f"[FC worker{self.worker_id}] Unknown task type: {task_type}") + except Exception as e: + import traceback + + logger.info(f"[FC worker{self.worker_id}] failed!!, Exception:{e}\n Traceback:{traceback.format_exc()}\n") + raise e def build_fusion_storage_helper(self, optimizer_states_meta, model_states_meta): ( From 94246f6b734cc4771371f32d3a6a841baa542b2a Mon Sep 17 00:00:00 2001 From: chenxuyi Date: Thu, 23 Jan 2025 22:27:01 +0800 Subject: [PATCH 2/4] support non-pp --- paddlenlp/trainer/trainer.py | 52 +++++++++++++-------- paddlenlp/trainer/utils/flash_checkpoint.py | 9 ++++ 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 250beea787a9..f42ea241a266 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -42,6 +42,7 @@ import paddle.nn as nn from packaging import version from paddle import framework +from paddle.distributed.fleet.meta_parallel import PipelineLayer try: from paddle.base import core @@ -161,7 +162,11 @@ from .utils.async_save import AsyncSaver try: - from .utils.flash_checkpoint import FlashCheckpointManager, get_fused_param_mappings + from .utils.flash_checkpoint import ( + FlashCheckpointCallback, + FlashCheckpointManager, + get_fused_param_mappings, + ) except (ImportError, ModuleNotFoundError): FlashCheckpointManager, get_fused_param_mappings = None, None from .utils.helper import ( # nested_truncate, @@ -350,8 +355,6 @@ def __init__( ) if self.args.pipeline_parallel_degree > 1 and self.args.use_hybrid_parallel: - from paddle.distributed.fleet.meta_parallel import PipelineLayer - assert (isinstance(model, LoRAModel) and isinstance(model.model, PipelineLayer)) or isinstance( model, PipelineLayer ), "Only support pipeline parallel mode when model is PipelineLayer!!!" @@ -700,24 +703,35 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin """ assert isinstance(self.model, PretrainedModel), "model should be a PretrainedModel when using flash" logger.info("Create flash checkpoint manager...") - pipeline_hooks_capacity = ( - unwrapped_model.forward_pipeline_parallel_hook_capacity - + unwrapped_model.backward_pipeline_parallel_hook_capacity - ) - self.flash_checkpoint_manager = FlashCheckpointManager( - worker_num=self.args.flash_workers_num, - pipeline_hooks_capacity=pipeline_hooks_capacity, - capacity_usage=self.args.flash_pipeline_hooks_capacity_usage, - ema_coef=self.args.flash_save_ema_coef, - ) - for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity): - unwrapped_model.register_forward_pipeline_parallel_hook( - location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook + if isinstance(unwrapped_model, PipelineLayer): + pipeline_hooks_capacity = ( + unwrapped_model.forward_pipeline_parallel_hook_capacity + + unwrapped_model.backward_pipeline_parallel_hook_capacity ) - for i in range(unwrapped_model.backward_pipeline_parallel_hook_capacity): - unwrapped_model.register_backward_pipeline_parallel_hook( - location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook + self.flash_checkpoint_manager = FlashCheckpointManager( + worker_num=self.args.flash_workers_num, + pipeline_hooks_capacity=pipeline_hooks_capacity, + capacity_usage=self.args.flash_pipeline_hooks_capacity_usage, + ema_coef=self.args.flash_save_ema_coef, + ) + for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity): + unwrapped_model.register_forward_pipeline_parallel_hook( + location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook + ) + for i in range(unwrapped_model.backward_pipeline_parallel_hook_capacity): + unwrapped_model.register_backward_pipeline_parallel_hook( + location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook + ) + else: + pipeline_hooks_capacity = self.args.gradient_accumulation_steps + self.flash_checkpoint_manager = FlashCheckpointManager( + worker_num=self.args.flash_workers_num, + pipeline_hooks_capacity=pipeline_hooks_capacity, + capacity_usage=self.args.flash_pipeline_hooks_capacity_usage, + ema_coef=self.args.flash_save_ema_coef, ) + _callback = FlashCheckpointCallback(self.flash_checkpoint_manager) + self.add_callback(_callback) if resume_from_checkpoint is not None: path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema") diff --git a/paddlenlp/trainer/utils/flash_checkpoint.py b/paddlenlp/trainer/utils/flash_checkpoint.py index 8e8703559140..adb538f9c4af 100644 --- a/paddlenlp/trainer/utils/flash_checkpoint.py +++ b/paddlenlp/trainer/utils/flash_checkpoint.py @@ -33,6 +33,7 @@ ) from paddle.optimizer.fusion_utils import FusionStorageHelper +from paddlenlp.trainer.trainer_callback import TrainerCallback from paddlenlp.transformers.utils import device_guard from paddlenlp.utils.env import ( CONFIG_NAME, @@ -335,6 +336,14 @@ def restore_tensor_from_meta(self, tensor_meta): return tensor +class FlashCheckpointCallback(TrainerCallback): + def __init__(self, flash_checkpoint_manager): + self.manager = flash_checkpoint_manager + + def on_substep_end(self, args, state, control, **kwargs): + self.manager.flash_checkpoint_pipeline_hook(0) + + class FlashCheckpointManager: def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef=None): assert worker_num > 0, "worker_num must be greater than 0" From 4c15ff2edfd04f63e011b34db77f16031f2c1dbe Mon Sep 17 00:00:00 2001 From: chenxuyi Date: Fri, 24 Jan 2025 13:00:18 +0800 Subject: [PATCH 3/4] refactor: move [flash checkpoint manager] to callback --- paddlenlp/trainer/trainer.py | 90 +---------- paddlenlp/trainer/training_args.py | 4 + paddlenlp/trainer/utils/flash_checkpoint.py | 168 +++++++++++++++++--- 3 files changed, 158 insertions(+), 104 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index f42ea241a266..b80343339585 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -18,7 +18,6 @@ import collections import contextlib -import copy import inspect import json import math @@ -92,7 +91,6 @@ from ..transformers.model_utils import ( PretrainedModel, _add_variant, - get_parameter_dtype, load_sharded_checkpoint, unwrap_model, ) @@ -405,12 +403,7 @@ def __init__( assert not self.args.save_rng_states, "save_rng_states is not supported when using flash save mode" # init attributes for flash save mode - self.manipulated_state_dict = None - self.manipulated_config_to_save = None - self.manipulated_weight_suffix = None - self.model_meta = None self.flash_checkpoint_manager = None - self.user_file_list = [] if self.args.ordered_save_group_size > 0: logger.info(f"using save in order, its group size is {self.args.ordered_save_group_size}") @@ -730,8 +723,10 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin capacity_usage=self.args.flash_pipeline_hooks_capacity_usage, ema_coef=self.args.flash_save_ema_coef, ) - _callback = FlashCheckpointCallback(self.flash_checkpoint_manager) - self.add_callback(_callback) + _callback = FlashCheckpointCallback( + self.args, self.flash_checkpoint_manager, self.runtime_timer, self.sharding_io + ) + self.add_callback(_callback) if resume_from_checkpoint is not None: path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema") @@ -739,58 +734,6 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin self.flash_checkpoint_manager.set_ema_state_dict(path) logger.info("Create flash checkpoint manager done.") - def maybe_update_flash_checkpoint_worker(self): - if self.optimizer.fused_buffer_version == self.flash_checkpoint_manager.cache_version: - return - - logger.info("Flash checkpoint workers need upgrade.") - self._cache_meta_for_sharded_save() - param_mappings, ipc_meta_mappings = get_fused_param_mappings(self.optimizer, self.manipulated_state_dict) - optimizer_states_meta = ( - self.optimizer.fused_states_accumulators_meta, - self.optimizer.fused_states_master_weights_meta, - None, - self.optimizer.fused_states_buffer_ipc_meta, - ) - model_states_meta = (param_mappings, ipc_meta_mappings) - optimizer_states_name_path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) - model_states_name_path = _add_variant(PADDLE_WEIGHTS_NAME, self.manipulated_weight_suffix) - - dynamic_objecs = {} - dynamic_objecs["optimizer_states_meta"] = optimizer_states_meta - dynamic_objecs["model_states_meta"] = model_states_meta - dynamic_objecs["optimizer_states_name_path"] = optimizer_states_name_path - dynamic_objecs["model_states_name_path"] = model_states_name_path - - static_objects = {} - static_objects["model_config"] = self.manipulated_config_to_save - static_objects["training_args"] = self.args - static_objects["model_meta"] = self.model_meta - static_objects["user_file"] = self.user_file_list - - self.flash_checkpoint_manager.update_flash_workers( - self.optimizer.fused_buffer_version, dynamic_objecs, static_objects - ) - - def _cache_meta_for_sharded_save(self): - logger.info("Start caching metas for sharded save...") - ( - self.manipulated_state_dict, - self.manipulated_config_to_save, - self.manipulated_weight_suffix, - ) = self.sharding_io.manipulate_state_dict_and_config(self.model, merge_tensor_parallel=False) - logger.info("Cache manipulated static dict done.") - if self.manipulated_config_to_save is None: - model_to_save = unwrap_model(self.model) - dtype = get_parameter_dtype(model_to_save) - model_to_save.config.dtype = str(dtype).split(".")[1] - self.manipulated_config_to_save = copy.deepcopy(model_to_save.config) - self.manipulated_config_to_save.architectures = [model_to_save.__class__.__name__] - self.manipulated_config_to_save = self.manipulated_config_to_save.to_json_string(use_diff=True) - logger.info("Cache manipulated model config done") - self.model_meta = self.sharding_io.gather_distributed_model_meta() - logger.info("Cache distributed model meta done.") - def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, @@ -1230,10 +1173,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): self.callback_handler.on_optimizer_begin( args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None ) - if self.args.enable_flash_save_mode and self.flash_checkpoint_manager.current_worker is not None: - logger.info("Start syncing flash checkpoints") - self.flash_checkpoint_manager.sync_offload_status() - logger.info("Synced flash checkpoints.") optimizer_was_run = True if self.args.offload_optim: @@ -2462,28 +2401,7 @@ def _ordered_save(self, state_dict, save_path): paddle.save(state_dict, save_path) dist.barrier(mp_group) - def _get_save_infos_based_on_steps(self, checkpoint_folder): - flash_checkpoint_dir = None - persistent_checkpoint_dir = None - if self.args.flash_save_steps > 0 and self.state.global_step % self.args.flash_save_steps == 0: - flash_checkpoint_dir = os.path.join(FLASH_DEVICE, checkpoint_folder) - if self.args.save_steps > 0 and self.state.global_step % self.args.save_steps == 0: - persistent_checkpoint_dir = os.path.join(self.args.output_dir, checkpoint_folder) - return (flash_checkpoint_dir, persistent_checkpoint_dir) - - def _save_checkpoint_flash(self): - self.runtime_timer.start("checkpoint saving time") - self.maybe_update_flash_checkpoint_worker() - checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" - save_infos = self._get_save_infos_based_on_steps(checkpoint_folder) - non_cached_objects = (self.lr_scheduler.state_dict(), self.state) - self.flash_checkpoint_manager.get_idle_worker_for_saving(save_infos, non_cached_objects) - self.runtime_timer.stop() - def _save_checkpoint(self, model, metrics=None): - if self.args.enable_flash_save_mode: - self._save_checkpoint_flash() - return # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" self.runtime_timer.start("checkpoint saving time") diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 1e8cf7da2836..611b2b9540e4 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -880,6 +880,10 @@ class TrainingArguments: default=0, metadata={"help": "The coefficient of EMA parameters in flash save mode. if set to 0, skip EMA process"}, ) + flash_ema_interval: Optional[int] = field( + default=1, + metadata={"help": "Interval between updating EMA parameters."}, + ) save_tokenizer: Optional[bool] = field( default=True, metadata={"help": "Save tokenizer to output_dir."}, diff --git a/paddlenlp/trainer/utils/flash_checkpoint.py b/paddlenlp/trainer/utils/flash_checkpoint.py index adb538f9c4af..89a57a22617c 100644 --- a/paddlenlp/trainer/utils/flash_checkpoint.py +++ b/paddlenlp/trainer/utils/flash_checkpoint.py @@ -13,10 +13,9 @@ # limitations under the License. import atexit +import copy import hashlib import json - -# import copy import multiprocessing import os import time @@ -34,16 +33,25 @@ from paddle.optimizer.fusion_utils import FusionStorageHelper from paddlenlp.trainer.trainer_callback import TrainerCallback +from paddlenlp.transformers.model_utils import ( + _add_variant, + get_parameter_dtype, + unwrap_model, +) from paddlenlp.transformers.utils import device_guard from paddlenlp.utils.env import ( CONFIG_NAME, MODEL_META_NAME, + PADDLE_OPTIMIZER_NAME, + PADDLE_WEIGHTS_NAME, + PREFIX_CHECKPOINT_DIR, SCHEDULER_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME, ) from paddlenlp.utils.fault_tolerance import FC_DUMP_ERROR, PC_DUMP_ERROR from paddlenlp.utils.log import logger +from paddlenlp.utils.pdc_sdk import FLASH_DEVICE def md5(tensor): @@ -62,7 +70,7 @@ class FCTaskType(Enum): PREPARE = 1 OFFLOAD = 2 FINISH = 3 - SET_EMA_STATE_DICT = 4 + SET_EMA_STATE_DICT = 5 class FCWorkerStatus(Enum): @@ -156,6 +164,7 @@ def ema_accumulate(self): """ # logger.info(f'[FC EMA] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}') # do update: ema = alpha * ema + (1-alpha) * model + logger.info("[FC EMA] start") cpu_master_weights = self.optimizer_fusion_storage_helper.cpu_buffer._slice( self.master_min_offset, self.master_max_offset ).cpu() @@ -337,12 +346,125 @@ def restore_tensor_from_meta(self, tensor_meta): class FlashCheckpointCallback(TrainerCallback): - def __init__(self, flash_checkpoint_manager): + """ + call FlashCheckpointManager during training in following order: + + on_step_end: + * call get_idle_worker_for_saving, set manager.current_worker + * call maybe_update_flash_checkpoint_worker + + * on_substep_end(call `gradient_accumulate` times): call flash_checkpoint_pipeline_hook (in non-pp model) + * (when offload done, dump model) + on_optimizer_begin: call sync_offload_status, unset set manager.current_worker + """ + + def __init__(self, args, flash_checkpoint_manager, timer, sharding_io): self.manager = flash_checkpoint_manager + self.runtime_timer = timer + self.user_file_list = [] + self.manipulated_state_dict = None + self.manipulated_config_to_save = None + self.manipulated_weight_suffix = None + self.model_meta = None + self.sharding_io = sharding_io + assert ( + args.flash_save_steps % args.flash_ema_interval == 0 + ), f"flash_save_steps:{args.flash_save_steps} must be divisible by flash_ema_interval:{args.flash_ema_interval}" + assert ( + args.save_steps % args.flash_ema_interval == 0 + ), f"save_steps:{args.save_steps} must be divisible by flash_ema_interval:{args.flash_ema_interval}" + self.flash_ema_interval = args.flash_ema_interval + if args.flash_save_ema_coef: + assert args.flash_workers_num == 1, "[FC EMA] not support #worker > 1" def on_substep_end(self, args, state, control, **kwargs): self.manager.flash_checkpoint_pipeline_hook(0) + def on_optimizer_begin(self, args, state, control, **kwargs): + if args.enable_flash_save_mode and self.manager.current_worker is not None: + logger.info("Start syncing flash checkpoints") + self.manager.sync_offload_status() + logger.info("Synced flash checkpoints.") + + def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kwargs): + self.manager.flash_checkpoint_pipeline_hook(0) + logger.info( + f"check coef: {args.flash_save_ema_coef} {control.should_save}, {state.global_step}, {self.flash_ema_interval}" + ) + if not control.should_save: + if args.flash_save_ema_coef and state.global_step % self.flash_ema_interval == 0: + self.maybe_update_flash_checkpoint_worker(args, model, optimizer) + self.manager.get_idle_worker_for_saving() # prepare for dumping + else: + self.runtime_timer.start("checkpoint saving time") + self.maybe_update_flash_checkpoint_worker(args, model, optimizer) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" + save_infos = self._get_save_infos_based_on_steps(state, args, checkpoint_folder) + non_cached_objects = (lr_scheduler.state_dict(), state) + self.manager.get_idle_worker_for_saving((save_infos, non_cached_objects)) + self.runtime_timer.stop() + control.should_save = False # avoid regular saving + + def _get_save_infos_based_on_steps(self, state, args, checkpoint_folder): + flash_checkpoint_dir = None + persistent_checkpoint_dir = None + if args.flash_save_steps > 0 and state.global_step % args.flash_save_steps == 0: + flash_checkpoint_dir = os.path.join(FLASH_DEVICE, checkpoint_folder) + if args.save_steps > 0 and state.global_step % args.save_steps == 0: + persistent_checkpoint_dir = os.path.join(args.output_dir, checkpoint_folder) + return (flash_checkpoint_dir, persistent_checkpoint_dir) + + def maybe_update_flash_checkpoint_worker(self, args, model, optimizer): + # logger.info(f'check should update :{optimizer.fused_buffer_version} vs {self.manager.cache_version}') + if optimizer.fused_buffer_version == self.manager.cache_version: + return + + logger.info("Flash checkpoint workers need upgrade.") + self._cache_meta_for_sharded_save(model) + param_mappings, ipc_meta_mappings = get_fused_param_mappings(optimizer, self.manipulated_state_dict) + optimizer_states_meta = ( + optimizer.fused_states_accumulators_meta, + optimizer.fused_states_master_weights_meta, + None, + optimizer.fused_states_buffer_ipc_meta, + ) + model_states_meta = (param_mappings, ipc_meta_mappings) + optimizer_states_name_path = _add_variant(PADDLE_OPTIMIZER_NAME, args.optimizer_name_suffix) + model_states_name_path = _add_variant(PADDLE_WEIGHTS_NAME, self.manipulated_weight_suffix) + + dynamic_objecs = {} + dynamic_objecs["optimizer_states_meta"] = optimizer_states_meta + dynamic_objecs["model_states_meta"] = model_states_meta + dynamic_objecs["optimizer_states_name_path"] = optimizer_states_name_path + dynamic_objecs["model_states_name_path"] = model_states_name_path + + static_objects = {} + static_objects["model_config"] = self.manipulated_config_to_save + static_objects["training_args"] = args + static_objects["model_meta"] = self.model_meta + static_objects["user_file"] = self.user_file_list + + self.manager.update_flash_workers(optimizer.fused_buffer_version, dynamic_objecs, static_objects) + + def _cache_meta_for_sharded_save(self, model): + logger.info("Start caching metas for sharded save...") + ( + self.manipulated_state_dict, + self.manipulated_config_to_save, + self.manipulated_weight_suffix, + ) = self.sharding_io.manipulate_state_dict_and_config(model, merge_tensor_parallel=False) + logger.info("Cache manipulated static dict done.") + if self.manipulated_config_to_save is None: + model_to_save = unwrap_model(model) + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.dtype = str(dtype).split(".")[1] + self.manipulated_config_to_save = copy.deepcopy(model_to_save.config) + self.manipulated_config_to_save.architectures = [model_to_save.__class__.__name__] + self.manipulated_config_to_save = self.manipulated_config_to_save.to_json_string(use_diff=True) + logger.info("Cache manipulated model config done") + self.model_meta = self.sharding_io.gather_distributed_model_meta() + logger.info("Cache distributed model meta done.") + class FlashCheckpointManager: def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef=None): @@ -355,7 +477,6 @@ def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef self.current_worker = None self.device_id = int(os.getenv("FLAGS_selected_gpus")) self.pipeline_hooks_steps = max(int(pipeline_hooks_capacity * capacity_usage), 1) - self.ema_coef = ema_coef logger.info( f"[FC manager] pipeline hooks capacity: {pipeline_hooks_capacity}; " f"pipeline hooks steps for offloading: {self.pipeline_hooks_steps} " @@ -412,7 +533,10 @@ def update_flash_workers(self, new_version, dynamic_objecs, static_objects): logger.info("[FC manager] update all flash workers done") self.ready_to_save = True - def get_idle_worker_for_saving(self, save_infos, non_cached_objects): + def get_idle_worker_for_saving(self, save_infos_and_non_cached_objects=None): + """ + if `save_infos_and_non_cached_objects` is None, do offload without dumping. + """ self.report_error_worker() assert self.current_worker is None, "[FC manager] current_worker must be None" found_worker = False @@ -424,12 +548,16 @@ def get_idle_worker_for_saving(self, save_infos, non_cached_objects): break if found_worker: break - logger.info("[FC manager] Waiting for idle worker...") + logger.info("[FC manager] Waiting for idle worker..., consider increse `save-step` or `global-batch-size`") time.sleep(1) - task = (FCTaskType.PREPARE, (save_infos, non_cached_objects)) - logger.info("[FC manager] before putting task for prepare") + task = (FCTaskType.PREPARE, save_infos_and_non_cached_objects) + logger.info( + f"[FC manager] before putting task for prepare, dumping={save_infos_and_non_cached_objects is not None}" + ) self.current_worker.task_queue.put(task) - logger.info("[FC manager] after putting task for prepare") + logger.info( + f"[FC manager] after putting task for prepare, dumping={save_infos_and_non_cached_objects is not None}" + ) def sync_offload_status(self): self.report_error_worker() @@ -542,13 +670,15 @@ def process_update_task(self, updates): self.version.value = version def process_prepare_task(self, prepares): - save_infos, non_cached_objects = prepares self.offloaded_numels = 0 self.status.value = FCWorkerStatus.OFFLOADING.value + if prepares is None: # when `prepares` is None, not dumping + return + save_infos, non_cached_objects = prepares self.flash_save_dir, self.persistent_save_dir = save_infos self.lr_scheduler, self.trainer_state = non_cached_objects - def process_offload_task(self): + def process_offload_task(self, dump): actual_offload_size = ( min(self.offloaded_numels + self.chunk_size_in_numel, self.all_numel) - self.offloaded_numels ) @@ -582,12 +712,12 @@ def process_offload_task(self): # continue to process dumping task at the last chunk if self.offloaded_numels == self.all_numel: - need_report_error = self.process_dump_task() - self.offloaded_numels = 0 - if need_report_error: - self.status.value = FCWorkerStatus.ERROR.value + if dump: + need_report_error = self.process_dump_task() else: - self.status.value = FCWorkerStatus.IDLE.value + need_report_error = False + self.offloaded_numels = 0 + self.status.value = FCWorkerStatus.ERROR.value if need_report_error else FCWorkerStatus.IDLE.value def process_dump_task(self): """ @@ -669,6 +799,7 @@ def run(self): paddle.set_device(f"gpu:{self.device_id}") logger.info(f"[FC worker{self.worker_id}] Worker{self.worker_id} started.") ema_ckpt_path = None + save_info_tuple = None # save dir... try: while True: task = self.task_queue.get() @@ -686,9 +817,10 @@ def run(self): self.flash_ema_processor.load_ema_state_dict(ema_ckpt_path) ema_ckpt_path = None elif task_type == FCTaskType.PREPARE: + save_info_tuple = task_body self.process_prepare_task(task_body) elif task_type == FCTaskType.OFFLOAD: - self.process_offload_task() + self.process_offload_task(dump=save_info_tuple is not None) elif task_type == FCTaskType.SET_EMA_STATE_DICT: ema_ckpt_path = task_body # mark ema state dict path else: From c5e86429351973dcaee26d2d569b089bed803201 Mon Sep 17 00:00:00 2001 From: chenxuyi Date: Wed, 5 Feb 2025 21:19:49 +0800 Subject: [PATCH 4/4] [flashcheckpoint] fix save under dp_degree > 1 with use-expert-parallel --- paddlenlp/trainer/trainer.py | 2 + paddlenlp/trainer/training_args.py | 2 +- paddlenlp/trainer/utils/flash_checkpoint.py | 88 ++++++++++++++++++--- paddlenlp/trainer/utils/sharding_io.py | 3 +- 4 files changed, 84 insertions(+), 11 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index b80343339585..533f2e2d6abe 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -705,6 +705,7 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin worker_num=self.args.flash_workers_num, pipeline_hooks_capacity=pipeline_hooks_capacity, capacity_usage=self.args.flash_pipeline_hooks_capacity_usage, + use_expert_parallel=self.args.use_expert_parallel, ema_coef=self.args.flash_save_ema_coef, ) for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity): @@ -721,6 +722,7 @@ def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoin worker_num=self.args.flash_workers_num, pipeline_hooks_capacity=pipeline_hooks_capacity, capacity_usage=self.args.flash_pipeline_hooks_capacity_usage, + use_expert_parallel=self.args.use_expert_parallel, ema_coef=self.args.flash_save_ema_coef, ) _callback = FlashCheckpointCallback( diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 611b2b9540e4..35ca38188440 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -877,7 +877,7 @@ class TrainingArguments: }, ) flash_save_ema_coef: Optional[float] = field( - default=0, + default=None, metadata={"help": "The coefficient of EMA parameters in flash save mode. if set to 0, skip EMA process"}, ) flash_ema_interval: Optional[int] = field( diff --git a/paddlenlp/trainer/utils/flash_checkpoint.py b/paddlenlp/trainer/utils/flash_checkpoint.py index 89a57a22617c..d3138ac309d4 100644 --- a/paddlenlp/trainer/utils/flash_checkpoint.py +++ b/paddlenlp/trainer/utils/flash_checkpoint.py @@ -26,6 +26,7 @@ import paddle.autograd as imperative_base import paddle.distributed as dist from paddle.base import core +from paddle.distributed.fleet import fleet from paddle.incubate.tensor.manipulation import ( async_offload_with_offset, create_async_load, @@ -374,7 +375,7 @@ def __init__(self, args, flash_checkpoint_manager, timer, sharding_io): args.save_steps % args.flash_ema_interval == 0 ), f"save_steps:{args.save_steps} must be divisible by flash_ema_interval:{args.flash_ema_interval}" self.flash_ema_interval = args.flash_ema_interval - if args.flash_save_ema_coef: + if args.flash_save_ema_coef is not None: assert args.flash_workers_num == 1, "[FC EMA] not support #worker > 1" def on_substep_end(self, args, state, control, **kwargs): @@ -392,7 +393,7 @@ def on_step_end(self, args, state, control, model, lr_scheduler, optimizer, **kw f"check coef: {args.flash_save_ema_coef} {control.should_save}, {state.global_step}, {self.flash_ema_interval}" ) if not control.should_save: - if args.flash_save_ema_coef and state.global_step % self.flash_ema_interval == 0: + if args.flash_save_ema_coef is not None and state.global_step % self.flash_ema_interval == 0: self.maybe_update_flash_checkpoint_worker(args, model, optimizer) self.manager.get_idle_worker_for_saving() # prepare for dumping else: @@ -467,7 +468,7 @@ def _cache_meta_for_sharded_save(self, model): class FlashCheckpointManager: - def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef=None): + def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, use_expert_parallel, ema_coef=None): assert worker_num > 0, "worker_num must be greater than 0" assert capacity_usage <= 1.0, "capacity_usage must be less than or equal to 1.0" self.cache_version = 0 @@ -484,6 +485,7 @@ def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef ) self.current_pipeline_hook_step = 0 ctx = multiprocessing.get_context("spawn") + assert hasattr(fleet, "_hcg"), "FlashCheckpoint Only support `use_hybrid_parallel`" for i in range(worker_num): worker_task_queue = ctx.Queue() worker_status = ctx.Value("i", FCWorkerStatus.IDLE.value) @@ -496,6 +498,11 @@ def __init__(self, worker_num, pipeline_hooks_capacity, capacity_usage, ema_coef worker_task_queue, worker_status, worker_version, + use_expert_parallel, + fleet.get_hybrid_communicate_group().get_data_parallel_rank(), + fleet.get_hybrid_communicate_group().get_model_parallel_rank(), + fleet.get_hybrid_communicate_group()._get_pipe_parallel_id(), + fleet.get_hybrid_communicate_group().get_sharding_parallel_rank(), ema_coef, ) p = ctx.Process(target=worker_loop, args=(worker,)) @@ -613,7 +620,22 @@ def worker_loop(worker): class FlashCheckpointWorker: - def __init__(self, worker_id, device_id, global_rank, offload_chunks, task_queue, status, version, ema_coef=None): + def __init__( + self, + worker_id, + device_id, + global_rank, + offload_chunks, + task_queue, + status, + version, + use_expert_parallel, + dp_rank, + mp_rank, + pp_rank, + sd_rank, + ema_coef=None, + ): super().__init__() self.worker_id = worker_id self.device_id = device_id @@ -623,6 +645,11 @@ def __init__(self, worker_id, device_id, global_rank, offload_chunks, task_queue self.status = status self.version = version self.ema_coef = ema_coef + self.use_expert_parallel = use_expert_parallel + self.dp_rank = dp_rank + self.mp_rank = mp_rank + self.pp_rank = pp_rank + self.sd_rank = sd_rank # for dynamic objects saving self.optimizer_fusion_storage_helper = None @@ -706,7 +733,7 @@ def process_offload_task(self, dump): if self.offloaded_numels == self.all_numel: self.optimizer_fusion_storage_helper.wait_all() self.param_fusion_storage_helper.wait_all() - if self.ema_coef: + if self.ema_coef is not None: self.flash_ema_processor.ema_accumulate() self.status.value = FCWorkerStatus.DUMPING.value @@ -744,6 +771,35 @@ def process_dump_task(self): need_report_error = True return need_report_error + def _filter_moe_no_sync_optimizer_params(self, model_meta, optimzier_state_dict): + """ + filter optimizer params which should not sync, copy from paddlenlp.Trainer + """ + filter_optimzier_state_dict = OrderedDict() + assert "master_weights" in optimzier_state_dict, optimzier_state_dict.keys() + param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) + filter_optimzier_state_dict["master_weights"] = OrderedDict() + suffix = f"tp{self.mp_rank:0>2d}_pp{self.pp_rank:0>2d}" + dyname_to_pname = model_meta["sharding_metas"][suffix]["structure_name_mapping"] + dyname_to_meta = model_meta["sharding_metas"][suffix]["param_meta"] + for k, pname in dyname_to_pname.items(): + shape, dtype, is_dist, is_no_sync = dyname_to_meta[k] + if is_no_sync: + if pname in param_names_in_master_weights: + filter_optimzier_state_dict["master_weights"][pname] = optimzier_state_dict["master_weights"][ + pname + ] + else: + pass + # logger.info(f"filter out master weight:{pname} -> {k}") + for op_k, op_v in optimzier_state_dict.items(): + if op_k.startswith(pname): + filter_optimzier_state_dict[op_k] = op_v + else: + # logger.info(f"filter out key={k}, when dp!=0") + pass + return filter_optimzier_state_dict + def process_dump_task_impl(self, output_dir): os.makedirs(output_dir, exist_ok=True) # Step1: save static objects @@ -771,14 +827,28 @@ def process_dump_task_impl(self, output_dir): # Step2: save dynamic objects # Step2.1: save model states model_states_name_path = os.path.join(output_dir, self.model_states_name_path) - paddle.save(self.param_fusion_storage_helper.state_dict(), model_states_name_path) + state_dict = self.param_fusion_storage_helper.state_dict() # Step2.2: save optimizer states optimizer_state_name_path = os.path.join(output_dir, self.optimizer_states_name_path) - paddle.save(self.optimizer_fusion_storage_helper.state_dict(), optimizer_state_name_path) - if self.ema_coef: + opt_state_dict = self.optimizer_fusion_storage_helper.state_dict() + + if self.ema_coef is not None: ema_name_path = os.path.join(output_dir, self.optimizer_states_name_path).replace("optimizer", "ema") - paddle.save(self.flash_ema_processor.ema_state_dict(), ema_name_path) + ema_state_dict = self.flash_ema_processor.ema_state_dict() + + if self.dp_rank <= 0 or self.use_expert_parallel: + if self.dp_rank > 0: # ep + opt_state_dict = self._filter_moe_no_sync_optimizer_params(self.model_meta_content, opt_state_dict) + if self.ema_coef is not None: + # non master-weights in `ema-state-dict` when dp >1 will be filterd, which is acceptable + ema_state_dict = self._filter_moe_no_sync_optimizer_params(self.model_meta_content, ema_state_dict) + + paddle.save(state_dict, model_states_name_path) + paddle.save(opt_state_dict, optimizer_state_name_path) + if self.ema_coef is not None: + paddle.save(ema_state_dict, ema_name_path) + # Step2.3: save LR Scheduler (To be removed) lr_state_name_path = os.path.join(output_dir, SCHEDULER_NAME) if self.device_id == 0: diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index 19f1c11d196a..18fa7c67c0a0 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -597,7 +597,8 @@ def _gather_sharding_metas(self): for k, v in model.state_dict().items(): structure_name_mapping[k] = v.name is_distributed = getattr(v, "is_distributed", False) - param_meta[k] = (v.shape, int(v.dtype), is_distributed) + no_sync = getattr(v, "no_sync", False) + param_meta[k] = (v.shape, int(v.dtype), is_distributed, no_sync) sharding_metas = {} sharding_meta = {}