-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[COMPILE] workflow for deepspeed + torch.compile (#6570)
We use simple model + deepspeed zero 3 + torch.compile and count graph break numbers to demonstrate current status of combing deepspeed + torch.compile. --------- Co-authored-by: Masahiro Tanaka <[email protected]>
- Loading branch information
Showing
3 changed files
with
199 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
name: xpu-compile | ||
|
||
on: | ||
workflow_dispatch: | ||
schedule: | ||
- cron: "0 0 * * *" | ||
pull_request: | ||
paths: | ||
- ".github/workflows/xpu-compile.yml" | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
permissions: | ||
contents: read | ||
issues: write | ||
|
||
jobs: | ||
compile-tests: | ||
runs-on: [self-hosted, intel, xpu] | ||
container: | ||
image: intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 | ||
ports: | ||
- 80 | ||
options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Install prerequisite | ||
run: | | ||
apt-get update | ||
apt-get install clinfo libaio-dev python3-pip -y | ||
pip install torch==2.3.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torch/ | ||
pip install intel-extension-for-pytorch==2.3.110+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/intel-extension-for-pytorch/ | ||
pip install oneccl_bind_pt==2.3.100+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/oneccl-bind-pt/ | ||
pip install torchvision==0.18.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torchvision/ | ||
pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v3.0.0b2/triton_xpu-3.0.0b2-cp310-cp310-linux_x86_64.whl | ||
pip install py-cpuinfo numpy | ||
pip install .[dev,autotuning] | ||
- name: Check container state | ||
run: | | ||
ldd --version | ||
ds_report | ||
python3 -c "import torch; print('torch:', torch.__version__, torch)" | ||
python3 -c "import torch; import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())" | ||
python3 -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)" | ||
pip list | ||
- name: Compile Status | ||
shell: bash | ||
run: | | ||
export FI_HMEM=system | ||
ulimit -n 1048575 | ||
cd tests/torch_compile | ||
export ZE_AFFINITY_MASK=0,1 | ||
deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt | ||
cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
{ | ||
"train_batch_size": 8, | ||
"steps_per_print": 2000, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.001, | ||
"betas": [ | ||
0.8, | ||
0.999 | ||
], | ||
"eps": 1e-8, | ||
"weight_decay": 3e-7 | ||
} | ||
}, | ||
"scheduler": { | ||
"type": "WarmupLR", | ||
"params": { | ||
"warmup_min_lr": 0, | ||
"warmup_max_lr": 0.001, | ||
"warmup_num_steps": 1000 | ||
} | ||
}, | ||
"gradient_clipping": 1.0, | ||
"prescale_gradients": false, | ||
"bf16": { | ||
"enabled": true, | ||
"loss_scale": 0, | ||
"loss_scale_window": 500, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1, | ||
"initial_scale_power": 15 | ||
}, | ||
"wall_clock_breakdown": false, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"reduce_scatter": true, | ||
"overlap_comm": false, | ||
"contiguous_gradients": false | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import argparse | ||
import deepspeed | ||
from deepspeed.accelerator import get_accelerator | ||
from deepspeed import comm | ||
|
||
import torch | ||
import intel_extension_for_pytorch # noqa: F401 # type: ignore | ||
from torch.utils.data import Dataset, DataLoader | ||
|
||
torch._dynamo.config.cache_size_limit = 100 | ||
|
||
import collections | ||
|
||
|
||
def get_dynamo_stats(): | ||
# TODO: consider deepcopy'ing the entire counters struct and | ||
# adding a helper to do subtraction on it | ||
return collections.Counter({ | ||
"calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"], | ||
"unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"], | ||
"graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()), | ||
# NB: The plus removes zero counts | ||
"unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]), | ||
"autograd_captures": torch._dynamo.utils.counters["compiled_autograd"]["captures"], | ||
"autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"]["compiles"], | ||
"cudagraph_skips": torch._dynamo.utils.counters["inductor"]["cudagraph_skips"], | ||
}) | ||
|
||
|
||
class RandomDataset(Dataset): | ||
|
||
def __init__(self, size, length): | ||
self.len = length | ||
self.data = torch.randn(length, size).to(torch.bfloat16) | ||
|
||
def __getitem__(self, index): | ||
return self.data[index] | ||
|
||
def __len__(self): | ||
return self.len | ||
|
||
|
||
data_size = 1024 | ||
data_length = 100 | ||
rand_loader = DataLoader(dataset=RandomDataset(data_size, data_length), batch_size=1, shuffle=False) | ||
|
||
|
||
class MyModule(torch.nn.Module): | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.fc0 = torch.nn.Linear(1024, 256, bias=False) | ||
self.fc1 = torch.nn.Linear(256, 256, bias=False) | ||
self.dropout = torch.nn.Dropout(0.5) | ||
|
||
def forward(self, data, residual): | ||
output = residual + self.fc1(self.fc0(self.dropout(data))) * 0.5 | ||
return output | ||
|
||
|
||
model = MyModule() | ||
params = model.parameters() | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher') | ||
parser.add_argument('--deepspeed_config', | ||
type=str, | ||
default='ds_config.json', | ||
help='path to DeepSpeed configuration file') | ||
cmd_args = parser.parse_args() | ||
|
||
# initialize the DeepSpeed engine | ||
model_engine, optimizer, _, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=params) | ||
model_engine.compile() | ||
|
||
residual = torch.rand(256, 256, dtype=torch.float).to(get_accelerator().current_device_name()) | ||
|
||
start_stats = get_dynamo_stats() | ||
|
||
for step, batch in enumerate(rand_loader): | ||
if step % 10 == 0 and comm.get_rank() == 0: | ||
print(f'step={step}') | ||
# forward() method | ||
loss = model_engine(batch.to(get_accelerator().current_device_name()), residual).sum() | ||
# runs backpropagation | ||
model_engine.backward(loss) | ||
# weight update | ||
model_engine.step() | ||
|
||
dynamo_stats = get_dynamo_stats() | ||
dynamo_stats.subtract(start_stats) | ||
|
||
if comm.get_rank() == 0: | ||
print(dynamo_stats) |