Skip to content

Commit

Permalink
Fix AC in T5 example (#1273)
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim authored Jun 29, 2024
1 parent a38cbfc commit 26de419
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions distributed/FSDP/T5_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def fsdp_main(args):
device_id=torch.cuda.current_device(),
limit_all_gathers=fsdp_config.limit_all_gathers)

# Enabling this causes https://github.com/pytorch/examples/issues/1210
if fsdp_config.fsdp_activation_checkpointing:
policies.apply_fsdp_checkpointing(model)

Expand Down
2 changes: 1 addition & 1 deletion distributed/FSDP/configs/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class fsdp_config:
mixed_precision: bool=True
use_fp16: bool=False
seed: int=42
fsdp_activation_checkpointing: bool=True
fsdp_activation_checkpointing: bool=False
limit_all_gathers: bool=True
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD #HYBRID_SHARD, SHARD_GRAD_OP
checkpoint_type: StateDictType = StateDictType.FULL_STATE_DICT # alternatively can use SHARDED_STATE_DICT to avoid OOMs
Expand Down

0 comments on commit 26de419

Please sign in to comment.