Skip to content

Commit

Permalink
Merge branch 'master' into comet-config-enabled-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Jun 10, 2024
2 parents cbce287 + a41729f commit fa7cfe6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 2 additions & 0 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime import compiler

# DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
Expand Down Expand Up @@ -987,6 +988,7 @@ def after_backward_hook(_nonuse_grads):
return tuple(all_outputs)


@compiler.disable # WA from Pytorch repo for compile + zero 3 accuracy issue
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ def average_tensor(self, tensor):
stream = self.reduction_stream
if not get_accelerator().resolves_data_dependency():
stream.wait_stream(get_accelerator().current_stream())
get_accelerator().current_stream().wait_stream(stream)
else:
stream = get_accelerator().current_stream()

Expand Down Expand Up @@ -1962,8 +1963,8 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.loss_scale
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * self.loss_scale

for grad in grad_groups_flat:
if isinstance(grad, list):
Expand Down

0 comments on commit fa7cfe6

Please sign in to comment.