diff --git a/distributed/FSDP/T5_training.py b/distributed/FSDP/T5_training.py index 1aae5d0990..4ab136eace 100644 --- a/distributed/FSDP/T5_training.py +++ b/distributed/FSDP/T5_training.py @@ -121,6 +121,7 @@ def fsdp_main(args): device_id=torch.cuda.current_device(), limit_all_gathers=fsdp_config.limit_all_gathers) + # Enabling this causes https://github.com/pytorch/examples/issues/1210 if fsdp_config.fsdp_activation_checkpointing: policies.apply_fsdp_checkpointing(model) diff --git a/distributed/FSDP/configs/fsdp.py b/distributed/FSDP/configs/fsdp.py index 301771cd26..220cc67c55 100644 --- a/distributed/FSDP/configs/fsdp.py +++ b/distributed/FSDP/configs/fsdp.py @@ -8,7 +8,7 @@ class fsdp_config: mixed_precision: bool=True use_fp16: bool=False seed: int=42 - fsdp_activation_checkpointing: bool=True + fsdp_activation_checkpointing: bool=False limit_all_gathers: bool=True sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD #HYBRID_SHARD, SHARD_GRAD_OP checkpoint_type: StateDictType = StateDictType.FULL_STATE_DICT # alternatively can use SHARDED_STATE_DICT to avoid OOMs