From 46ac9d3c796471930a42198d90a9735d286b20e5 Mon Sep 17 00:00:00 2001 From: ZhaoYi1222 Date: Fri, 17 Jun 2022 14:30:21 +0800 Subject: [PATCH] [draft]add optimizer for ViT --- image/vision_transformer/colo_vit/test_vit.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/image/vision_transformer/colo_vit/test_vit.py b/image/vision_transformer/colo_vit/test_vit.py index e6728b7..1e29f02 100644 --- a/image/vision_transformer/colo_vit/test_vit.py +++ b/image/vision_transformer/colo_vit/test_vit.py @@ -44,8 +44,13 @@ def check_param_equal(model, torch_model): def check_grad_equal(model, torch_model): - for p, torch_p in zip(model.parameters(), torch_model.parameters()): + for (np, p), (ntp, torch_p) in zip(model.named_parameters(), torch_model.named_parameters()): if (torch_p.grad.shape == p.grad.shape): + print(torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0)) + print(ntp) + print(torch_p.grad) + print(np) + print(p.grad) assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True else: dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape)) @@ -75,7 +80,15 @@ def run_vit(init_spec_func, use_ddp): model.train() torch_model.train() set_seed(gpc.get_local_rank(ParallelMode.DATA)) + + optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + for i, image_dict in enumerate(train_dataloader): + if use_ddp: + model.zero_grad() + else: + optimizer.zero_grad() logits = model(image_dict['pixel_values']) torch_logits = torch_model(image_dict['pixel_values']) assert tensor_equal(torch_logits.logits, logits.logits) @@ -87,6 +100,9 @@ def run_vit(init_spec_func, use_ddp): loss.backward() torch_loss.backward() check_grad_equal(model, torch_model) + # optimizer.step() + # torch_optimizer.step() + check_param_equal(model, torch_model) if i > 0: break @@ -98,7 +114,7 @@ def run_dist(rank, world_size, port, use_ddp): config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_vit(init_1d_row_spec, use_ddp) - run_vit(init_1d_col_spec, use_ddp) + # run_vit(init_1d_col_spec, use_ddp) @pytest.mark.dist