-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
798 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Requirements | ||
|
||
```bash | ||
pip install cupy-cuda12x | ||
python -m cupyx.tools.install_library --cuda 12.x --library nccl | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import Any, Dict | ||
|
||
import ray.util.collective as cc | ||
import torch | ||
import torch.distributed.distributed_c10d as c10d | ||
from packaging.version import Version | ||
|
||
|
||
def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any: | ||
rank = cc.get_rank(group_name) | ||
if rank == src: | ||
if Version(torch.__version__) >= Version("2.3.0"): | ||
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device, group=None) | ||
elif Version(torch.__version__) >= Version("1.13.0"): | ||
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device) | ||
else: | ||
obj_tensor, size_tensor = c10d._object_to_tensor(obj) | ||
obj_tensor = obj_tensor.to(device) | ||
size_tensor = size_tensor.to(device) | ||
else: | ||
size_tensor = torch.empty(1, dtype=torch.int64, device=device) | ||
cc.broadcast(size_tensor, src, group_name) | ||
if rank != src: | ||
obj_tensor = torch.empty(size_tensor.item(), dtype=torch.uint8, device=device) | ||
cc.broadcast(obj_tensor, src, group_name) | ||
if rank != src: | ||
if Version(torch.__version__) >= Version("2.3.0"): | ||
obj = c10d._tensor_to_object(obj_tensor, size_tensor.item(), group=None) | ||
else: | ||
obj = c10d._tensor_to_object(obj, size_tensor.item()) | ||
return obj | ||
|
||
|
||
def ray_broadcast_tensor_dict( | ||
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default" | ||
) -> Dict[str, torch.Tensor]: | ||
rank = cc.get_rank(group_name) | ||
if rank == src: | ||
metadata = [] | ||
for k, v in tensor_dict.items(): | ||
metadata.append((k, v.shape, v.dtype)) | ||
else: | ||
metadata = None | ||
metadata = ray_broadcast_object(metadata, src, device, group_name) | ||
if rank != src: | ||
out_dict = {} | ||
for k, shape, dtype in metadata: | ||
if rank == src: | ||
tensor = tensor_dict[k] | ||
else: | ||
tensor = torch.empty(shape, dtype=dtype, device=device) | ||
cc.broadcast(tensor, src, group_name) | ||
if rank != src: | ||
out_dict[k] = tensor | ||
if rank == src: | ||
out_dict = tensor_dict | ||
return out_dict |
190 changes: 190 additions & 0 deletions
190
applications/ColossalChat/coati/distributed/consumer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
from contextlib import nullcontext | ||
from typing import Any, Dict, Optional | ||
|
||
import ray | ||
import ray.util.collective as cc | ||
import torch | ||
import torch.distributed as dist | ||
from tqdm import tqdm | ||
from transformers import AutoModelForCausalLM | ||
|
||
from colossalai.booster import Booster | ||
from colossalai.booster.plugin import HybridParallelPlugin | ||
from colossalai.initialize import launch | ||
from colossalai.nn.optimizer import HybridAdam | ||
from colossalai.utils import get_current_device | ||
|
||
from .comm import ray_broadcast_tensor_dict | ||
from .utils import bind_batch, post_recv, unbind_batch | ||
|
||
|
||
class BaseConsumer: | ||
def __init__( | ||
self, | ||
num_producers: int, | ||
num_episodes: int, | ||
rank: int, | ||
world_size: int, | ||
master_addr: str, | ||
master_port: int, | ||
num_update_per_episode: int, | ||
num_recv_per_update: int, | ||
batch_size: int, | ||
model_config: Dict[str, Any], | ||
plugin_config: Dict[str, Any], | ||
microbatch_size: int = 1, | ||
): | ||
self.num_producers = num_producers | ||
self.num_episodes = num_episodes | ||
self.rank = rank | ||
self.world_size = world_size | ||
self.master_addr = master_addr | ||
self.master_port = master_port | ||
self.num_update_per_episode = num_update_per_episode | ||
self.num_recv_per_update = num_recv_per_update | ||
self.batch_size = batch_size | ||
self.microbatch_size = microbatch_size | ||
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" | ||
self.num_microbatches = batch_size // microbatch_size | ||
|
||
self.model_config = model_config | ||
self.plugin_config = plugin_config | ||
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" | ||
|
||
self.device = get_current_device() | ||
|
||
def setup(self) -> None: | ||
for i in range(self.num_producers): | ||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") | ||
if self.rank == 0: | ||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") | ||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) | ||
|
||
plugin_config = dict( | ||
tp_size=1, | ||
pp_size=1, | ||
precision="bf16", | ||
zero_stage=1, | ||
) | ||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: | ||
plugin_config["microbatch_size"] = self.microbatch_size | ||
plugin_config.update(self.plugin_config) | ||
self.plugin = HybridParallelPlugin(**plugin_config) | ||
self.booster = Booster(plugin=self.plugin) | ||
self.dp_rank = dist.get_rank(self.plugin.dp_group) | ||
self.dp_size = dist.get_world_size(self.plugin.dp_group) | ||
|
||
self.buffer = [] | ||
|
||
self.recv_cnt = 0 | ||
|
||
def state_dict(self) -> Dict[str, torch.Tensor]: | ||
raise NotImplementedError | ||
|
||
def step(self, step_idx: int, **kwargs) -> Optional[float]: | ||
raise NotImplementedError | ||
|
||
def loop(self) -> None: | ||
print( | ||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" | ||
) | ||
for episode in range(self.num_episodes): | ||
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar: | ||
for step in pbar: | ||
i = 0 | ||
for _ in range(self.num_recv_per_update): | ||
# receive data from producers | ||
|
||
for r in range(self.num_producers): | ||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") | ||
self.buffer.extend( | ||
unbind_batch( | ||
ray_broadcast_tensor_dict( | ||
None, src=0, device=self.device, group_name=f"sync_data_{r}" | ||
) | ||
) | ||
) | ||
while len(self.buffer) >= self.dp_size * self.microbatch_size: | ||
batches = self.buffer[ | ||
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size | ||
] | ||
self.buffer = self.buffer[self.dp_size * self.microbatch_size :] | ||
batch = bind_batch(batches) | ||
batch = post_recv(batch) | ||
loss = self.step(i, **batch) | ||
if loss is not None: | ||
pbar.set_postfix({"loss": loss}) | ||
i += 1 | ||
assert len(self.buffer) == 0 | ||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: | ||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") | ||
state_dict = self.state_dict() | ||
if self.rank == 0: | ||
ray_broadcast_tensor_dict( | ||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model" | ||
) | ||
|
||
|
||
@ray.remote | ||
class SimpleConsumer(BaseConsumer): | ||
def __init__( | ||
self, | ||
num_producers, | ||
num_episodes, | ||
rank, | ||
world_size, | ||
master_addr, | ||
master_port, | ||
num_update_per_episode, | ||
num_recv_per_update, | ||
batch_size, | ||
model_config, | ||
plugin_config, | ||
microbatch_size=1, | ||
): | ||
super().__init__( | ||
num_producers, | ||
num_episodes, | ||
rank, | ||
world_size, | ||
master_addr, | ||
master_port, | ||
num_update_per_episode, | ||
num_recv_per_update, | ||
batch_size, | ||
model_config, | ||
plugin_config, | ||
microbatch_size, | ||
) | ||
path = model_config.pop("path") | ||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) | ||
self.model.train() | ||
self.model.gradient_checkpointing_enable() | ||
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3) | ||
self.accum_loss = torch.zeros(1, device=self.device) | ||
|
||
def setup(self): | ||
super().setup() | ||
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer) | ||
|
||
def step(self, step_idx: int, **kwargs) -> Optional[float]: | ||
need_update = (step_idx + 1) % self.num_microbatches == 0 | ||
|
||
ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer) | ||
with ctx: | ||
out = self.model(**kwargs) | ||
loss = out.loss / self.num_microbatches | ||
self.accum_loss.add_(loss.data) | ||
self.booster.backward(loss, self.optimizer) | ||
if need_update: | ||
self.optimizer.step() | ||
self.optimizer.zero_grad() | ||
loss_scalar = self.accum_loss.item() | ||
self.accum_loss.zero_() | ||
return loss_scalar | ||
|
||
def state_dict(self): | ||
self.model._force_wait_all_gather() | ||
model = self.model.unwrap() | ||
state_dict = model.state_dict() | ||
return state_dict |
Oops, something went wrong.