From e84bbf476170adedf8d96e9df1a6e8bfb3631694 Mon Sep 17 00:00:00 2001 From: Geeta Chauhan <4461127+chauhang@users.noreply.github.com> Date: Sun, 24 Mar 2024 16:46:57 -0700 Subject: [PATCH] Add support for generating debug traces on failure --- train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 25bf37f6..bca795bc 100644 --- a/train.py +++ b/train.py @@ -17,6 +17,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.tensor.parallel import loss_parallel +from torch.distributed.elastic.multiprocessing.errors import record from torchtrain.checkpoint import CheckpointManager, IntervalType from torchtrain.config_manager import JobConfig @@ -97,7 +98,8 @@ def build_grad_scaler(model): return ShardedGradScaler(enabled=enable_grad_scaling) - +#Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html +@record def main(job_config: JobConfig): init_logger() logger.info(f"Starting job: {job_config.job.description}")