Skip to content

Commit

Permalink
[WIP][RFC] TorchFT integration
Browse files Browse the repository at this point in the history
Summary:
This is a WIP TorchFT integration PR.

Test Plan:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```

```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: 9244a1078fa9a10e564d6c28001bb508d75a1434
Pull Request resolved: #806
  • Loading branch information
fegin committed Feb 3, 2025
1 parent cca0702 commit e9e9345
Show file tree
Hide file tree
Showing 9 changed files with 265 additions and 99 deletions.
4 changes: 4 additions & 0 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}

PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
204 changes: 119 additions & 85 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from dataclasses import dataclass, field
from io import BytesIO
from multiprocessing import get_context
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
set_model_state_dict,
Expand Down Expand Up @@ -143,50 +144,29 @@ def __init__(
lr_schedulers: SchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
ft_manager: Optional[Any] = None,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.keep_latest_k = ckpt_config.keep_latest_k
self.ft_manager = ft_manager
self.enable_staging = (
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
) or self.ft_manager

if not self.enable_checkpoint:
if not self.enable_checkpoint and self.ft_manager is None:
return
"""
Note: Pipeline Parallelism and Virtual Stages
1. even for simple PP schedules, there is a separate optimizer each PP rank.
rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
When saving, these collide and one of them is lost. Then when reloading, only one stage can
restore its optimizer states, others will error.
The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.
2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
requiring us to reason about multiple 'optim' objects locally.
We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
optimizers do, so it's hard to write a generic 'flattener' utility.
TODO: This is currently unsolved and needs a fix.
"""
self.states = states

self.states.update(
{
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
}
self._initialize_states(
states, dataloader, model_parts, optimizers, lr_schedulers
)

async_mode = ckpt_config.async_mode.lower()
self.staging = False
self.sending_to_checkpoint_mp = False
self.staging_id = None
self.cpu_offload_state_dict = None
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
IntervalType.SECONDS
Expand All @@ -201,6 +181,7 @@ def __init__(
if async_mode == AsyncMode.ASYNC or self.interval_type == IntervalType.SECONDS:
self.pg = dist.new_group(backend="gloo")

self.keep_latest_k = ckpt_config.keep_latest_k
self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]

Expand All @@ -224,10 +205,6 @@ def __init__(
daemon=True,
)
self.mp.start()
self.cpu_offload_state_dict = None
self.staging = False
self.staging_id = None
self.staging_stream = torch.cuda.Stream()
else:
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")

Expand All @@ -241,8 +218,61 @@ def __del__(self):
self.mp.join()

def reset(self) -> None:
# We need to stage the local state if another replicate joins during the
# first step.
if self.ft_manager:
self.cpu_staging(None)
self.begin_time = time.monotonic()

def _initialize_states(
self,
states: Dict[str, Any],
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: SchedulersContainer,
) -> None:
"""
Note: Pipeline Parallelism and Virtual Stages
1. Even for simple PP schedules, there is a separate optimizer each PP rank.
rank0's optimizer would have a param_group[0] which refers to layers.0 in the
original model. rank1's would _also_ have a param_group[0], since it's index based,
but referring to layers.1.
When saving, these collide and one of them is lost. Then when reloading, only one
stage can restore its optimizer states, others will error.
The solution to this problem is optimizer flattening: it landed in #127071
and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
kwarg to DCP functions called in the OptimizerContainer.
2. With complex PP schedules, we have multiple model chunks per pp rank. This
compounds challenge (1) by also requiring us to reason about multiple 'optim'
objects locally.
We solve this in the Model and Optimizer wrapper classes by flattening the
state dicts from each object into one state dict before saving/loading.
We rely on the individual state_dicts to not collide, which is gauranteed for
the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).
3. LR schedulers also index model states like optimizers and would need to be
flattened properly to support resharding. Unfortunately, the implementations of
different lr_schedulers do not follow a clear pattern like optimizers do, so it's
hard to write a generic 'flattener' utility.
TODO: This is currently unsolved and needs a fix.
"""
self.states = states
self.states.update(
{
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
}
)

def _create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")

Expand Down Expand Up @@ -325,31 +355,8 @@ def _async_wait(self) -> None:
self.async_future.result()

def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
try:
from torch.distributed._state_dict_utils import (
_copy_state_dict,
_create_cpu_state_dict,
)
except ImportError as e:
raise ImportError(
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
) from e
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
if self.cpu_offload_state_dict is None:
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
self.cpu_offload_state_dict = _create_cpu_state_dict(
state_dict, pin_memory=True, share_memory=True
)

logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
with torch.cuda.stream(self.staging_stream):
self.cpu_offload_state_dict = _copy_state_dict(
state_dict,
self.cpu_offload_state_dict,
non_blocking=True,
)
self.staging = True
self.staging_id = checkpoint_id
self.cpu_staging(checkpoint_id)
self.sending_to_checkpoint_mp = True

def save(self, curr_step: int, force: bool = False) -> None:
"""
Expand All @@ -359,6 +366,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
for initial seed checkpoint.
"""
if not self._should_save(curr_step, force):
if self.ft_manager:
self.cpu_staging(None)
return

begin = time.monotonic()
Expand All @@ -382,26 +391,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
f"in {time.monotonic() - begin:.2f} seconds."
)

def cpu_staging(self, checkpoint_id: Optional[str]) -> None:
"""Offload state_dict to CPU memory"""
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
if self.cpu_offload_state_dict is None:
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
self.cpu_offload_state_dict = _create_cpu_state_dict(
state_dict, pin_memory=True, share_memory=True
)

logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
with torch.cuda.stream(self.staging_stream):
self.cpu_offload_state_dict = _copy_state_dict(
state_dict,
self.cpu_offload_state_dict,
non_blocking=True,
)
self.staging = True
self.staging_id = checkpoint_id

def wait_for_staging(self) -> None:
if not self.staging_stream.query():
self.staging_stream.synchronize()
self.staging = False

def staging_results(self) -> Dict[str, Any]:
self.maybe_wait_for_staging()
return self.cpu_offload_state_dict

def maybe_wait_for_staging(self) -> None:
if (
self.enable_checkpoint
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
and self.staging
):
if not self.staging_stream.query():
self.staging_stream.synchronize()

def sync_func():
self.mp_queue_send.put_nowait(
(self.cpu_offload_state_dict, self.staging_id)
)

# This may be a faster way to do zero-overhead checkpointing staging
# checkpointing but we need more thorough investigation before
# swithing to this method.
# self.my_thread = threading.Thread(target=func).start()
sync_func()
self.staging = False
if self.enable_staging and self.staging:
self.wait_for_staging()

if self.sending_to_checkpoint_mp:
# Copy the sync staging result to another process.
def sync_func():
self.mp_queue_send.put_nowait(
(self.cpu_offload_state_dict, self.staging_id)
)

# This may be a faster way to do zero-overhead checkpointing staging
# checkpointing but we need more thorough investigation before
# swithing to this method.
# self.my_thread = threading.Thread(target=func).start()
sync_func()
self.sending_to_checkpoint_mp = False

def load(self, step: int = -1) -> bool:
if not self.enable_checkpoint:
Expand Down
13 changes: 13 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,19 @@ def __init__(self):
action="store_true",
)

self.parser.add_argument(
"--experimental.enable_torchft",
action="store_true",
help="Enable TorchFT integration.",
)

self.parser.add_argument(
"--experimental.ft_replica_group_id",
type=int,
default=-1,
help="The FT replicate group of this run.",
)

def to_dict(self):
return self.args_dict

Expand Down
58 changes: 58 additions & 0 deletions torchtitan/ft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import importlib
from typing import Any, Callable, Optional

from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict

from torchtitan.config_manager import JobConfig

if importlib.util.find_spec("torchft") is not None:
import torchft as ft

has_torchft = True
else:
has_torchft = False


def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]:
"""
Initialize the FT manager for the given job.
"""
if not job.experimental.enable_torchft:
return None

if not has_torchft:
raise ImportError("torchft is not installed. Please install it.")

pg = ft.ProcessGroupBabyNCCL()
manager = ft.Manager(
pg=pg,
min_replica_size=1,
load_state_dict=None,
state_dict=None,
use_async_quorum=True,
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}",
)

return manager


def set_ft_state_dict_fns(manager: Optional["ft.Manager"], ckpt_manager) -> None:
"""
Set the state dict for the given manager.
"""
if manager is None:
return

def state_dict():
ret = {}
for k, v in ckpt_manager.staging_results().items():
if k in {"model", "optimizer", "lr_schedulers"}:
ret[k] = v
return ret

def load_state_dict(state_dict):
assert state_dict is not None
for k, v in state_dict.items():
ckpt_manager.states[k].load_state_dict(v)

manager.set_state_dict_fns(load_state_dict, state_dict)
Loading

0 comments on commit e9e9345

Please sign in to comment.