Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FlashCheckpoint] support EMA #9815

Open
wants to merge 4 commits into
base: incubate/paddlenlp-fleety
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 46 additions & 105 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import collections
import contextlib
import copy
import inspect
import json
import math
Expand All @@ -42,6 +41,7 @@
import paddle.nn as nn
from packaging import version
from paddle import framework
from paddle.distributed.fleet.meta_parallel import PipelineLayer

try:
from paddle.base import core
Expand Down Expand Up @@ -91,7 +91,6 @@
from ..transformers.model_utils import (
PretrainedModel,
_add_variant,
get_parameter_dtype,
load_sharded_checkpoint,
unwrap_model,
)
Expand Down Expand Up @@ -161,8 +160,12 @@
from .utils.async_save import AsyncSaver

try:
from .utils.flash_checkpoint import FlashCheckpointManager, get_fused_param_mappings
except:
from .utils.flash_checkpoint import (
FlashCheckpointCallback,
FlashCheckpointManager,
get_fused_param_mappings,
)
except (ImportError, ModuleNotFoundError):
FlashCheckpointManager, get_fused_param_mappings = None, None
from .utils.helper import ( # nested_truncate,
broadcast_dp_optimizer,
Expand Down Expand Up @@ -350,8 +353,6 @@ def __init__(
)

if self.args.pipeline_parallel_degree > 1 and self.args.use_hybrid_parallel:
from paddle.distributed.fleet.meta_parallel import PipelineLayer

assert (isinstance(model, LoRAModel) and isinstance(model.model, PipelineLayer)) or isinstance(
model, PipelineLayer
), "Only support pipeline parallel mode when model is PipelineLayer!!!"
Expand Down Expand Up @@ -402,12 +403,7 @@ def __init__(
assert not self.args.save_rng_states, "save_rng_states is not supported when using flash save mode"

# init attributes for flash save mode
self.manipulated_state_dict = None
self.manipulated_config_to_save = None
self.manipulated_weight_suffix = None
self.model_meta = None
self.flash_checkpoint_manager = None
self.user_file_list = []

if self.args.ordered_save_group_size > 0:
logger.info(f"using save in order, its group size is {self.args.ordered_save_group_size}")
Expand Down Expand Up @@ -692,83 +688,53 @@ def _wrap_model_and_load_sharded_checkpoint(self, resume_from_checkpoint):
self._load_from_checkpoint(resume_from_checkpoint)
return model

def create_flash_checkpoint_manager(self, unwrapped_model):
def create_flash_checkpoint_manager(self, unwrapped_model, resume_from_checkpoint=None):
"""
Create flash checkpoint manager.
Has to be called after pipeline model is created.
resume_from_checkpoint: if use Flash checkpoing EMA, load previous checkpoint status
"""
assert isinstance(self.model, PretrainedModel), "model should be a PretrainedModel when using flash"
logger.info("Create flash checkpoint manager...")
pipeline_hooks_capacity = (
unwrapped_model.forward_pipeline_parallel_hook_capacity
+ unwrapped_model.backward_pipeline_parallel_hook_capacity
)
self.flash_checkpoint_manager = FlashCheckpointManager(
worker_num=self.args.flash_workers_num,
pipeline_hooks_capacity=pipeline_hooks_capacity,
capacity_usage=self.args.flash_pipeline_hooks_capacity_usage,
)
for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity):
unwrapped_model.register_forward_pipeline_parallel_hook(
location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook
if isinstance(unwrapped_model, PipelineLayer):
pipeline_hooks_capacity = (
unwrapped_model.forward_pipeline_parallel_hook_capacity
+ unwrapped_model.backward_pipeline_parallel_hook_capacity
)
for i in range(unwrapped_model.backward_pipeline_parallel_hook_capacity):
unwrapped_model.register_backward_pipeline_parallel_hook(
location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook
self.flash_checkpoint_manager = FlashCheckpointManager(
worker_num=self.args.flash_workers_num,
pipeline_hooks_capacity=pipeline_hooks_capacity,
capacity_usage=self.args.flash_pipeline_hooks_capacity_usage,
use_expert_parallel=self.args.use_expert_parallel,
ema_coef=self.args.flash_save_ema_coef,
)
logger.info("Create flash checkpoint manager done.")

def maybe_update_flash_checkpoint_worker(self):
if self.optimizer.fused_buffer_version == self.flash_checkpoint_manager.cache_version:
return

logger.info("Flash checkpoint workers need upgrade.")
self._cache_meta_for_sharded_save()
param_mappings, ipc_meta_mappings = get_fused_param_mappings(self.optimizer, self.manipulated_state_dict)
optimizer_states_meta = (
self.optimizer.fused_states_accumulators_meta,
self.optimizer.fused_states_master_weights_meta,
None,
self.optimizer.fused_states_buffer_ipc_meta,
)
model_states_meta = (param_mappings, ipc_meta_mappings)
optimizer_states_name_path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
model_states_name_path = _add_variant(PADDLE_WEIGHTS_NAME, self.manipulated_weight_suffix)

dynamic_objecs = {}
dynamic_objecs["optimizer_states_meta"] = optimizer_states_meta
dynamic_objecs["model_states_meta"] = model_states_meta
dynamic_objecs["optimizer_states_name_path"] = optimizer_states_name_path
dynamic_objecs["model_states_name_path"] = model_states_name_path

static_objects = {}
static_objects["model_config"] = self.manipulated_config_to_save
static_objects["training_args"] = self.args
static_objects["model_meta"] = self.model_meta
static_objects["user_file"] = self.user_file_list

self.flash_checkpoint_manager.update_flash_workers(
self.optimizer.fused_buffer_version, dynamic_objecs, static_objects
for i in range(unwrapped_model.forward_pipeline_parallel_hook_capacity):
unwrapped_model.register_forward_pipeline_parallel_hook(
location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook
)
for i in range(unwrapped_model.backward_pipeline_parallel_hook_capacity):
unwrapped_model.register_backward_pipeline_parallel_hook(
location=i, hook=self.flash_checkpoint_manager.flash_checkpoint_pipeline_hook
)
else:
pipeline_hooks_capacity = self.args.gradient_accumulation_steps
self.flash_checkpoint_manager = FlashCheckpointManager(
worker_num=self.args.flash_workers_num,
pipeline_hooks_capacity=pipeline_hooks_capacity,
capacity_usage=self.args.flash_pipeline_hooks_capacity_usage,
use_expert_parallel=self.args.use_expert_parallel,
ema_coef=self.args.flash_save_ema_coef,
)
_callback = FlashCheckpointCallback(
self.args, self.flash_checkpoint_manager, self.runtime_timer, self.sharding_io
)

def _cache_meta_for_sharded_save(self):
logger.info("Start caching metas for sharded save...")
(
self.manipulated_state_dict,
self.manipulated_config_to_save,
self.manipulated_weight_suffix,
) = self.sharding_io.manipulate_state_dict_and_config(self.model, merge_tensor_parallel=False)
logger.info("Cache manipulated static dict done.")
if self.manipulated_config_to_save is None:
model_to_save = unwrap_model(self.model)
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.dtype = str(dtype).split(".")[1]
self.manipulated_config_to_save = copy.deepcopy(model_to_save.config)
self.manipulated_config_to_save.architectures = [model_to_save.__class__.__name__]
self.manipulated_config_to_save = self.manipulated_config_to_save.to_json_string(use_diff=True)
logger.info("Cache manipulated model config done")
self.model_meta = self.sharding_io.gather_distributed_model_meta()
logger.info("Cache distributed model meta done.")
self.add_callback(_callback)
if resume_from_checkpoint is not None:
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = os.path.join(resume_from_checkpoint, path).replace("optimizer", "ema")
logger.info(f"FC EMA load from {path}")
self.flash_checkpoint_manager.set_ema_state_dict(path)
logger.info("Create flash checkpoint manager done.")

def train(
self,
Expand Down Expand Up @@ -900,7 +866,7 @@ def train(
self._load_optimizer_and_scheduler(resume_from_checkpoint)

if self.args.enable_flash_save_mode:
self.create_flash_checkpoint_manager(model)
self.create_flash_checkpoint_manager(model, resume_from_checkpoint)

logger.info(f"{self.runtime_timer.log()}")

Expand Down Expand Up @@ -1209,10 +1175,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
self.callback_handler.on_optimizer_begin(
args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None
)
if self.args.enable_flash_save_mode and self.flash_checkpoint_manager.current_worker is not None:
logger.info("Start syncing flash checkpoints")
self.flash_checkpoint_manager.sync_offload_status()
logger.info("Synced flash checkpoints.")
optimizer_was_run = True

if self.args.offload_optim:
Expand Down Expand Up @@ -2441,28 +2403,7 @@ def _ordered_save(self, state_dict, save_path):
paddle.save(state_dict, save_path)
dist.barrier(mp_group)

def _get_save_infos_based_on_steps(self, checkpoint_folder):
flash_checkpoint_dir = None
persistent_checkpoint_dir = None
if self.args.flash_save_steps > 0 and self.state.global_step % self.args.flash_save_steps == 0:
flash_checkpoint_dir = os.path.join(FLASH_DEVICE, checkpoint_folder)
if self.args.save_steps > 0 and self.state.global_step % self.args.save_steps == 0:
persistent_checkpoint_dir = os.path.join(self.args.output_dir, checkpoint_folder)
return (flash_checkpoint_dir, persistent_checkpoint_dir)

def _save_checkpoint_flash(self):
self.runtime_timer.start("checkpoint saving time")
self.maybe_update_flash_checkpoint_worker()
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
save_infos = self._get_save_infos_based_on_steps(checkpoint_folder)
non_cached_objects = (self.lr_scheduler.state_dict(), self.state)
self.flash_checkpoint_manager.get_idle_worker_for_saving(save_infos, non_cached_objects)
self.runtime_timer.stop()

def _save_checkpoint(self, model, metrics=None):
if self.args.enable_flash_save_mode:
self._save_checkpoint_flash()
return
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")

Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,14 @@ class TrainingArguments:
"help": "Set pipeline hook capacity usage ratio. Lower value brings faster save speed but may effect calculation speed."
},
)
flash_save_ema_coef: Optional[float] = field(
default=None,
metadata={"help": "The coefficient of EMA parameters in flash save mode. if set to 0, skip EMA process"},
)
flash_ema_interval: Optional[int] = field(
default=1,
metadata={"help": "Interval between updating EMA parameters."},
)
save_tokenizer: Optional[bool] = field(
default=True,
metadata={"help": "Save tokenizer to output_dir."},
Expand Down
Loading
Loading