diff --git a/image/vision_transformer/colo_vit/test_vit.py b/image/vision_transformer/colo_vit/test_vit.py index 1e29f02..033508a 100644 --- a/image/vision_transformer/colo_vit/test_vit.py +++ b/image/vision_transformer/colo_vit/test_vit.py @@ -16,7 +16,12 @@ from colossalai.nn.parallel.data_parallel import ColoDDP -def init_1d_row_spec(model): +# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating. +# But for other layers, it's 1d_col split. +# Layernorm is not supported for now. +# patch_embeddings.projection has nn.Conv2d +# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182 +def init_1d_row_for_linear_weight_spec(model): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ParallelAction(ComputePattern.TP1D)) @@ -26,11 +31,11 @@ def init_1d_row_spec(model): p.set_spec(spec) -def init_1d_col_spec(model): +# Similarly, it's col split for Linear but row split for others. +def init_1d_col_for_linear_weight_bias_spec(model): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ParallelAction(ComputePattern.TP1D)) - with DistSpecManager.no_grad(): for n, p in model.named_parameters(): if ('weight' in n @@ -46,11 +51,6 @@ def check_param_equal(model, torch_model): def check_grad_equal(model, torch_model): for (np, p), (ntp, torch_p) in zip(model.named_parameters(), torch_model.named_parameters()): if (torch_p.grad.shape == p.grad.shape): - print(torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0)) - print(ntp) - print(torch_p.grad) - print(np) - print(p.grad) assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True else: dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) @@ -100,11 +100,10 @@ def run_vit(init_spec_func, use_ddp): loss.backward() torch_loss.backward() check_grad_equal(model, torch_model) - # optimizer.step() - # torch_optimizer.step() + optimizer.step() + torch_optimizer.step() check_param_equal(model, torch_model) - if i > 0: - break + break def run_dist(rank, world_size, port, use_ddp): @@ -113,8 +112,8 @@ def run_dist(rank, world_size, port, use_ddp): tp_world_size = world_size // 2 if use_ddp else world_size config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_vit(init_1d_row_spec, use_ddp) - # run_vit(init_1d_col_spec, use_ddp) + run_vit(init_1d_row_for_linear_weight_spec, use_ddp) + run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp) @pytest.mark.dist