-
Notifications
You must be signed in to change notification settings - Fork 269
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: This is a WIP TorchFT integration PR. Test Plan: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 07a02ffa43cbc1e16ff35ef9be820db52905d683 Pull Request resolved: #806
- Loading branch information
Showing
9 changed files
with
266 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import Any, Callable, Optional | ||
import importlib | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torch.distributed._state_dict_utils import ( | ||
_copy_state_dict, | ||
_create_cpu_state_dict, | ||
) | ||
|
||
if importlib.util.find_spec("torchft") is not None: | ||
import torchft as ft | ||
has_torchft = True | ||
else: | ||
has_torchft = False | ||
|
||
|
||
def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]: | ||
""" | ||
Initialize the FT manager for the given job. | ||
""" | ||
if not job.experimental.enable_torchft: | ||
return None | ||
|
||
if not has_torchft: | ||
raise ImportError("torchft is not installed. Please install it.") | ||
|
||
pg = ft.ProcessGroupBabyNCCL() | ||
manager = ft.Manager( | ||
pg=pg, | ||
min_replica_size=1, | ||
load_state_dict=None, | ||
state_dict=None, | ||
use_async_quorum=True, | ||
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}", | ||
) | ||
|
||
return manager | ||
|
||
|
||
def set_ft_state_dict_fns(manager: Optional["ft.Manager"], ckpt_manager) -> None: | ||
""" | ||
Set the state dict for the given manager. | ||
""" | ||
if manager is None: | ||
return | ||
|
||
def state_dict(): | ||
ret = {} | ||
for k, v in ckpt_manager.staging_results().items(): | ||
if k in {"model", "optimizer", "lr_schedulers"}: | ||
ret[k] = v | ||
return ret | ||
|
||
def load_state_dict(state_dict): | ||
assert state_dict is not None | ||
for k, v in state_dict.items(): | ||
ckpt_manager.states[k].load_state_dict(v) | ||
|
||
manager.set_state_dict_fns(load_state_dict, state_dict) |
Oops, something went wrong.