Skip to content

Commit

Permalink
Add tok/sec calculation to pretraining script
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Oct 24, 2024
1 parent 32673d0 commit 9990730
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
17 changes: 15 additions & 2 deletions litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,22 @@ def main(
# Save final checkpoint
save_checkpoint(fabric, state, tokenizer_dir, out_dir / "final" / "lit_model.pth")

fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
total_tokens = state["iter_num"] * train.micro_batch_size * model.max_seq_length * fabric.world_size

# Print formatted output
separator = "-" * 40
fabric.print(separator)
fabric.print("| Performance")
fabric.print(f"| - Total tokens : {total_tokens:,}")
fabric.print(f"| - Training Time : {(time.perf_counter()-train_time):.2f} s")
fabric.print(f"| - Tok/sec : {total_tokens / train_time:.2f} tok/s")
fabric.print("| " + "-" * 40)

if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
memory_used = torch.cuda.max_memory_allocated() / 1e9
fabric.print("| Memory Usage")
fabric.print(f"| - Memory Used : {memory_used:.2f} GB")
fabric.print(separator)


def fit(
Expand Down
2 changes: 1 addition & 1 deletion litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def create_finetuning_performance_report(training_time, token_counts, device_typ
memory_used = torch.cuda.max_memory_allocated() / 1e9
output += f"| Memory Usage \n"
output += f"| - Memory Used : {memory_used:.02f} GB \n"
output += "=======================================================\n"
output += "-------------------------------------------------------\n"

return output

Expand Down

0 comments on commit 9990730

Please sign in to comment.