diff --git a/open_lm/main.py b/open_lm/main.py index cbc7d91f..331655da 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -51,7 +51,7 @@ from open_lm.distributed import is_master, init_distributed_device, broadcast_object from open_lm.logger import setup_logging from open_lm.params import parse_args -from open_lm.scheduler import cosine_lr, const_lr +from open_lm.scheduler import cosine_lr, const_lr, cosine_rewarmed_lr from open_lm.train import train_one_epoch from open_lm.evaluate import evaluate_loop from open_lm.file_utils import ( @@ -691,8 +691,23 @@ def main(args): # args.lr_cooldown_end, # args.force_min_lr, ) + elif args.lr_scheduler == "cosine-rewarmed": + resumed_step = (args.train_num_samples * start_epoch) // args.global_batch_size + scheduler = cosine_rewarmed_lr( + optimizer, + args.lr, + args.warmup, + total_steps, + args.lr_cooldown_end, + args.force_min_lr, + args.cosine_rewarmed_target_steps, + args.cosine_rewarmed_original_warmup, + resumed_step, + ) else: - raise ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const.") + raise ValueError( + f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, cosine-rewarned." + ) # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) diff --git a/open_lm/params.py b/open_lm/params.py index eea63106..52119a2a 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -237,9 +237,9 @@ def check_args(args): if args.remote_sync_protocol != "s3": raise ValueError("Sync protocol not supported when using resume latest.") - if args.lr_scheduler not in {"cosine", "const", "const-cooldown"}: + if args.lr_scheduler not in {"cosine", "const", "cosine-rewarmed"}: raise ValueError( - f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown." + f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, cosine-rewarmed." ) if args.experimental_meta_device: @@ -391,7 +391,19 @@ def parse_args(args): "--lr-scheduler", type=str, default="cosine", - help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine", + help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown), 'cosine-rewarmed'. Default: cosine", + ) + parser.add_argument( + "--cosine-rewarmed-target-steps", + type=int, + default=None, + help="for cosine rewarmed, the target steps for the cosine schedule. Default: cosine", + ) + parser.add_argument( + "--cosine-rewarmed-original-warmup", + type=int, + default=1000, + help="for cosine rewarmed, the original warmup steps. Default: 1000", ) parser.add_argument( "--lr-cooldown-end", diff --git a/open_lm/scheduler.py b/open_lm/scheduler.py index 2505e57b..02d3d8f8 100644 --- a/open_lm/scheduler.py +++ b/open_lm/scheduler.py @@ -10,6 +10,17 @@ def _warmup_lr(base_lr, warmup_length, step): return base_lr * (step + 1) / warmup_length +def _cosine_lr(step, base_lr, warmup_length, steps, min_lr, force_min_lr): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + e = step - warmup_length + es = steps - warmup_length + lr = min_lr + 0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - min_lr) + lr = max(lr, force_min_lr) + return lr + + def const_lr(optimizer, base_lr, warmup_length): def _lr_adjuster(step): if step < warmup_length: @@ -63,3 +74,28 @@ def _lr_adjuster(step): return lr return _lr_adjuster + + +def cosine_rewarmed_lr( + optimizer, base_lr, warmup_length, steps, min_lr, force_min_lr, target_steps, original_warmup, resumed_step +): + def _lr_adjuster(step): + step -= resumed_step + new_base_lr = _cosine_lr( + target_steps - steps + warmup_length, base_lr, original_warmup, target_steps, min_lr, force_min_lr + ) + if step < warmup_length: + lr = _warmup_lr(new_base_lr, warmup_length, step) + else: + lr = _cosine_lr( + target_steps - steps + step - warmup_length, + base_lr, + warmup_length, + target_steps - warmup_length, + min_lr, + force_min_lr, + ) + assign_learning_rate(optimizer, lr) + return lr + + return _lr_adjuster