Skip to content

Commit

Permalink
[chat] add distributed impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 committed Feb 20, 2025
1 parent 9379cbd commit c99c4b4
Show file tree
Hide file tree
Showing 9 changed files with 798 additions and 0 deletions.
6 changes: 6 additions & 0 deletions applications/ColossalChat/coati/distributed/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Requirements

```bash
pip install cupy-cuda12x
python -m cupyx.tools.install_library --cuda 12.x --library nccl
```
Empty file.
57 changes: 57 additions & 0 deletions applications/ColossalChat/coati/distributed/comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any, Dict

import ray.util.collective as cc
import torch
import torch.distributed.distributed_c10d as c10d
from packaging.version import Version


def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any:
rank = cc.get_rank(group_name)
if rank == src:
if Version(torch.__version__) >= Version("2.3.0"):
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device, group=None)
elif Version(torch.__version__) >= Version("1.13.0"):
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device)
else:
obj_tensor, size_tensor = c10d._object_to_tensor(obj)
obj_tensor = obj_tensor.to(device)
size_tensor = size_tensor.to(device)
else:
size_tensor = torch.empty(1, dtype=torch.int64, device=device)
cc.broadcast(size_tensor, src, group_name)
if rank != src:
obj_tensor = torch.empty(size_tensor.item(), dtype=torch.uint8, device=device)
cc.broadcast(obj_tensor, src, group_name)
if rank != src:
if Version(torch.__version__) >= Version("2.3.0"):
obj = c10d._tensor_to_object(obj_tensor, size_tensor.item(), group=None)
else:
obj = c10d._tensor_to_object(obj, size_tensor.item())
return obj


def ray_broadcast_tensor_dict(
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
) -> Dict[str, torch.Tensor]:
rank = cc.get_rank(group_name)
if rank == src:
metadata = []
for k, v in tensor_dict.items():
metadata.append((k, v.shape, v.dtype))
else:
metadata = None
metadata = ray_broadcast_object(metadata, src, device, group_name)
if rank != src:
out_dict = {}
for k, shape, dtype in metadata:
if rank == src:
tensor = tensor_dict[k]
else:
tensor = torch.empty(shape, dtype=dtype, device=device)
cc.broadcast(tensor, src, group_name)
if rank != src:
out_dict[k] = tensor
if rank == src:
out_dict = tensor_dict
return out_dict
190 changes: 190 additions & 0 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from contextlib import nullcontext
from typing import Any, Dict, Optional

import ray
import ray.util.collective as cc
import torch
import torch.distributed as dist
from tqdm import tqdm
from transformers import AutoModelForCausalLM

from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch


class BaseConsumer:
def __init__(
self,
num_producers: int,
num_episodes: int,
rank: int,
world_size: int,
master_addr: str,
master_port: int,
num_update_per_episode: int,
num_recv_per_update: int,
batch_size: int,
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
):
self.num_producers = num_producers
self.num_episodes = num_episodes
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port
self.num_update_per_episode = num_update_per_episode
self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size

self.model_config = model_config
self.plugin_config = plugin_config
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"

self.device = get_current_device()

def setup(self) -> None:
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)

plugin_config = dict(
tp_size=1,
pp_size=1,
precision="bf16",
zero_stage=1,
)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group)

self.buffer = []

self.recv_cnt = 0

def state_dict(self) -> Dict[str, torch.Tensor]:
raise NotImplementedError

def step(self, step_idx: int, **kwargs) -> Optional[float]:
raise NotImplementedError

def loop(self) -> None:
print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
)
for episode in range(self.num_episodes):
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
for step in pbar:
i = 0
for _ in range(self.num_recv_per_update):
# receive data from producers

for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
unbind_batch(
ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
)
)
while len(self.buffer) >= self.dp_size * self.microbatch_size:
batches = self.buffer[
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
]
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
batch = bind_batch(batches)
batch = post_recv(batch)
loss = self.step(i, **batch)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
state_dict = self.state_dict()
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)


@ray.remote
class SimpleConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.model.train()
self.model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
self.accum_loss = torch.zeros(1, device=self.device)

def setup(self):
super().setup()
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)

def step(self, step_idx: int, **kwargs) -> Optional[float]:
need_update = (step_idx + 1) % self.num_microbatches == 0

ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer)
with ctx:
out = self.model(**kwargs)
loss = out.loss / self.num_microbatches
self.accum_loss.add_(loss.data)
self.booster.backward(loss, self.optimizer)
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
loss_scalar = self.accum_loss.item()
self.accum_loss.zero_()
return loss_scalar

def state_dict(self):
self.model._force_wait_all_gather()
model = self.model.unwrap()
state_dict = model.state_dict()
return state_dict
Loading

0 comments on commit c99c4b4

Please sign in to comment.