Skip to content
This repository has been archived by the owner on Mar 23, 2023. It is now read-only.

Commit

Permalink
[example]support ViT with ColoTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Jun 23, 2022
2 parents f71dd4f + 2428730 commit b4cfda7
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 0 deletions.
24 changes: 24 additions & 0 deletions image/vision_transformer/colo_vit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Vision Transformer with ColoTensor

# Overview

In this example, we will run Vision Transformer with ColoTensor.
We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit).
You can change world size or decide whether use DDP in our code.

# Requirement

You should install colossalai from the **latest** main branch and install pytest, transformers with:

```shell
pip install pytest transformers
```

# How to run

In your terminal
```shell
pytest test_vit.py
```

This will evaluate models with different **world_size** and **use_ddp**.
129 changes: 129 additions & 0 deletions image/vision_transformer/colo_vit/test_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
import pytest
import colossalai
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
from colossalai.core import global_context as gpc
from functools import partial
from utils.util import tensor_equal, tensor_shard_equal, set_seed
from vit import get_training_components
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel.data_parallel import ColoDDP


# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
# But for other layers, it's 1d_col split.
# Layernorm is not supported for now.
# patch_embeddings.projection has nn.Conv2d
# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182
def init_1d_row_for_linear_weight_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n:
p.set_spec(spec)


# Similarly, it's col split for Linear but row split for others.
def init_1d_col_for_linear_weight_bias_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if ('weight' in n
or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n:
p.set_spec(spec)


def check_param_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert tensor_shard_equal(torch_p, p)


def check_grad_equal(model, torch_model):
for (np, p), (ntp, torch_p) in zip(model.named_parameters(), torch_model.named_parameters()):
if (torch_p.grad.shape == p.grad.shape):
assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True
else:
dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape))
dim = dims_not_eq.item()
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True


def run_vit(init_spec_func, use_ddp):
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components()

with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model.cuda()
torch_model = model_builder().cuda()
if use_ddp:
model = ColoDDP(model)
torch_model = DDP(torch_model,
device_ids=[gpc.get_global_rank()],
process_group=gpc.get_group(ParallelMode.DATA))
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p)
init_spec_func(model)

check_param_equal(model, torch_model)
model.train()
torch_model.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA))

optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

for i, image_dict in enumerate(train_dataloader):
if use_ddp:
model.zero_grad()
else:
optimizer.zero_grad()
logits = model(image_dict['pixel_values'])
torch_logits = torch_model(image_dict['pixel_values'])
assert tensor_equal(torch_logits.logits, logits.logits)
loss = criterion(logits.logits, image_dict['label'])
torch_loss = criterion(torch_logits.logits, image_dict['label'])
if use_ddp:
model.backward(loss)
else:
loss.backward()
torch_loss.backward()
check_grad_equal(model, torch_model)
optimizer.step()
torch_optimizer.step()
check_param_equal(model, torch_model)
break


def run_dist(rank, world_size, port, use_ddp):
if use_ddp and world_size == 1:
return
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit(init_1d_row_for_linear_weight_spec, use_ddp)
run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_vit(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_vit(4, False)
25 changes: 25 additions & 0 deletions image/vision_transformer/colo_vit/utils/dummy_data_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod


class DummyDataGenerator(ABC):

def __init__(self, length=10):
self.length = length

@abstractmethod
def generate(self):
pass

def __iter__(self):
self.step = 0
return self

def __next__(self):
if self.step < self.length:
self.step += 1
return self.generate()
else:
raise StopIteration

def __len__(self):
return self.length
57 changes: 57 additions & 0 deletions image/vision_transformer/colo_vit/utils/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import random
import numpy as np
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode


def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True


def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True


def replace_parameter_add_grad(layer, weight=None, bias=None):
if weight is not None:
delattr(layer, 'weight')
setattr(layer, 'weight', weight)
layer.weight.requires_grad = True
if bias is not None:
delattr(layer, 'bias')
setattr(layer, 'bias', bias)
layer.bias.requires_grad = True


def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
dist.broadcast(tensor, src=0)
tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank]
return tensor_chunk.clone()


def tensor_equal(A, B):
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)


def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape:
return tensor_equal(tensor, shard)
else:
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
if dims_not_eq.numel() == 1:
# 1D shard
dim = dims_not_eq.item()
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
else:
raise NotImplementedError
66 changes: 66 additions & 0 deletions image/vision_transformer/colo_vit/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import torch.nn as nn
from transformers import ViTForImageClassification, ViTConfig
from utils.dummy_data_generator import DummyDataGenerator
from colossalai.utils.cuda import get_current_device


class DummyDataLoader(DummyDataGenerator):
batch_size = 4
channel = 3
category = 8
image_size = 224

def generate(self):
image_dict = {}
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size,
DummyDataLoader.channel,
DummyDataLoader.image_size,
DummyDataLoader.image_size,
device=get_current_device()) * 2 - 1
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
dtype=torch.int64,
device=get_current_device())
return image_dict


class ViTCVModel(nn.Module):

def __init__(self,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
image_size=224,
patch_size=16,
num_channels=3,
num_labels=8,
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = ViTForImageClassification(
ViTConfig(hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
image_size=image_size,
patch_size=patch_size,
num_channels=num_channels,
num_labels=num_labels))
if checkpoint:
self.model.gradient_checkpointing_enable()

def forward(self, pixel_values):
return self.model(pixel_values=pixel_values)


def vit_base_s(checkpoint=True):
return ViTCVModel(checkpoint=checkpoint)


def vit_base_micro(checkpoint=True):
return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint)


def get_training_components():
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy

0 comments on commit b4cfda7

Please sign in to comment.