From 6e2899fbc6d9367615e6eb35e46b07c3e33e8651 Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Mon, 10 Jun 2024 13:39:01 +0300 Subject: [PATCH] WA for Torch-compile-Z3-act-apt accuracy issue from the Pytorch repo (#5590) We have been encountered an accuracy issue when running Torch compile + zero3 + activation checkpointing. Specifically some grads gets is zeroed (running without torch compile, this issue is not encountered). This issue was also reproduced by Umesh Chand from the DS team. We found that in the Pytorch repo torch compile has been specifically disabled using the label: @torch._disable_dynamo() reference to the WA in the Pytorch repo (https://github.com/pytorch/pytorch/blob/ec8b254ef49b4a057cf89c2ae64520fb7b423a3e/torch/utils/checkpoint.py#L324) this indicates that there is some issue with torch compile and checkpointing (not necessarily DS related). given that the checkpointing function in DeepSpeed is based on the Pytorch function, We propose to adopt this WA to ensure correct behavior (it can be removed later if the underlying issue is fixed) Note: this shouldn't impact non-troch compile cases. --------- Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/activation_checkpointing/checkpointing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 2a21cf7ca17a..529931ca0df1 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -29,6 +29,7 @@ from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers, FORWARD_GLOBAL_TIMER from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank from deepspeed.accelerator import get_accelerator +from deepspeed.runtime import compiler # DeepSpeed Checkpointing Enabled or Disabled deepspeed_checkpointing_enabled = False @@ -987,6 +988,7 @@ def after_backward_hook(_nonuse_grads): return tuple(all_outputs) +@compiler.disable # WA from Pytorch repo for compile + zero 3 accuracy issue def checkpoint(function, *args): """Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint. """