Skip to content

Commit

Permalink
Removed setting global flag for swap_tensors since not needed anymore
Browse files Browse the repository at this point in the history
ghstack-source-id: 484237b30ba8bf8bb9e7a9cf2c97180d9fb21295
Pull Request resolved: #178
  • Loading branch information
awgu committed Mar 29, 2024
1 parent 49f9784 commit dca7657
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ def main(job_config: JobConfig):
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)
# set this as required by DTensor to work with `to_empty`
# TODO: remove in the future when enabled by default for wrapper subclasses
torch.__future__.set_swap_module_params_on_conversion(True)
# allocate sharded model on GPU and initialize weights via DTensor
model.to_empty(device="cuda")
model.init_weights()

Expand Down

0 comments on commit dca7657

Please sign in to comment.