Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Feb 7, 2025
1 parent 60f81d1 commit 58b0bc3
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 24 deletions.
23 changes: 11 additions & 12 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn

WorkerType = Type[Worker]


Expand Down Expand Up @@ -441,8 +442,7 @@ def _create_dataloader(self):
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.get('seed', 1))
sampler = RandomSampler(data_source=self.train_dataset,
generator=train_dataloader_generator)
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)

Expand Down Expand Up @@ -606,8 +606,9 @@ def init_workers(self):

def _save_checkpoint(self):
# path: given_path + `/global_step_{global_steps}` + `/actor`
local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f'global_step_{self.global_steps}')
actor_local_path = os.path.join(local_global_step_folder,'actor')
local_global_step_folder = os.path.join(self.config.trainer.default_local_dir,
f'global_step_{self.global_steps}')
actor_local_path = os.path.join(local_global_step_folder, 'actor')

actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor')
Expand All @@ -618,8 +619,7 @@ def _save_checkpoint(self):
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path, self.global_steps)



# save dataloader
dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt')
import dill
Expand All @@ -630,16 +630,16 @@ def _save_checkpoint(self):
'latest_checkpointed_iteration.txt')
with open(local_latest_checkpointed_iteration, 'w') as f:
f.write(str(self.global_steps))

def _load_checkpoint(self):
if self.config.trainer.resume_mode == 'disable':
return 0

# load from hdfs
if self.config.trainer.default_hdfs_dir is not None:
NotImplementedError('load from hdfs is not implemented yet')
else:
checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
if not os.path.isabs(checkpoint_folder):
working_dir = os.getcwd()
checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
Expand Down Expand Up @@ -672,8 +672,8 @@ def _load_checkpoint(self):
# load critic
if self.use_critic:
self.critic_wg.load_checkpoint(critic_path)
# load dataloader,

# load dataloader,
# TODO: from remote not implemented yet
dataloader_local_path = os.path.join(global_step_folder, 'data.pt')
self.train_dataloader = torch.load(dataloader_local_path)
Expand Down Expand Up @@ -830,7 +830,6 @@ def fit(self):
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()


# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
Expand Down
3 changes: 2 additions & 1 deletion verl/utils/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
print("Found checkpoint: %s", ckpt_path)
return ckpt_path


def get_checkpoint_tracker_filename(root_path: str):
"""
Tracker file rescords the latest chckpoint during training to restart from.
"""
return os.path.join(root_path, "latest_checkpointed_iteration.txt")
return os.path.join(root_path, "latest_checkpointed_iteration.txt")
8 changes: 1 addition & 7 deletions verl/utils/checkpoint/fsdp_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,7 @@ def load_checkpoint(self, path=None, del_local_after_load=True, *args, **kwargs)
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)

def save_checkpoint(self,
local_path: str,
global_step: int,
remove_previous_ckpt=True,
*args,
**kwargs):
def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=True, *args, **kwargs):
# record the previous global step
self.previous_global_step = global_step

Expand Down Expand Up @@ -140,7 +135,6 @@ def save_checkpoint(self,
torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None
torch.save(extra_state_dict, extra_path)


# wait for everyone to dump to local
torch.distributed.barrier()

Expand Down
6 changes: 3 additions & 3 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self,
parquet_files = [parquet_files]

self.parquet_files = copy.deepcopy(parquet_files)
self.original_parquet_files = copy.deepcopy(parquet_files) # use for resume
self.original_parquet_files = copy.deepcopy(parquet_files) # use for resume
self.cache_dir = os.path.expanduser(cache_dir)
self.tokenizer = tokenizer

Expand Down Expand Up @@ -116,12 +116,12 @@ def _read_files_and_tokenize(self):
axis=1)]

print(f'filter dataset len: {len(self.dataframe)}')

def resume_dataset_state(self):
self.serialize_dataset = False if hasattr(self, 'original_parquet_files') else True
# resume dataframe if not it's serialized in data.pt
if not self.serialize_dataset:
self._download(use_origin_parquet=True) # download and resume from original parquet files
self._download(use_origin_parquet=True) # download and resume from original parquet files
self._read_files_and_tokenize()
else:
print(r'old dataloader ckpt file is used, please train from scratch for better ckpt performance')
Expand Down
4 changes: 3 additions & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,10 +560,11 @@ def load_checkpoint(self, path, del_local_after_load=True):
load_grad=self._is_offload_grad)

self.checkpoint_manager.load_checkpoint(path=path, del_local_after_load=del_local_after_load)

if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)


class CriticWorker(Worker):

def __init__(self, config):
Expand Down Expand Up @@ -839,6 +840,7 @@ def load_checkpoint(self, path, del_local_after_load=True):
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad)


# TODO(sgm): we may need to extract it to dp_reward_model.py
class RewardModelWorker(Worker):
"""
Expand Down

0 comments on commit 58b0bc3

Please sign in to comment.