This repository has been archived by the owner on Mar 23, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[example]support ViT with ColoTensor
- Loading branch information
Showing
5 changed files
with
301 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Vision Transformer with ColoTensor | ||
|
||
# Overview | ||
|
||
In this example, we will run Vision Transformer with ColoTensor. | ||
We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit). | ||
You can change world size or decide whether use DDP in our code. | ||
|
||
# Requirement | ||
|
||
You should install colossalai from the **latest** main branch and install pytest, transformers with: | ||
|
||
```shell | ||
pip install pytest transformers | ||
``` | ||
|
||
# How to run | ||
|
||
In your terminal | ||
```shell | ||
pytest test_vit.py | ||
``` | ||
|
||
This will evaluate models with different **world_size** and **use_ddp**. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import torch | ||
import pytest | ||
import colossalai | ||
from colossalai.context.parallel_mode import ParallelMode | ||
import torch.multiprocessing as mp | ||
from colossalai.testing import rerun_if_address_is_in_use | ||
from colossalai.utils.cuda import get_current_device | ||
from colossalai.utils import free_port | ||
from colossalai.utils.model.colo_init_context import ColoInitContext | ||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec | ||
from colossalai.core import global_context as gpc | ||
from functools import partial | ||
from utils.util import tensor_equal, tensor_shard_equal, set_seed | ||
from vit import get_training_components | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from colossalai.nn.parallel.data_parallel import ColoDDP | ||
|
||
|
||
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. | ||
# But for other layers, it's 1d_col split. | ||
# Layernorm is not supported for now. | ||
# patch_embeddings.projection has nn.Conv2d | ||
# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182 | ||
def init_1d_row_for_linear_weight_spec(model): | ||
spec = TensorSpec( | ||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), | ||
ParallelAction(ComputePattern.TP1D)) | ||
with DistSpecManager.no_grad(): | ||
for n, p in model.named_parameters(): | ||
if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n: | ||
p.set_spec(spec) | ||
|
||
|
||
# Similarly, it's col split for Linear but row split for others. | ||
def init_1d_col_for_linear_weight_bias_spec(model): | ||
spec = TensorSpec( | ||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), | ||
ParallelAction(ComputePattern.TP1D)) | ||
with DistSpecManager.no_grad(): | ||
for n, p in model.named_parameters(): | ||
if ('weight' in n | ||
or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n: | ||
p.set_spec(spec) | ||
|
||
|
||
def check_param_equal(model, torch_model): | ||
for p, torch_p in zip(model.parameters(), torch_model.parameters()): | ||
assert tensor_shard_equal(torch_p, p) | ||
|
||
|
||
def check_grad_equal(model, torch_model): | ||
for (np, p), (ntp, torch_p) in zip(model.named_parameters(), torch_model.named_parameters()): | ||
if (torch_p.grad.shape == p.grad.shape): | ||
assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True | ||
else: | ||
dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) | ||
dim = dims_not_eq.item() | ||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) | ||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) | ||
assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True | ||
|
||
|
||
def run_vit(init_spec_func, use_ddp): | ||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components() | ||
|
||
with ColoInitContext(device=get_current_device()): | ||
model = model_builder() | ||
model = model.cuda() | ||
torch_model = model_builder().cuda() | ||
if use_ddp: | ||
model = ColoDDP(model) | ||
torch_model = DDP(torch_model, | ||
device_ids=[gpc.get_global_rank()], | ||
process_group=gpc.get_group(ParallelMode.DATA)) | ||
for torch_p, p in zip(torch_model.parameters(), model.parameters()): | ||
torch_p.data.copy_(p) | ||
init_spec_func(model) | ||
|
||
check_param_equal(model, torch_model) | ||
model.train() | ||
torch_model.train() | ||
set_seed(gpc.get_local_rank(ParallelMode.DATA)) | ||
|
||
optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) | ||
torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) | ||
|
||
for i, image_dict in enumerate(train_dataloader): | ||
if use_ddp: | ||
model.zero_grad() | ||
else: | ||
optimizer.zero_grad() | ||
logits = model(image_dict['pixel_values']) | ||
torch_logits = torch_model(image_dict['pixel_values']) | ||
assert tensor_equal(torch_logits.logits, logits.logits) | ||
loss = criterion(logits.logits, image_dict['label']) | ||
torch_loss = criterion(torch_logits.logits, image_dict['label']) | ||
if use_ddp: | ||
model.backward(loss) | ||
else: | ||
loss.backward() | ||
torch_loss.backward() | ||
check_grad_equal(model, torch_model) | ||
optimizer.step() | ||
torch_optimizer.step() | ||
check_param_equal(model, torch_model) | ||
break | ||
|
||
|
||
def run_dist(rank, world_size, port, use_ddp): | ||
if use_ddp and world_size == 1: | ||
return | ||
tp_world_size = world_size // 2 if use_ddp else world_size | ||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) | ||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') | ||
run_vit(init_1d_row_for_linear_weight_spec, use_ddp) | ||
run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp) | ||
|
||
|
||
@pytest.mark.dist | ||
@pytest.mark.parametrize('world_size', [1, 4]) | ||
@pytest.mark.parametrize('use_ddp', [False, True]) | ||
@rerun_if_address_is_in_use() | ||
def test_vit(world_size, use_ddp): | ||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) | ||
mp.spawn(run_func, nprocs=world_size) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_vit(4, False) |
25 changes: 25 additions & 0 deletions
25
image/vision_transformer/colo_vit/utils/dummy_data_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class DummyDataGenerator(ABC): | ||
|
||
def __init__(self, length=10): | ||
self.length = length | ||
|
||
@abstractmethod | ||
def generate(self): | ||
pass | ||
|
||
def __iter__(self): | ||
self.step = 0 | ||
return self | ||
|
||
def __next__(self): | ||
if self.step < self.length: | ||
self.step += 1 | ||
return self.generate() | ||
else: | ||
raise StopIteration | ||
|
||
def __len__(self): | ||
return self.length |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import os | ||
import random | ||
import numpy as np | ||
import torch | ||
import torch.distributed as dist | ||
from colossalai.core import global_context as gpc | ||
from colossalai.context import ParallelMode | ||
|
||
|
||
def set_seed(seed): | ||
random.seed(seed) | ||
os.environ['PYTHONHASHSEED'] = str(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.backends.cudnn.deterministic = True | ||
|
||
|
||
def check_equal(A, B): | ||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True | ||
|
||
|
||
def replace_parameter_add_grad(layer, weight=None, bias=None): | ||
if weight is not None: | ||
delattr(layer, 'weight') | ||
setattr(layer, 'weight', weight) | ||
layer.weight.requires_grad = True | ||
if bias is not None: | ||
delattr(layer, 'bias') | ||
setattr(layer, 'bias', bias) | ||
layer.bias.requires_grad = True | ||
|
||
|
||
def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): | ||
dist.broadcast(tensor, src=0) | ||
tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank] | ||
return tensor_chunk.clone() | ||
|
||
|
||
def tensor_equal(A, B): | ||
return torch.allclose(A, B, rtol=1e-3, atol=1e-1) | ||
|
||
|
||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): | ||
assert tensor.ndim == shard.ndim | ||
if tensor.shape == shard.shape: | ||
return tensor_equal(tensor, shard) | ||
else: | ||
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) | ||
if dims_not_eq.numel() == 1: | ||
# 1D shard | ||
dim = dims_not_eq.item() | ||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) | ||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) | ||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) | ||
else: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import torch | ||
import torch.nn as nn | ||
from transformers import ViTForImageClassification, ViTConfig | ||
from utils.dummy_data_generator import DummyDataGenerator | ||
from colossalai.utils.cuda import get_current_device | ||
|
||
|
||
class DummyDataLoader(DummyDataGenerator): | ||
batch_size = 4 | ||
channel = 3 | ||
category = 8 | ||
image_size = 224 | ||
|
||
def generate(self): | ||
image_dict = {} | ||
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size, | ||
DummyDataLoader.channel, | ||
DummyDataLoader.image_size, | ||
DummyDataLoader.image_size, | ||
device=get_current_device()) * 2 - 1 | ||
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), | ||
dtype=torch.int64, | ||
device=get_current_device()) | ||
return image_dict | ||
|
||
|
||
class ViTCVModel(nn.Module): | ||
|
||
def __init__(self, | ||
hidden_size=768, | ||
num_hidden_layers=12, | ||
num_attention_heads=12, | ||
image_size=224, | ||
patch_size=16, | ||
num_channels=3, | ||
num_labels=8, | ||
checkpoint=False): | ||
super().__init__() | ||
self.checkpoint = checkpoint | ||
self.model = ViTForImageClassification( | ||
ViTConfig(hidden_size=hidden_size, | ||
num_hidden_layers=num_hidden_layers, | ||
num_attention_heads=num_attention_heads, | ||
image_size=image_size, | ||
patch_size=patch_size, | ||
num_channels=num_channels, | ||
num_labels=num_labels)) | ||
if checkpoint: | ||
self.model.gradient_checkpointing_enable() | ||
|
||
def forward(self, pixel_values): | ||
return self.model(pixel_values=pixel_values) | ||
|
||
|
||
def vit_base_s(checkpoint=True): | ||
return ViTCVModel(checkpoint=checkpoint) | ||
|
||
|
||
def vit_base_micro(checkpoint=True): | ||
return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint) | ||
|
||
|
||
def get_training_components(): | ||
trainloader = DummyDataLoader() | ||
testloader = DummyDataLoader() | ||
return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy |