Skip to content
This repository has been archived by the owner on Mar 23, 2023. It is now read-only.

Commit

Permalink
[example]add optimizer for ViT and comments for projection
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhaoYi1222 committed Jun 17, 2022
1 parent 46ac9d3 commit 2428730
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions image/vision_transformer/colo_vit/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from colossalai.nn.parallel.data_parallel import ColoDDP


def init_1d_row_spec(model):
# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
# But for other layers, it's 1d_col split.
# Layernorm is not supported for now.
# patch_embeddings.projection has nn.Conv2d
# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182
def init_1d_row_for_linear_weight_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D))
Expand All @@ -26,11 +31,11 @@ def init_1d_row_spec(model):
p.set_spec(spec)


def init_1d_col_spec(model):
# Similarly, it's col split for Linear but row split for others.
def init_1d_col_for_linear_weight_bias_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D))

with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if ('weight' in n
Expand All @@ -46,11 +51,6 @@ def check_param_equal(model, torch_model):
def check_grad_equal(model, torch_model):
for (np, p), (ntp, torch_p) in zip(model.named_parameters(), torch_model.named_parameters()):
if (torch_p.grad.shape == p.grad.shape):
print(torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0))
print(ntp)
print(torch_p.grad)
print(np)
print(p.grad)
assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True
else:
dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape))
Expand Down Expand Up @@ -100,11 +100,10 @@ def run_vit(init_spec_func, use_ddp):
loss.backward()
torch_loss.backward()
check_grad_equal(model, torch_model)
# optimizer.step()
# torch_optimizer.step()
optimizer.step()
torch_optimizer.step()
check_param_equal(model, torch_model)
if i > 0:
break
break


def run_dist(rank, world_size, port, use_ddp):
Expand All @@ -113,8 +112,8 @@ def run_dist(rank, world_size, port, use_ddp):
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit(init_1d_row_spec, use_ddp)
# run_vit(init_1d_col_spec, use_ddp)
run_vit(init_1d_row_for_linear_weight_spec, use_ddp)
run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp)


@pytest.mark.dist
Expand Down

0 comments on commit 2428730

Please sign in to comment.