Skip to content

Commit

Permalink
fix a bug in flipping
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffffffli committed Jan 15, 2019
1 parent 674a97e commit 4e78517
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
5 changes: 2 additions & 3 deletions train_sppe/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def valid(val_loader, m, criterion, optimizer, writer):

loss = criterion(out.mul(setMask), labels)

flip_out = m(flip(inps, cuda=True))
flip_out = flip(shuffleLR(
flip_out, val_loader.dataset, cuda=True), cuda=True)
flip_out = m(flip(inps))
flip_out = flip(shuffleLR(flip_out, val_loader.dataset))

out = (flip_out + out) / 2

Expand Down
33 changes: 20 additions & 13 deletions train_sppe/src/utils/img.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,26 @@ def cv_rotate(img, rot, resW, resH):

def flip(x):
assert (x.dim() == 3 or x.dim() == 4)
# dim = x.dim() - 1
x = x.numpy().copy()
if x.ndim == 3:
x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1))
elif x.ndim == 4:
for i in range(x.shape[0]):
x[i] = np.transpose(
np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1))
# x = x.swapaxes(dim, 0)
# x = x[::-1, ...]
# x = x.swapaxes(0, dim)

return torch.from_numpy(x.copy())
if '0.4.1' in torch.__version__:
dim = x.dim() - 1

return x.flip(dims=(dim,))
else:
is_cuda = False
if x.is_cuda:
x = x.cpu()
is_cuda = True
x = x.numpy().copy()
if x.ndim == 3:
x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1))
elif x.ndim == 4:
for i in range(x.shape[0]):
x[i] = np.transpose(
np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1))
x = torch.from_numpy(x.copy())
if is_cuda:
x = x.cuda()
return x


def shuffleLR(x, dataset):
Expand Down

0 comments on commit 4e78517

Please sign in to comment.