Skip to content

Commit

Permalink
update matching
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffffffli committed Jan 16, 2019
1 parent 4e78517 commit 344e2e0
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ images
*.pyc
.ipynb_checkpoints
*/.ipynb_checkpoints/
*/.tensorboard
*/.tensorboard/*
*/exp

*.pth
Expand Down
26 changes: 24 additions & 2 deletions SPPE/src/utils/eval.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from opt import opt
try:
from utils.img import transformBoxInvert, transformBoxInvert_batch
from utils.img import transformBoxInvert, transformBoxInvert_batch, findPeak, processPeaks
except ImportError:
from SPPE.src.utils.img import transformBoxInvert, transformBoxInvert_batch
from SPPE.src.utils.img import transformBoxInvert, transformBoxInvert_batch, findPeak, processPeaks
import torch


Expand Down Expand Up @@ -147,6 +147,28 @@ def getPrediction(hms, pt1, pt2, inpH, inpW, resH, resW):
return preds, preds_tf, maxval


def getMultiPeakPrediction(hms, pt1, pt2, inpH, inpW, resH, resW):

assert hms.dim() == 4, 'Score maps should be 4-dim'

preds_img = {}
hms = hms.numpy()
for n in range(hms.shape[0]): # Number of samples
preds_img[n] = {} # Result of sample: n
for k in range(hms.shape[1]): # Number of keypoints
preds_img[n][k] = [] # Result of keypoint: k
hm = hms[n][k]

candidate_points = findPeak(hm)

res_pt = processPeaks(candidate_points, hm,
pt1[n], pt2[n], inpH, inpW, resH, resW)

preds_img[n][k] = res_pt

return preds_img


def getPrediction_batch(hms, pt1, pt2, inpH, inpW, resH, resW):
'''
Get keypoint location from heatmaps
Expand Down
69 changes: 67 additions & 2 deletions SPPE/src/utils/img.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import scipy.misc
from torchvision import transforms
import torch.nn.functional as F
from scipy.ndimage import maximum_filter

from PIL import Image
from copy import deepcopy
Expand Down Expand Up @@ -198,7 +199,7 @@ def transformBox(pt, ul, br, inpH, inpW, resH, resW):


def transformBoxInvert(pt, ul, br, inpH, inpW, resH, resW):
center = torch.zeros(2)
center = np.zeros(2)
center[0] = (br[0] - 1 - ul[0]) / 2
center[1] = (br[1] - 1 - ul[1]) / 2

Expand All @@ -209,7 +210,7 @@ def transformBoxInvert(pt, ul, br, inpH, inpW, resH, resW):
_pt[0] = _pt[0] - max(0, (lenW - 1) / 2 - center[0])
_pt[1] = _pt[1] - max(0, (lenH - 1) / 2 - center[1])

new_point = torch.zeros(2)
new_point = np.zeros(2)
new_point[0] = _pt[0] + ul[0]
new_point[1] = _pt[1] + ul[1]
return new_point
Expand Down Expand Up @@ -430,3 +431,67 @@ def get_dir(src_point, rot_rad):
src_result[1] = src_point[0] * sn + src_point[1] * cs

return src_result


def findPeak(hm):
mx = maximum_filter(hm, size=5)
idx = zip(*np.where((mx == hm) * (hm > 0.1)))
candidate_points = []
for (y, x) in idx:
candidate_points.append([x, y, hm[y][x]])
if len(candidate_points) == 0:
return torch.zeros(0)
candidate_points = np.array(candidate_points)
candidate_points = candidate_points[np.lexsort(-candidate_points.T)]
return torch.Tensor(candidate_points)


def processPeaks(candidate_points, hm, pt1, pt2, inpH, inpW, resH, resW):
# type: (Tensor, Tensor, Tensor, Tensor, float, float, float, float) -> List[Tensor]

if candidate_points.shape[0] == 0: # Low Response
maxval = np.max(hm.reshape(1, -1), 1)
idx = np.argmax(hm.reshape(1, -1), 1)

x = idx % resW
y = int(idx / resW)

candidate_points = np.zeros((1, 3))
candidate_points[0, 0:1] = x
candidate_points[0, 1:2] = y
candidate_points[0, 2:3] = maxval

res_pts = []
for i in range(candidate_points.shape[0]):
x, y, maxval = candidate_points[i][0], candidate_points[i][1], candidate_points[i][2]

if bool(maxval < 0.05) and len(res_pts) > 0:
pass
else:
if bool(x > 0) and bool(x < resW - 2):
if bool(hm[int(y)][int(x) + 1] - hm[int(y)][int(x) - 1] > 0):
x += 0.25
elif bool(hm[int(y)][int(x) + 1] - hm[int(y)][int(x) - 1] < 0):
x -= 0.25
if bool(y > 0) and bool(y < resH - 2):
if bool(hm[int(y) + 1][int(x)] - hm[int(y) - 1][int(x)] > 0):
y += (0.25 * inpH / inpW)
elif bool(hm[int(y) + 1][int(x)] - hm[int(y) - 1][int(x)] < 0):
y -= (0.25 * inpH / inpW)

#pt = torch.zeros(2)
pt = np.zeros(2)
pt[0] = x + 0.2
pt[1] = y + 0.2

pt = transformBoxInvert(pt, pt1, pt2, inpH, inpW, resH, resW)

res_pt = np.zeros(3)
res_pt[:2] = pt
res_pt[2] = maxval

res_pts.append(res_pt)

if maxval < 0.05:
break
return res_pts
17 changes: 11 additions & 6 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from opt import opt
from yolo.preprocess import prep_image, prep_frame, inp_to_image
from pPose_nms import pose_nms, write_json
from SPPE.src.utils.eval import getPrediction
from matching import candidate_reselect as matching
from SPPE.src.utils.eval import getPrediction, getMultiPeakPrediction
from yolo.util import write_results, dynamic_write_results
from yolo.darknet import Darknet
from tqdm import tqdm
Expand Down Expand Up @@ -656,11 +657,15 @@ def update(self):
self.stream.write(img)
else:
# location prediction (n, kp, 2) | score prediction (n, kp, 1)

preds_hm, preds_img, preds_scores = getPrediction(
hm_data, pt1, pt2, opt.inputResH, opt.inputResW, opt.outputResH, opt.outputResW)

result = pose_nms(boxes, scores, preds_img, preds_scores)
if opt.matching:
preds = getMultiPeakPrediction(
hm_data, pt1.numpy(), pt2.numpy(), opt.inputResH, opt.inputResW, opt.outputResH, opt.outputResW)
result = matching(boxes, scores.numpy(), preds)
else:
preds_hm, preds_img, preds_scores = getPrediction(
hm_data, pt1, pt2, opt.inputResH, opt.inputResW, opt.outputResH, opt.outputResW)
result = pose_nms(
boxes, scores, preds_img, preds_scores)
result = {
'imgname': im_name,
'result': result
Expand Down
Loading

0 comments on commit 344e2e0

Please sign in to comment.