Skip to content

Commit

Permalink
Add support for generating debug traces on failure
Browse files Browse the repository at this point in the history
  • Loading branch information
chauhang authored Mar 24, 2024
1 parent 72aad15 commit e84bbf4
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.tensor.parallel import loss_parallel
from torch.distributed.elastic.multiprocessing.errors import record

from torchtrain.checkpoint import CheckpointManager, IntervalType
from torchtrain.config_manager import JobConfig
Expand Down Expand Up @@ -97,7 +98,8 @@ def build_grad_scaler(model):

return ShardedGradScaler(enabled=enable_grad_scaling)


#Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def main(job_config: JobConfig):
init_logger()
logger.info(f"Starting job: {job_config.job.description}")
Expand Down

0 comments on commit e84bbf4

Please sign in to comment.