Skip to content

Commit

Permalink
DeviceMesh initialization fix and VeDeviceMesh 2.0 (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
MackZackA authored Apr 26, 2024
1 parent 97735b1 commit dd44ba5
Show file tree
Hide file tree
Showing 47 changed files with 843 additions and 466 deletions.
70 changes: 31 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,63 +1,55 @@
# veScale: A PyTorch Native LLM Training Framework
<div align="center">
<img src="./docs/pictures/icon.png" width="150"/>
</div>

## Coming Soon
# A PyTorch Native LLM Training Framework

We are refactoring our [internal LLM training system](https://arxiv.org/abs/2402.15627) components to meet open source standard. The tentative timeline is as follows:
_**An Industrial-Level Framework for Easy-of-Use**_

1. by mid April, 4D parallelism (tensor parallelism, sequence parallelism, data parallelism and ZERO) examples for nanoGPT, Llama2 and Mixtral models
2. by end of May, fast checkpointing system
3. by end of July, CUDA event monitor, pipeline parallelism and supporting components for large-scale training
- 🔥 **PyTorch Native**: veScale is rooted in PyTorch-native data structures, operators, and APIs, enjoying the ecosystem of PyTorch that dominates the ML world.

## Installation
- 🛡 **Zero Model Code Change**: veScale decouples distributed system design from model architecture, requiring near-zero or zero modification on the model code of users.

### From Source
- 🚀 **Single Device Abstraction**: veScale provides single-device semantics to users, automatically distributing and orchestrating model execution in a cluster of devices.

#### Install a Patched Version of PyTorch
- 🎯 **Automatic Parallelism Planning**: veScale parallelizes model execution with a synergy of strategies (tensor, sequence, data, ZeRO, pipeline parallelism) under semi- or full-automation [coming soon].

```bash
bash patches/build_pytorch_w_patch.sh
```
-**Eager & Compile Mode**: veScale supports not only Eager-mode automation for parallel training and inference but also Compile-mode for ultimate performance [coming soon].

This will compile and install a patched version of PyTorch (based on v2.2.1_rc3).
The patch code can be found here: [PyTorch-Patch](patches/patched_pytorch_v2.2.1_rc3.patch)
- 📀 **Automatic Checkpoint Resharding**: veScale manages distributed checkpoints automatically with online resharding across different cluster sizes and different parallelism strategies.

#### Install a Patched Version of TorchDistX

```bash
bash patches/build_torchdistX_w_patch.sh
```
## Coming Soon

This will compile and install a patched version of TorchdistX (based on its master).
The patch code can be found here: [TorchDistX-Patch](patches/patched_torchdistX_9c1b9f.patch)
_**veScale**_ is still in its early phase. We are refactoring our [internal LLM training system](https://arxiv.org/abs/2402.15627) components to meet open source standard. The tentative timeline is as follows:

#### Install veScale
- by end of May, fast checkpointing system

```bash
pushd python && pip3 install -r requirements.txt && pip3 install -e . && popd
```
- by end of July, CUDA event monitor, pipeline parallelism and supporting components for large-scale training

This will install veScale and its dependencies.
## Table of Content ([web view](https://volcengine.github.io/veScaleWeb/))

### Docker Image
**[Introduction](./docs/texts/introduction.md)**

#### Build the Docker Image
**[Quick Start](./docs/texts/quick-start.md)**

Make sure it is in the Vescale directory.
**[DTensor](./vescale/dtensor/README.md)**

```bash
docker build .
```
It may take a while to build the image.
**Parallel**
* [Overview](./docs/texts/parallel_overview.md)
* [Tensor Parallel & Sequence Parallel](./vescale/dmodule/README.md)
* [Data Parallel](./vescale/ddp/README.md)
* [Optimizer Parallel](./vescale/optim/README.md)
* [Pipeline Parallel](./vescale/pipe/README.md)
* [nD Device Mesh](./vescale/devicemesh_api/README.md)

Once the building process is finished, you can `docker run` with the id.
**Plan**
* [Auto TP & SP Plan](./vescale/dmp/README.md)

**[Checkpoint](./vescale/checkpoint/README.md)**

## [We Are Hiring!](https://volcengine.github.io/veScaleWeb/misc/join-us.html) ##

## [License](./LICENSE)

The veScale Project is under the Apache License v2.0.

## Acknowledgement

veScale team would like to sincerely acknowledge the assistance of and collaboration with
the [PyTorch DTensor team](https://github.com/pytorch/pytorch/tree/main/torch/distributed/_tensor).
The veScale Project is under the Apache License v2.0.
Binary file added docs/pictures/ddp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/dmodule.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/doptimizer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/dtensor.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/icon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/parallel5d.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/tldr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/vedevicemesh.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/vescale-logo-dark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pictures/vescale-logo-light.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
58 changes: 58 additions & 0 deletions docs/texts/introduction.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# veScale: A PyTorch Native LLM Training Framework

## TLDR

An _**Industrial-Level**_ Framework for _**Easy-of-Use**_:

<img src="../../docs/pictures/tldr.png" alt="TL'DR" width="400"/>

(`*` is under development.)

## Why veScale

The era of giant models today calls forth distributed training.
Despite countless distributed training frameworks that have been published in the past decade (to name a few), few have excelled at the _**Ease-of-Use**_ and development extensibility demanded by real industry production,
as the quality most favored for a framework is often the _**Ease-of-Use**_ instead of pure _Performance_.
Companies developing 100s~1000s models a week benefit the most from a framework that is both easy to use and extend, and provides elegant encapsulation of models and clean APIs.

The _**Ease-of-Use**_ of a framework for training and developing LLM lies in the following essentials:

- 🔥 **PyTorch Native**: _PyTorch_ ecosystem dominates the ML world and owns 92% of models on _HuggingFace_ and 70% of research on _Papers with Code_; Alienating from _PyTorch_ ecosystem makes a framework hard to adapt and extend.

- 🛡 **Zero Model Code Change**: Users' model code should remain untouched, instead of being intertwined with framework code, which requires users to not only manually rewrite the model for distributed training with tons of care, but also painfully debug within the deep coupled model and framework code.

- 🚀 **Single Device Abstraction**: Model developers should focus on developing model architecture itself with single device semantics, rather than being distracted by the complex and error-prone management of multiple devices and diverse interconnects in distributed environments.

- 🎯 **Automatic Parallelism Planning**: Gigantic models cannot be trained without _nD Parallelism_ (_Tensor, Sequence, Data, ZeRO, Pipeline Parallelism, etc._). Users' giant models should be automatically scaled by a framework for _nD_ parallel training, instead of being manually planned and tuned for each operator or layer under different cluster settings, which takes forever.

-**Eager & Compile Mode**: Users should enjoy both _Eager_ and _Compile_ mode offered by a framework with:
- _Eager_ mode for fast development, convenient debugging, and customization with callbacks and control flows;
- _Compile_ mode for ultimate performance boost with a single click.

- 📀 **Automatic Checkpoint Resharding**: Training models and optimizer states should be saved/loaded automatically and performantly in distributed settings, and can even be _online resharded_ across different cluster sizes and different _nD Parallelism_.

## What is veScale

**veScale**'s overview is as follows:

<img src="../../docs/pictures/overview.png" alt="overview" width="700"/>

We take an initial step to develop an _**Industry-Level**_ framework, **veScale**, that focuses _**Ease-of-Use**_ for scaling LLM training, by combining _PyTorch Nativeness_ and _Automatic Parallelism*_.

Ideally, **veScale** only expects model developers to write a simple model code with native _torch.nn.Module_ under _Zero Code Change_ as if running on a _Single Device_, and then **veScale** will automatically parallelize it across a cluster of devices in a _nD Parallelism_ search space with all the optimizations and heavy lifting handled transparently.

Unlike existing frameworks that rely on _Compile_ mode and a "perfect model graph" for _Automatic Parallelism_, **veScale** is inventing an _Eager-Mode-ONLY*_ _Automatic Parallelism_ that does not rely on the model graph at all.
Furthermore, **veScale** is also developing a _Mixed Mode_* of partial _Eager_ and partial _Compile_.

**veScale** is designed and implemented on top of a primitive called _DTensor_ that provides a global tensor semantic with local shards distributed on multiple devices.
**veScale** extends and enhances the _PyTorch DTensor_ for our production standard, and further develops the _Auto-Plan*_ and _Auto-Paralleize_ with a unified configuration and API.

Furthermore, **veScale** also supports online _Auto-Reshard_ for distributed checkpoints, which will be open-sourced as a new project -- **OmniStore**.

(`*` is under development)

## Status of veScale

**veScale** is still in its early phase.

The tentative open-source timeline can be found in the **veScale** [**repo**](https://github.com/volcengine/veScale/tree/main).
50 changes: 50 additions & 0 deletions docs/texts/parallel_overview.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# veScale Parallel Overview

The overview of veScale _n-D parallelism_ is as follows:

<img src="../../docs/pictures/parallel5d.png" alt="5D" width="600"/>

(`*` is under development)

The _Auto-Parallelize_ block takes the untouched _Model_ from the user and _Parallel Plan_ (given by manual effort, prefined for each model type, or automatically generated from _Auto-Plan*_) and then parallelizes the single-device model into _nD Parallelism_ across a mesh of devices.

veScale's _nD Parallelism_ follows a decoupled design where each D of parallelism is handled by an independent sub-block (e.g., _DModule_ only handles _Tensor & Sequence Parallel_, without coupling with other _Parallel_).
In contrast to the conventional _coupled_ design that intertwines all parallelism together, such a _decoupled_ _nD Parallelism_ enjoys composability, debuggability, explainability, and extensibility, all of which are of great value for hyper-scale training in production.

## 4D Parallelisim API

Our 4D parallelism (_Tensor, Sequence, Data, and ZeRO2_) is as follows:

``` python
# zero model code change
from <HuggingFace> import <ModelCls>, <ModelArgs>

# create fake model without actual memory usage (optional)
fake_model = deferred_init(<ModelCls>, <ModelArgs>)

# initialize 4D device mesh
mesh = init_device_mesh("cuda", (dp_zero_size, tp_sp_size), mesh_dim_names=["DP_ZERO", "TP_SP"])

# parallelize model in tp & sp
from <PredefinedPlan> import sharding_plan
real_tp_sp_model = parallelize_module(fake_model, mesh["TP_SP"], sharding_plan)

# parallelize model in dp
ddp_model = DDP(real_tp_sp_model, mesh["DP_ZERO"])

# parallelize model with zero2
doptimizer = DistributedOptimizer(torch.optim.AdamW, models=[ddp_model])

# train model as if on a single device
for x in range(dataset):
loss = ddp_model(x)
loss.backward()
doptimizer.step()
doptimizer.zero_grad()
```

More examples can be found in: `<repo>/examples/`.

## 5D Parallelisim API

Coming Soon
52 changes: 52 additions & 0 deletions docs/texts/quick-start.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Quick Start

First, find the **veScale** [**repo**](https://github.com/volcengine/veScale/tree/main).

## Installation

### From Source

#### Install a Patched Version of PyTorch

```bash
bash [repo]/patches/build_pytorch_w_patch.sh
```

This will compile and install a patched version of PyTorch.

#### Install a Patched Version of TorchDistX

```bash
bash [repo]/patches/build_torchdistX_w_patch.sh
```

This will compile and install a patched version of TorchdistX (based on its master).

#### Install veScale

```bash
pushd python && pip3 install -r requirements.txt && pip3 install -e . && popd
```

This will install **veScale** and its dependencies.

### Docker Image

#### Build the Docker Image

Make sure it is in the veScale directory.

```bash
docker build .
```
It may take a while to build the image.

Once the building process is finished, you can `docker run` with the id.

## Run Examples

- Nano GPT: `<repo>/examples/nanogpt_4D_finetune`

- Open LLAMA: `<repo>/examples/open_llama_4D_benchmark`

- Mixtral: `<repo>/examples/mixtral_4D_benchmark`
9 changes: 6 additions & 3 deletions examples/mixtral_4D_benchmark/mixtral_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@ def estimate_mixtral(config, bsz, sqence_length):
embed = 4 * bsz * sqence_length * config.hidden_size
# MixtralMoE consists of 3 linear layers.
ff = 3 * 2 * config.num_experts_per_tok * config.hidden_size * config.intermediate_size * bsz * sqence_length
attn_qkv = 2 * bsz * sqence_length * config.hidden_size * 3 * config.hidden_size
# GQA
head_size = config.hidden_size // config.num_attention_heads
attn_q = 2 * bsz * sqence_length * config.hidden_size * config.hidden_size
attn_kv = 2 * 2 * bsz * sqence_length * config.hidden_size * config.num_key_value_heads * head_size
attn_mask = 2 * sqence_length * config.hidden_size
attn_proj = 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length
attn = attn_qkv + attn_mask + attn_proj
attn_proj = 2 * config.hidden_size * config.hidden_size * bsz * sqence_length
attn = attn_q + attn_kv + attn_mask + attn_proj
return embed + (ff + attn) * config.num_hidden_layers


Expand Down
12 changes: 6 additions & 6 deletions examples/nanogpt_4D_finetune/finetune_4D.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch.distributed import broadcast, all_reduce, barrier, init_process_group, destroy_process_group, get_rank

from model import GPTConfig, GPT
from vescale.devicemesh_api.device_mesh_api import veDeviceMesh
from vescale.devicemesh_api import VESCALE_DEVICE_MESH

from vescale import distribute_tensor
from vescale.dmodule.api import parallelize_module
Expand Down Expand Up @@ -114,7 +114,7 @@ def main():
torch.cuda.set_device(device)
init_process_group(backend=backend, world_size=world_size, rank=rank)

mesh = veDeviceMesh.init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
VESCALE_DEVICE_MESH.init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
ddp_rank = get_rank() // tp_size
else:
rank = 0
Expand Down Expand Up @@ -162,8 +162,8 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size):
else:
x, y = x.to(device), y.to(device)
if ddp:
x = distribute_tensor(x, mesh["TP"], [Replicate()])
y = distribute_tensor(y, mesh["TP"], [Replicate()])
x = distribute_tensor(x, VESCALE_DEVICE_MESH["TP"], [Replicate()])
y = distribute_tensor(y, VESCALE_DEVICE_MESH["TP"], [Replicate()])
return x, y

# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
Expand Down Expand Up @@ -235,10 +235,10 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size):

# + + + parallelize the model and wrap it with DDP using veScale APIs
if ddp:
model = parallelize_module(model, mesh["TP"], nanoGPT_plan)
model = parallelize_module(model, VESCALE_DEVICE_MESH["TP"], nanoGPT_plan)
model = DDP(
model,
data_pg_or_device_mesh=mesh["DP"],
data_pg_or_device_mesh=VESCALE_DEVICE_MESH["DP"],
accumulate_allreduce_grads_in_fp32=DDP_grads_in_fp32,
overlap_grad_reduce=False,
use_distributed_optimizer=use_DO,
Expand Down
18 changes: 8 additions & 10 deletions test/checkpoint/common_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,22 @@ def build_gpt_model_optimizer_and_dataset(init_method, dp_size=1, tp_size=1):

open_source = False
try:
from vescale.devicemesh_api import veDeviceMesh
from vescale.devicemesh_api import VESCALE_DEVICE_MESH
except ImportError:
open_source = True
device_mesh = veDeviceMesh.init_device_mesh(
VESCALE_DEVICE_MESH.init_device_mesh(
device_type="cuda",
mesh_shape=(dp_size, tp_size),
mesh_dim_names=("DP", "TP"),
)

# Enable tensor Parallel
tp_gpt = parallelize_module(gpt, device_mesh["TP"], nanoGPT_plan)
tp_gpt = parallelize_module(gpt, VESCALE_DEVICE_MESH["TP"], nanoGPT_plan)

# Enable data Parallel
ddp_gpt = DDP(
tp_gpt,
data_pg_or_device_mesh=device_mesh["DP"],
data_pg_or_device_mesh=VESCALE_DEVICE_MESH["DP"],
accumulate_allreduce_grads_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
Expand Down Expand Up @@ -280,24 +280,22 @@ def get_open_llama_model(layer_number=None):


def get_open_llama_model_optimizer(dp_size, tp_size, layer_number=None):
from vescale.devicemesh_api import veDeviceMesh
from vescale.devicemesh_api import VESCALE_DEVICE_MESH

device_mesh = veDeviceMesh.init_device_mesh(
"cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP"), check_uniqueness=True
)
VESCALE_DEVICE_MESH.init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP"), check_uniqueness=True)
# Set 4 layers to avoid timeout on CI
# Use 32 layers when running on training platform
vescale_decoder, config = get_open_llama_model(layer_number=layer_number)

vescale_decoder = parallelize_module(
vescale_decoder,
device_mesh["TP"],
VESCALE_DEVICE_MESH["TP"],
sharding_plan,
)

ddp_decoder = DDP(
vescale_decoder,
data_pg_or_device_mesh=device_mesh["DP"],
data_pg_or_device_mesh=VESCALE_DEVICE_MESH["DP"],
accumulate_allreduce_grads_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
Expand Down
Loading

0 comments on commit dd44ba5

Please sign in to comment.