diff --git a/train_gpt2.py b/train_gpt2.py index 33a4468..403f213 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -506,7 +506,8 @@ def get_lr(it): for param_group in optimizer.param_groups: param_group['lr'] = lr optimizer.step() - torch.cuda.synchronize() # wait for the GPU to finish work + if device_type == "cuda": + torch.cuda.synchronize() # wait for the GPU to finish work t1 = time.time() dt = t1 - t0 # time difference in seconds tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size