Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[applications/ColossalChat/examples/training_scripts/lora_finetune.py]: Fixed bug, added save_interval and added auto resume functions #6223

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 215 additions & 26 deletions applications/ColossalChat/examples/training_scripts/lora_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,36 @@ def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor:
return loss / dist.get_world_size(group)


def get_second_latest_subfolder_and_optimizer_file(folder_path):
os.path.exists(folder_path) or os.makedirs(folder_path)

# 获取所有以"lora"开头的子文件夹
subfolders = [
f for f in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, f)) and f.startswith("lora")
]

# 检查子文件夹数量是否大于等于2
if len(subfolders) < 2:
return None, None # 如果子文件夹数量小于2,返回None

# 按最后修改时间排序,最新的排在前面
subfolders.sort(key=lambda x: os.path.getmtime(os.path.join(folder_path, x)), reverse=True)

# 获取倒数第二新的子文件夹路径
second_latest_subfolder = subfolders[1] if len(subfolders) >= 2 else None
second_latest_lora_subfolder_path = os.path.join(folder_path, second_latest_subfolder)

# 获取所有以"optimizer"开头且".pth"为后缀的文件
# 获取倒数第二新的optimizer文件
second_latest_optimizer_subfolder_path = os.path.join(
folder_path, second_latest_subfolder.replace("lora_", "optimizer_") + ".pth"
)

return second_latest_lora_subfolder_path, second_latest_optimizer_subfolder_path


def train(args) -> None:

# ==============================
# Initialize Distributed Training
# ==============================
Expand Down Expand Up @@ -208,7 +237,12 @@ def is_master():
)
else:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=args.lora_alpha)
model = booster.enable_lora(model, lora_config=lora_config)
if args.lora_path:
coordinator.print_on_master(f"Loading lora weights from: {args.lora_path}")
model = booster.enable_lora(model, pretrained_dir=args.lora_path)
else:
model = booster.enable_lora(model, lora_config=lora_config)
model.enable_input_require_grads()

# this is essential, otherwise the grad checkpoint will not work.
model.train()
Expand Down Expand Up @@ -257,7 +291,7 @@ def is_master():
)

torch.set_default_dtype(torch.float)
booster.load_model(model, args.pretrained)
booster.load_model(model, args.pretrained, strict=False)

coordinator.print_on_master(
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
Expand All @@ -269,6 +303,26 @@ def is_master():
start_epoch = 0
start_step = 0

if not (args.lora_path or args.optmizer_path):
args.lora_path, args.optmizer_path = get_second_latest_subfolder_and_optimizer_file(args.save_dir)
coordinator.print_on_master(f"Lora Path:{args.lora_path}")
coordinator.print_on_master(f"Optimizer Path:{args.optmizer_path}")

# Load checkpoint if available
if args.optmizer_path:
checkpoint_path = args.optmizer_path
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=get_current_device())
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
start_epoch = checkpoint["epoch"]
start_step = checkpoint["step"]
coordinator.print_on_master(f"Resuming optimizer from epoch {start_epoch}, step {start_step}")
else:
coordinator.print_on_master("optimizer checkpoint not found, starting training from scratch")
else:
coordinator.print_on_master("Starting training from optimizer scratch")

num_steps_per_epoch = len(dataloader) // args.accumulation_steps

for epoch in range(start_epoch, args.num_epochs):
Expand Down Expand Up @@ -316,10 +370,12 @@ def is_master():
dataloader,
desc=f"Epoch {epoch}",
disable=not is_master(),
initial=start_step // args.accumulation_steps,
initial=start_step // args.accumulation_steps, # 设置起始位置
)
total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(pbar, start=start_step // args.accumulation_steps):
if step > num_steps_per_epoch:
break
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}

batch_output = model(**batch)
Expand Down Expand Up @@ -348,9 +404,25 @@ def is_master():

lr_scheduler.step()
optimizer.zero_grad()
# print(lr_scheduler.get_last_lr()[0])

total_loss.fill_(0.0)

if (step + 1) % args.save_interval == 0:
if args.lora_rank > 0:
booster.save_lora_as_pretrained(
model, os.path.join(args.save_dir, f"lora_epoch{epoch}_step{step}")
)
checkpoint = {
"epoch": epoch,
"step": step + 1,
"optimizer_state_dict": optimizer.state_dict(),
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
}
torch.save(checkpoint, os.path.join(args.save_dir, f"optimizer_epoch{epoch}_step{step}.pth"))
coordinator.print_on_master(f"Saved checkpoint at epoch {epoch}, step {step + 1}")

start_step = 0
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
Expand All @@ -373,10 +445,16 @@ def is_master():
"-m",
"--pretrained",
type=str,
required=True,
default=None,
help="Address of the pre-trained model",
)
parser.add_argument("-d", "--dataset", type=str, required=True, help="Raw Jonl dataset for training.")
parser.add_argument(
"-d",
"--dataset",
type=str,
default=None,
help="Raw Jonl dataset for training.",
)
parser.add_argument(
"-p",
"--plugin",
Expand All @@ -385,30 +463,99 @@ def is_master():
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp", "moe"],
help="Choose which plugin to use",
)
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
parser.add_argument("--tensorboard_dir", type=str, default=None, help="Tensorboard directory")
parser.add_argument("--config_file", type=str, default="training_config.json", help="Config file")
parser.add_argument(
"--save_dir",
type=str,
default=None,
help="Checkpoint directory",
)
parser.add_argument("--save_interval", type=int, default=100, help="Save interval")
parser.add_argument(
"--lora_path",
type=str,
default=None,
help="Lora checkpoint directory",
)
parser.add_argument(
"--optmizer_path",
type=str,
default=None,
help="Optmizer checkpoint directory",
)
parser.add_argument(
"--tensorboard_dir",
type=str,
default="logs",
help="Tensorboard directory",
)
parser.add_argument(
"--config_file",
type=str,
default="training_config.json",
help="Config file",
)
# Training parameters
parser.add_argument("-n", "--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
parser.add_argument(
"-n",
"--num_epochs",
type=int,
default=2,
help="Number of training epochs",
)
parser.add_argument(
"--accumulation_steps",
type=int,
default=1,
help="Number of accumulation steps",
)
parser.add_argument(
"--batch_size",
type=int,
default=2,
help="Global Batch size of each process",
)
parser.add_argument(
"--lr",
type=float,
default=2e-5,
help="Learning rate",
)
parser.add_argument(
"--max_length",
type=int,
default=256,
help="Model max length",
)
parser.add_argument(
"--mixed_precision",
type=str,
default="bf16",
choices=["fp16", "bf16"],
help="Mixed precision",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument(
"--grad_clip",
type=float,
default=1.0,
help="Gradient clipping value",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.1,
help="Weight decay",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=8,
help="Warmup steps",
)
parser.add_argument(
"-g",
"--use_grad_checkpoint",
action="store_true",
default=False,
default=True,
help="Use gradient checkpointing",
)
parser.add_argument(
Expand All @@ -420,11 +567,37 @@ def is_master():
)

# Additional arguments for 3d plugin.
parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.")
parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.")
parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.")
parser.add_argument("--ep", type=int, default=1, help="EP size, used for moe plugin.")
parser.add_argument("--zero_stage", type=int, default=1, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2])
parser.add_argument(
"--tp",
type=int,
default=1,
help="TP size, used for 3d plugin.",
)
parser.add_argument(
"--pp",
type=int,
default=1,
help="PP size, used for 3d plugin.",
)
parser.add_argument(
"--sp",
type=int,
default=1,
help="SP size, used for 3d plugin.",
)
parser.add_argument(
"--ep",
type=int,
default=1,
help="EP size, used for moe plugin.",
)
parser.add_argument(
"--zero_stage",
type=int,
default=1,
help="Zero stage, used for 3d plugin.",
choices=[0, 1, 2],
)
parser.add_argument(
"--sp_mode",
type=str,
Expand All @@ -439,13 +612,29 @@ def is_master():
help="Whether to enable SP, used for 3d plugin.",
)
parser.add_argument(
"--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin."
"--zero_cpu_offload",
default=False,
action="store_true",
help="Whether to use offloading, used for 3d plugin.",
)
parser.add_argument(
"--microbatch_size",
type=int,
default=1,
help="Batch size for each process in PP, used for 3d plugin.",
)
parser.add_argument(
"--lora_rank",
type=int,
default=8,
help="lora rank when using lora to train.",
)
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
"--lora_alpha",
type=int,
default=16,
help="lora alpha when using lora to train.",
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
parser.add_argument("--lora_alpha", type=int, default=8, help="lora alpha when using lora to train.")

args = parser.parse_args()

Expand Down