From c8ee7d17fbff79f289482ce5bb2706a71f715fa5 Mon Sep 17 00:00:00 2001 From: ZhaoYi1222 Date: Thu, 16 Jun 2022 16:47:41 +0800 Subject: [PATCH 1/3] support ViT with ColoTensor --- image/vision_transformer/colo_vit/README.md | 24 ++++ image/vision_transformer/colo_vit/test_vit.py | 114 ++++++++++++++++++ .../colo_vit/utils/dummy_data_generator.py | 25 ++++ .../vision_transformer/colo_vit/utils/util.py | 57 +++++++++ image/vision_transformer/colo_vit/vit.py | 66 ++++++++++ 5 files changed, 286 insertions(+) create mode 100644 image/vision_transformer/colo_vit/README.md create mode 100644 image/vision_transformer/colo_vit/test_vit.py create mode 100644 image/vision_transformer/colo_vit/utils/dummy_data_generator.py create mode 100644 image/vision_transformer/colo_vit/utils/util.py create mode 100644 image/vision_transformer/colo_vit/vit.py diff --git a/image/vision_transformer/colo_vit/README.md b/image/vision_transformer/colo_vit/README.md new file mode 100644 index 0000000..97987b6 --- /dev/null +++ b/image/vision_transformer/colo_vit/README.md @@ -0,0 +1,24 @@ +# Vision Transformer with ColoTensor + +# Overview + +In this example, we will run Vision Transformer with ColoTensor. +We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit). +You can change world size or decide whether use DDP in our code. + +# Requirement + +You should install colossalai from the **latest** main branch and install pytest, transformers with: + +```shell +pip install pytest transformers +``` + +# How to run + +In your terminal +```shell +pytest test_vit.py +``` + +This will evaluate models with different **world_size** and **use_ddp**. diff --git a/image/vision_transformer/colo_vit/test_vit.py b/image/vision_transformer/colo_vit/test_vit.py new file mode 100644 index 0000000..e6728b7 --- /dev/null +++ b/image/vision_transformer/colo_vit/test_vit.py @@ -0,0 +1,114 @@ +import torch +import pytest +import colossalai +from colossalai.context.parallel_mode import ParallelMode +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec +from colossalai.core import global_context as gpc +from functools import partial +from utils.util import tensor_equal, tensor_shard_equal, set_seed +from vit import get_training_components +from torch.nn.parallel import DistributedDataParallel as DDP +from colossalai.nn.parallel.data_parallel import ColoDDP + + +def init_1d_row_spec(model): + spec = TensorSpec( + distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + ParallelAction(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n: + p.set_spec(spec) + + +def init_1d_col_spec(model): + spec = TensorSpec( + distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + ParallelAction(ComputePattern.TP1D)) + + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if ('weight' in n + or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n: + p.set_spec(spec) + + +def check_param_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert tensor_shard_equal(torch_p, p) + + +def check_grad_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + if (torch_p.grad.shape == p.grad.shape): + assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True + else: + dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) + dim = dims_not_eq.item() + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True + + +def run_vit(init_spec_func, use_ddp): + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + model = model.cuda() + torch_model = model_builder().cuda() + if use_ddp: + model = ColoDDP(model) + torch_model = DDP(torch_model, + device_ids=[gpc.get_global_rank()], + process_group=gpc.get_group(ParallelMode.DATA)) + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p) + init_spec_func(model) + + check_param_equal(model, torch_model) + model.train() + torch_model.train() + set_seed(gpc.get_local_rank(ParallelMode.DATA)) + for i, image_dict in enumerate(train_dataloader): + logits = model(image_dict['pixel_values']) + torch_logits = torch_model(image_dict['pixel_values']) + assert tensor_equal(torch_logits.logits, logits.logits) + loss = criterion(logits.logits, image_dict['label']) + torch_loss = criterion(torch_logits.logits, image_dict['label']) + if use_ddp: + model.backward(loss) + else: + loss.backward() + torch_loss.backward() + check_grad_equal(model, torch_model) + if i > 0: + break + + +def run_dist(rank, world_size, port, use_ddp): + if use_ddp and world_size == 1: + return + tp_world_size = world_size // 2 if use_ddp else world_size + config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit(init_1d_row_spec, use_ddp) + run_vit(init_1d_col_spec, use_ddp) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize('use_ddp', [False, True]) +@rerun_if_address_is_in_use() +def test_vit(world_size, use_ddp): + run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_vit(4, False) diff --git a/image/vision_transformer/colo_vit/utils/dummy_data_generator.py b/image/vision_transformer/colo_vit/utils/dummy_data_generator.py new file mode 100644 index 0000000..5ab33e8 --- /dev/null +++ b/image/vision_transformer/colo_vit/utils/dummy_data_generator.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + + +class DummyDataGenerator(ABC): + + def __init__(self, length=10): + self.length = length + + @abstractmethod + def generate(self): + pass + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length diff --git a/image/vision_transformer/colo_vit/utils/util.py b/image/vision_transformer/colo_vit/utils/util.py new file mode 100644 index 0000000..32f9129 --- /dev/null +++ b/image/vision_transformer/colo_vit/utils/util.py @@ -0,0 +1,57 @@ +import os +import random +import numpy as np +import torch +import torch.distributed as dist +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + + +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True + + +def replace_parameter_add_grad(layer, weight=None, bias=None): + if weight is not None: + delattr(layer, 'weight') + setattr(layer, 'weight', weight) + layer.weight.requires_grad = True + if bias is not None: + delattr(layer, 'bias') + setattr(layer, 'bias', bias) + layer.bias.requires_grad = True + + +def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): + dist.broadcast(tensor, src=0) + tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank] + return tensor_chunk.clone() + + +def tensor_equal(A, B): + return torch.allclose(A, B, rtol=1e-3, atol=1e-1) + + +def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): + assert tensor.ndim == shard.ndim + if tensor.shape == shard.shape: + return tensor_equal(tensor, shard) + else: + dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) + if dims_not_eq.numel() == 1: + # 1D shard + dim = dims_not_eq.item() + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) + else: + raise NotImplementedError diff --git a/image/vision_transformer/colo_vit/vit.py b/image/vision_transformer/colo_vit/vit.py new file mode 100644 index 0000000..73bc144 --- /dev/null +++ b/image/vision_transformer/colo_vit/vit.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +from transformers import ViTForImageClassification, ViTConfig +from utils.dummy_data_generator import DummyDataGenerator +from colossalai.utils.cuda import get_current_device + + +class DummyDataLoader(DummyDataGenerator): + batch_size = 4 + channel = 3 + category = 8 + image_size = 224 + + def generate(self): + image_dict = {} + image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size, + DummyDataLoader.channel, + DummyDataLoader.image_size, + DummyDataLoader.image_size, + device=get_current_device()) * 2 - 1 + image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), + dtype=torch.int64, + device=get_current_device()) + return image_dict + + +class ViTCVModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + image_size=224, + patch_size=16, + num_channels=3, + num_labels=8, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = ViTForImageClassification( + ViTConfig(hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + image_size=image_size, + patch_size=patch_size, + num_channels=num_channels, + num_labels=num_labels)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, pixel_values): + return self.model(pixel_values=pixel_values) + + +def vit_base_s(checkpoint=True): + return ViTCVModel(checkpoint=checkpoint) + + +def vit_base_micro(checkpoint=True): + return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint) + + +def get_training_components(): + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy From 46ac9d3c796471930a42198d90a9735d286b20e5 Mon Sep 17 00:00:00 2001 From: ZhaoYi1222 Date: Fri, 17 Jun 2022 14:30:21 +0800 Subject: [PATCH 2/3] [draft]add optimizer for ViT --- image/vision_transformer/colo_vit/test_vit.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/image/vision_transformer/colo_vit/test_vit.py b/image/vision_transformer/colo_vit/test_vit.py index e6728b7..1e29f02 100644 --- a/image/vision_transformer/colo_vit/test_vit.py +++ b/image/vision_transformer/colo_vit/test_vit.py @@ -44,8 +44,13 @@ def check_param_equal(model, torch_model): def check_grad_equal(model, torch_model): - for p, torch_p in zip(model.parameters(), torch_model.parameters()): + for (np, p), (ntp, torch_p) in zip(model.named_parameters(), torch_model.named_parameters()): if (torch_p.grad.shape == p.grad.shape): + print(torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0)) + print(ntp) + print(torch_p.grad) + print(np) + print(p.grad) assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True else: dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) @@ -75,7 +80,15 @@ def run_vit(init_spec_func, use_ddp): model.train() torch_model.train() set_seed(gpc.get_local_rank(ParallelMode.DATA)) + + optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + for i, image_dict in enumerate(train_dataloader): + if use_ddp: + model.zero_grad() + else: + optimizer.zero_grad() logits = model(image_dict['pixel_values']) torch_logits = torch_model(image_dict['pixel_values']) assert tensor_equal(torch_logits.logits, logits.logits) @@ -87,6 +100,9 @@ def run_vit(init_spec_func, use_ddp): loss.backward() torch_loss.backward() check_grad_equal(model, torch_model) + # optimizer.step() + # torch_optimizer.step() + check_param_equal(model, torch_model) if i > 0: break @@ -98,7 +114,7 @@ def run_dist(rank, world_size, port, use_ddp): config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_vit(init_1d_row_spec, use_ddp) - run_vit(init_1d_col_spec, use_ddp) + # run_vit(init_1d_col_spec, use_ddp) @pytest.mark.dist From 24287308c72ea72ddf512194b647b62962434fa0 Mon Sep 17 00:00:00 2001 From: ZhaoYi1222 Date: Fri, 17 Jun 2022 16:52:30 +0800 Subject: [PATCH 3/3] [example]add optimizer for ViT and comments for projection --- image/vision_transformer/colo_vit/test_vit.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/image/vision_transformer/colo_vit/test_vit.py b/image/vision_transformer/colo_vit/test_vit.py index 1e29f02..033508a 100644 --- a/image/vision_transformer/colo_vit/test_vit.py +++ b/image/vision_transformer/colo_vit/test_vit.py @@ -16,7 +16,12 @@ from colossalai.nn.parallel.data_parallel import ColoDDP -def init_1d_row_spec(model): +# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. +# But for other layers, it's 1d_col split. +# Layernorm is not supported for now. +# patch_embeddings.projection has nn.Conv2d +# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182 +def init_1d_row_for_linear_weight_spec(model): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ParallelAction(ComputePattern.TP1D)) @@ -26,11 +31,11 @@ def init_1d_row_spec(model): p.set_spec(spec) -def init_1d_col_spec(model): +# Similarly, it's col split for Linear but row split for others. +def init_1d_col_for_linear_weight_bias_spec(model): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ParallelAction(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): for n, p in model.named_parameters(): if ('weight' in n @@ -46,11 +51,6 @@ def check_param_equal(model, torch_model): def check_grad_equal(model, torch_model): for (np, p), (ntp, torch_p) in zip(model.named_parameters(), torch_model.named_parameters()): if (torch_p.grad.shape == p.grad.shape): - print(torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0)) - print(ntp) - print(torch_p.grad) - print(np) - print(p.grad) assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True else: dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) @@ -100,11 +100,10 @@ def run_vit(init_spec_func, use_ddp): loss.backward() torch_loss.backward() check_grad_equal(model, torch_model) - # optimizer.step() - # torch_optimizer.step() + optimizer.step() + torch_optimizer.step() check_param_equal(model, torch_model) - if i > 0: - break + break def run_dist(rank, world_size, port, use_ddp): @@ -113,8 +112,8 @@ def run_dist(rank, world_size, port, use_ddp): tp_world_size = world_size // 2 if use_ddp else world_size config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_vit(init_1d_row_spec, use_ddp) - # run_vit(init_1d_col_spec, use_ddp) + run_vit(init_1d_row_for_linear_weight_spec, use_ddp) + run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp) @pytest.mark.dist