This repository has been archived by the owner on Mar 23, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1371a5b
commit c8ee7d1
Showing
5 changed files
with
286 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Vision Transformer with ColoTensor | ||
|
||
# Overview | ||
|
||
In this example, we will run Vision Transformer with ColoTensor. | ||
We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit). | ||
You can change world size or decide whether use DDP in our code. | ||
|
||
# Requirement | ||
|
||
You should install colossalai from the **latest** main branch and install pytest, transformers with: | ||
|
||
```shell | ||
pip install pytest transformers | ||
``` | ||
|
||
# How to run | ||
|
||
In your terminal | ||
```shell | ||
pytest test_vit.py | ||
``` | ||
|
||
This will evaluate models with different **world_size** and **use_ddp**. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import torch | ||
import pytest | ||
import colossalai | ||
from colossalai.context.parallel_mode import ParallelMode | ||
import torch.multiprocessing as mp | ||
from colossalai.testing import rerun_if_address_is_in_use | ||
from colossalai.utils.cuda import get_current_device | ||
from colossalai.utils import free_port | ||
from colossalai.utils.model.colo_init_context import ColoInitContext | ||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec | ||
from colossalai.core import global_context as gpc | ||
from functools import partial | ||
from utils.util import tensor_equal, tensor_shard_equal, set_seed | ||
from vit import get_training_components | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from colossalai.nn.parallel.data_parallel import ColoDDP | ||
|
||
|
||
def init_1d_row_spec(model): | ||
spec = TensorSpec( | ||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), | ||
ParallelAction(ComputePattern.TP1D)) | ||
with DistSpecManager.no_grad(): | ||
for n, p in model.named_parameters(): | ||
if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n: | ||
p.set_spec(spec) | ||
|
||
|
||
def init_1d_col_spec(model): | ||
spec = TensorSpec( | ||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), | ||
ParallelAction(ComputePattern.TP1D)) | ||
|
||
with DistSpecManager.no_grad(): | ||
for n, p in model.named_parameters(): | ||
if ('weight' in n | ||
or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n: | ||
p.set_spec(spec) | ||
|
||
|
||
def check_param_equal(model, torch_model): | ||
for p, torch_p in zip(model.parameters(), torch_model.parameters()): | ||
assert tensor_shard_equal(torch_p, p) | ||
|
||
|
||
def check_grad_equal(model, torch_model): | ||
for p, torch_p in zip(model.parameters(), torch_model.parameters()): | ||
if (torch_p.grad.shape == p.grad.shape): | ||
assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True | ||
else: | ||
dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) | ||
dim = dims_not_eq.item() | ||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) | ||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) | ||
assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True | ||
|
||
|
||
def run_vit(init_spec_func, use_ddp): | ||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components() | ||
|
||
with ColoInitContext(device=get_current_device()): | ||
model = model_builder() | ||
model = model.cuda() | ||
torch_model = model_builder().cuda() | ||
if use_ddp: | ||
model = ColoDDP(model) | ||
torch_model = DDP(torch_model, | ||
device_ids=[gpc.get_global_rank()], | ||
process_group=gpc.get_group(ParallelMode.DATA)) | ||
for torch_p, p in zip(torch_model.parameters(), model.parameters()): | ||
torch_p.data.copy_(p) | ||
init_spec_func(model) | ||
|
||
check_param_equal(model, torch_model) | ||
model.train() | ||
torch_model.train() | ||
set_seed(gpc.get_local_rank(ParallelMode.DATA)) | ||
for i, image_dict in enumerate(train_dataloader): | ||
logits = model(image_dict['pixel_values']) | ||
torch_logits = torch_model(image_dict['pixel_values']) | ||
assert tensor_equal(torch_logits.logits, logits.logits) | ||
loss = criterion(logits.logits, image_dict['label']) | ||
torch_loss = criterion(torch_logits.logits, image_dict['label']) | ||
if use_ddp: | ||
model.backward(loss) | ||
else: | ||
loss.backward() | ||
torch_loss.backward() | ||
check_grad_equal(model, torch_model) | ||
if i > 0: | ||
break | ||
|
||
|
||
def run_dist(rank, world_size, port, use_ddp): | ||
if use_ddp and world_size == 1: | ||
return | ||
tp_world_size = world_size // 2 if use_ddp else world_size | ||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) | ||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') | ||
run_vit(init_1d_row_spec, use_ddp) | ||
run_vit(init_1d_col_spec, use_ddp) | ||
|
||
|
||
@pytest.mark.dist | ||
@pytest.mark.parametrize('world_size', [1, 4]) | ||
@pytest.mark.parametrize('use_ddp', [False, True]) | ||
@rerun_if_address_is_in_use() | ||
def test_vit(world_size, use_ddp): | ||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) | ||
mp.spawn(run_func, nprocs=world_size) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_vit(4, False) |
25 changes: 25 additions & 0 deletions
25
image/vision_transformer/colo_vit/utils/dummy_data_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class DummyDataGenerator(ABC): | ||
|
||
def __init__(self, length=10): | ||
self.length = length | ||
|
||
@abstractmethod | ||
def generate(self): | ||
pass | ||
|
||
def __iter__(self): | ||
self.step = 0 | ||
return self | ||
|
||
def __next__(self): | ||
if self.step < self.length: | ||
self.step += 1 | ||
return self.generate() | ||
else: | ||
raise StopIteration | ||
|
||
def __len__(self): | ||
return self.length |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import os | ||
import random | ||
import numpy as np | ||
import torch | ||
import torch.distributed as dist | ||
from colossalai.core import global_context as gpc | ||
from colossalai.context import ParallelMode | ||
|
||
|
||
def set_seed(seed): | ||
random.seed(seed) | ||
os.environ['PYTHONHASHSEED'] = str(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.backends.cudnn.deterministic = True | ||
|
||
|
||
def check_equal(A, B): | ||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True | ||
|
||
|
||
def replace_parameter_add_grad(layer, weight=None, bias=None): | ||
if weight is not None: | ||
delattr(layer, 'weight') | ||
setattr(layer, 'weight', weight) | ||
layer.weight.requires_grad = True | ||
if bias is not None: | ||
delattr(layer, 'bias') | ||
setattr(layer, 'bias', bias) | ||
layer.bias.requires_grad = True | ||
|
||
|
||
def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): | ||
dist.broadcast(tensor, src=0) | ||
tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank] | ||
return tensor_chunk.clone() | ||
|
||
|
||
def tensor_equal(A, B): | ||
return torch.allclose(A, B, rtol=1e-3, atol=1e-1) | ||
|
||
|
||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): | ||
assert tensor.ndim == shard.ndim | ||
if tensor.shape == shard.shape: | ||
return tensor_equal(tensor, shard) | ||
else: | ||
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) | ||
if dims_not_eq.numel() == 1: | ||
# 1D shard | ||
dim = dims_not_eq.item() | ||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) | ||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) | ||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) | ||
else: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import torch | ||
import torch.nn as nn | ||
from transformers import ViTForImageClassification, ViTConfig | ||
from utils.dummy_data_generator import DummyDataGenerator | ||
from colossalai.utils.cuda import get_current_device | ||
|
||
|
||
class DummyDataLoader(DummyDataGenerator): | ||
batch_size = 4 | ||
channel = 3 | ||
category = 8 | ||
image_size = 224 | ||
|
||
def generate(self): | ||
image_dict = {} | ||
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size, | ||
DummyDataLoader.channel, | ||
DummyDataLoader.image_size, | ||
DummyDataLoader.image_size, | ||
device=get_current_device()) * 2 - 1 | ||
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), | ||
dtype=torch.int64, | ||
device=get_current_device()) | ||
return image_dict | ||
|
||
|
||
class ViTCVModel(nn.Module): | ||
|
||
def __init__(self, | ||
hidden_size=768, | ||
num_hidden_layers=12, | ||
num_attention_heads=12, | ||
image_size=224, | ||
patch_size=16, | ||
num_channels=3, | ||
num_labels=8, | ||
checkpoint=False): | ||
super().__init__() | ||
self.checkpoint = checkpoint | ||
self.model = ViTForImageClassification( | ||
ViTConfig(hidden_size=hidden_size, | ||
num_hidden_layers=num_hidden_layers, | ||
num_attention_heads=num_attention_heads, | ||
image_size=image_size, | ||
patch_size=patch_size, | ||
num_channels=num_channels, | ||
num_labels=num_labels)) | ||
if checkpoint: | ||
self.model.gradient_checkpointing_enable() | ||
|
||
def forward(self, pixel_values): | ||
return self.model(pixel_values=pixel_values) | ||
|
||
|
||
def vit_base_s(checkpoint=True): | ||
return ViTCVModel(checkpoint=checkpoint) | ||
|
||
|
||
def vit_base_micro(checkpoint=True): | ||
return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint) | ||
|
||
|
||
def get_training_components(): | ||
trainloader = DummyDataLoader() | ||
testloader = DummyDataLoader() | ||
return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy |