Skip to content

Commit

Permalink
[utils] fix gradient checkpoint logic (#2275)
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 authored Jan 4, 2024
1 parent a4a1e48 commit cacc562
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ def init_dataset_and_dataloader(args, configs, tokenizer):
def wrap_cuda_model(args, model):
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
world_size = int(os.environ.get('WORLD_SIZE', 1))
grad_ckpt = getattr(model.encoder, 'gradient_checkpointing', False)
if hasattr(model, 'encoder'):
grad_ckpt = getattr(model.encoder, 'gradient_checkpointing', False)
else:
grad_ckpt = False
# TODO(xcsong): could one GPU use ddp? and int(os.environ.get('WORLD_SIZE', 1)) > 1
if args.train_engine == "torch_ddp": # native pytorch ddp
assert (torch.cuda.is_available())
Expand Down Expand Up @@ -425,7 +428,8 @@ def wenet_join(group_join, info_dict):
# operations are executed. If we add a communication operation that is not
# managed by Deepspeed in this group, it's highly likely to cause
# communication chaos, resulting in hard-to-troubleshoot hangs.
dist.monitored_barrier(group=group_join, timeout=group_join.options._timeout)
dist.monitored_barrier(group=group_join,
timeout=group_join.options._timeout)
except RuntimeError as e:
logging.info("Detected uneven workload distribution: {}\n".format(e) +
"Break current worker to manually join all workers, " +
Expand Down

0 comments on commit cacc562

Please sign in to comment.