Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update torch.cuda.amp to torch.amp #13244

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion classify/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def run(
action = "validating" if dataloader.dataset.root.stem == "val" else "testing"
desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
bar = tqdm(dataloader, desc, n, not training, bar_format=TQDM_BAR_FORMAT, position=0)
with torch.cuda.amp.autocast(enabled=device.type != "cpu"):
with torch.amp.autocast("cuda", enabled=device.type != "cpu"):
for images, labels in bar:
with dt[0]:
images, labels = images.to(device, non_blocking=True), labels.to(device)
Expand Down
4 changes: 2 additions & 2 deletions segment/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def lf(x):
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = torch.cuda.amp.GradScaler(enabled=amp)
scaler = torch.amp.GradScaler("cuda", enabled=amp)
stopper, stop = EarlyStopping(patience=opt.patience), False
compute_loss = ComputeLoss(model, overlap=overlap) # init loss class
# callbacks.run('on_train_start')
Expand Down Expand Up @@ -380,7 +380,7 @@ def lf(x):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)

# Forward
with torch.cuda.amp.autocast(amp):
with torch.amp.autocast("cuda", enabled=amp):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
if RANK != -1:
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def lf(x):
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = torch.cuda.amp.GradScaler(enabled=amp)
scaler = torch.amp.GradScaler("cuda", enabled=amp)
stopper, stop = EarlyStopping(patience=opt.patience), False
compute_loss = ComputeLoss(model) # init loss class
callbacks.run("on_train_start")
Expand Down Expand Up @@ -409,7 +409,7 @@ def lf(x):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)

# Forward
with torch.cuda.amp.autocast(amp):
with torch.amp.autocast("cuda", enabled=amp):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1:
Expand Down
2 changes: 1 addition & 1 deletion utils/autobatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def check_train_batch_size(model, imgsz=640, amp=True):
"""Checks and computes optimal training batch size for YOLOv5 model, given image size and AMP setting."""
with torch.cuda.amp.autocast(amp):
with torch.amp.autocast("cuda", enabled=amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size


Expand Down
Loading