Skip to content
This repository has been archived by the owner on Mar 23, 2023. It is now read-only.

Commit

Permalink
[draft]add optimizer for ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhaoYi1222 committed Jun 17, 2022
1 parent c8ee7d1 commit 46ac9d3
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions image/vision_transformer/colo_vit/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ def check_param_equal(model, torch_model):


def check_grad_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
for (np, p), (ntp, torch_p) in zip(model.named_parameters(), torch_model.named_parameters()):
if (torch_p.grad.shape == p.grad.shape):
print(torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0))
print(ntp)
print(torch_p.grad)
print(np)
print(p.grad)
assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True
else:
dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape))
Expand Down Expand Up @@ -75,7 +80,15 @@ def run_vit(init_spec_func, use_ddp):
model.train()
torch_model.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA))

optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

for i, image_dict in enumerate(train_dataloader):
if use_ddp:
model.zero_grad()
else:
optimizer.zero_grad()
logits = model(image_dict['pixel_values'])
torch_logits = torch_model(image_dict['pixel_values'])
assert tensor_equal(torch_logits.logits, logits.logits)
Expand All @@ -87,6 +100,9 @@ def run_vit(init_spec_func, use_ddp):
loss.backward()
torch_loss.backward()
check_grad_equal(model, torch_model)
# optimizer.step()
# torch_optimizer.step()
check_param_equal(model, torch_model)
if i > 0:
break

Expand All @@ -98,7 +114,7 @@ def run_dist(rank, world_size, port, use_ddp):
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit(init_1d_row_spec, use_ddp)
run_vit(init_1d_col_spec, use_ddp)
# run_vit(init_1d_col_spec, use_ddp)


@pytest.mark.dist
Expand Down

0 comments on commit 46ac9d3

Please sign in to comment.