From 6104ab1b53920f6e2159749676073ff7d815c1fa Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 17 Jun 2024 18:16:01 +0000 Subject: [PATCH] fix guard of torch.cuda based on device type --- train_gpt2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_gpt2.py b/train_gpt2.py index 33a4468..403f213 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -506,7 +506,8 @@ def get_lr(it): for param_group in optimizer.param_groups: param_group['lr'] = lr optimizer.step() - torch.cuda.synchronize() # wait for the GPU to finish work + if device_type == "cuda": + torch.cuda.synchronize() # wait for the GPU to finish work t1 = time.time() dt = t1 - t0 # time difference in seconds tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size