diff --git a/backgroundremover/u2net/data_loader.py b/backgroundremover/u2net/data_loader.py index 92991f5..2e3b29a 100644 --- a/backgroundremover/u2net/data_loader.py +++ b/backgroundremover/u2net/data_loader.py @@ -139,7 +139,7 @@ def __call__(self, sample): # change the r,g,b to b,r,g from [0,255] to [0,1] # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) tmpImg = tmpImg.transpose((2, 0, 1)) - tmpLbl = label.transpose((2, 0, 1)) + tmpLbl = tmpLbl.transpose((2, 0, 1)) return { "imidx": torch.from_numpy(imidx),