Skip to content

Commit

Permalink
update yolo
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffffffli committed Jul 9, 2018
1 parent 3880736 commit 33ef1b2
Show file tree
Hide file tree
Showing 47 changed files with 3,393 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ ssd/examples
*.json
*.h5
*.zip
*.weights

coco-minival/
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ To match poses that correspond to the same person across frames, we also provide
./install.sh
```

1. Download the models manually: **ssd_coco.pth**([Google Drive](https://drive.google.com/open?id=1ifUNnxSDOP7InBtKy5CGnRAMLD4-d4Ct) | [Baidu pan](https://pan.baidu.com/s/1mon7Ht6ObqrS2ZY955swpg)), **pyra_4.pth** ([Google Drive](https://drive.google.com/open?id=1oG1Fxj4oBfKwD1W_2QObxltWybuIk7Y6) | [Baidu pan](https://pan.baidu.com/s/14ONL_T_d1twm9Lxac5x-Ew)). Place them into `./models/ssd` and `./models/sppe` respectively.
1. Download the models manually: **ssd_coco.pth**([Google Drive](https://drive.google.com/open?id=1ifUNnxSDOP7InBtKy5CGnRAMLD4-d4Ct) | [Baidu pan](https://pan.baidu.com/s/1mon7Ht6ObqrS2ZY955swpg)), **pyra_4.pth** ([Google Drive](https://drive.google.com/open?id=1oG1Fxj4oBfKwD1W_2QObxltWybuIk7Y6) | [Baidu pan](https://pan.baidu.com/s/14ONL_T_d1twm9Lxac5x-Ew)), **yolov3.weights**([Google Drive](https://drive.google.com/open?id=1yjrziA2RzFqWAQG4Qq7XN0vumsMxwSjS) | [Baidu pan](https://pan.baidu.com/s/108SjV-uIJpxnqDMT19v-Aw)). Place them into `./models/ssd`, `./models/sppe` and `./models/yolo` respectively.


## Quick Start
- **Demo**: Run AlphaPose for all images in a folder and visualize the results with:
```
python demo_fast.py \
python demo_yolo.py \
--inputlist ./list-coco-minival500.txt \
--imgpath ${img_directory} \
--outputpath ./coco-minival
Expand Down
2 changes: 1 addition & 1 deletion SPPE/src/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def getPrediction(hms, pt1, pt2, inpH, inpW, resH, resW):
# preds += 0.5

preds_tf = torch.zeros(preds.size())

preds_tf = transformBoxInvert_batch(preds, pt1, pt2, inpH, inpW, resH, resW)

return preds, preds_tf, maxval
Expand Down
2 changes: 1 addition & 1 deletion demo_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# Load YOLO model
print('Loading YOLO model..')
det_model = Darknet("yolo/cfg/yolov3.cfg")
det_model.load_weights('yolo/yolov3.weights')
det_model.load_weights('models/yolo/yolov3.weights')
det_model.net_info['height'] = args.inp_dim
det_inp_dim = int(det_model.net_info['height'])
assert det_inp_dim % 32 == 0
Expand Down
Empty file added models/yolo/.gitkeep
Empty file.
1 change: 0 additions & 1 deletion yolo
Submodule yolo deleted from fbb4ef
93 changes: 93 additions & 0 deletions yolo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# A PyTorch implementation of a YOLO v3 Object Detector

[UPDATE] : This repo serves as a driver code for my research. I just graduated college, and am very busy looking for research internship / fellowship roles before eventually applying for a masters. I won't have the time to look into issues for the time being. Thank you.


This repository contains code for a object detector based on [YOLOv3: An Incremental Improvement](https://pjreddie.com/media/files/papers/YOLOv3.pdf), implementedin PyTorch. The code is based on the official code of [YOLO v3](https://github.com/pjreddie/darknet), as well as a PyTorch
port of the original code, by [marvis](https://github.com/marvis/pytorch-yolo2). One of the goals of this code is to improve
upon the original port by removing redundant parts of the code (The official code is basically a fully blown deep learning
library, and includes stuff like sequence models, which are not used in YOLO). I've also tried to keep the code minimal, and
document it as well as I can.

### Tutorial for building this detector from scratch
If you want to understand how to implement this detector by yourself from scratch, then you can go through this very detailed 5-part tutorial series I wrote on Paperspace. Perfect for someone who wants to move from beginner to intermediate pytorch skills.

[Implement YOLO v3 from scratch](https://blog.paperspace.com/how-to-implement-a-yolo-object-detector-in-pytorch/)

As of now, the code only contains the detection module, but you should expect the training module soon. :)

## Requirements
1. Python 3.5
2. OpenCV
3. PyTorch 0.4

Using PyTorch 0.3 will break the detector.



## Detection Example

![Detection Example](https://i.imgur.com/m2jwneng.png)
## Running the detector

### On single or multiple images

Clone, and `cd` into the repo directory. The first thing you need to do is to get the weights file
This time around, for v3, authors has supplied a weightsfile only for COCO [here](https://pjreddie.com/media/files/yolov3.weights), and place

the weights file into your repo directory. Or, you could just type (if you're on Linux)

```
wget https://pjreddie.com/media/files/yolov3.weights
python detect.py --images imgs --det det
```


`--images` flag defines the directory to load images from, or a single image file (it will figure it out), and `--det` is the directory
to save images to. Other setting such as batch size (using `--bs` flag) , object threshold confidence can be tweaked with flags that can be looked up with.

```
python detect.py -h
```

### Speed Accuracy Tradeoff
You can change the resolutions of the input image by the `--reso` flag. The default value is 416. Whatever value you chose, rememeber **it should be a multiple of 32 and greater than 32**. Weird things will happen if you don't. You've been warned.

```
python detect.py --images imgs --det det --reso 320
```

### On Video
For this, you should run the file, video_demo.py with --video flag specifying the video file. The video file should be in .avi format
since openCV only accepts OpenCV as the input format.

```
python video_demo.py --video video.avi
```

Tweakable settings can be seen with -h flag.

### Speeding up Video Inference

To speed video inference, you can try using the video_demo_half.py file instead which does all the inference with 16-bit half
precision floats instead of 32-bit float. I haven't seen big improvements, but I attribute that to having an older card
(Tesla K80, Kepler arch). If you have one of cards with fast float16 support, try it out, and if possible, benchmark it.

### On a Camera
Same as video module, but you don't have to specify the video file since feed will be taken from your camera. To be precise,
feed will be taken from what the OpenCV, recognises as camera 0. The default image resolution is 160 here, though you can change it with `reso` flag.

```
python cam_demo.py
```
You can easily tweak the code to use different weightsfiles, available at [yolo website](https://pjreddie.com/darknet/yolo/)

NOTE: The scales features has been disabled for better refactoring.
### Detection across different scales
YOLO v3 makes detections across different scales, each of which deputise in detecting objects of different sizes depending upon whether they capture coarse features, fine grained features or something between. You can experiment with these scales by the `--scales` flag.

```
python detect.py --scales 1,3
```


Empty file added yolo/__init__.py
Empty file.
115 changes: 115 additions & 0 deletions yolo/bbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import division

import torch
import random

import numpy as np
import cv2

def confidence_filter(result, confidence):
conf_mask = (result[:,:,4] > confidence).float().unsqueeze(2)
result = result*conf_mask

return result

def confidence_filter_cls(result, confidence):
max_scores = torch.max(result[:,:,5:25], 2)[0]
res = torch.cat((result, max_scores),2)
print(res.shape)


cond_1 = (res[:,:,4] > confidence).float()
cond_2 = (res[:,:,25] > 0.995).float()

conf = cond_1 + cond_2
conf = torch.clamp(conf, 0.0, 1.0)
conf = conf.unsqueeze(2)
result = result*conf
return result



def get_abs_coord(box):
box[2], box[3] = abs(box[2]), abs(box[3])
x1 = (box[0] - box[2]/2) - 1
y1 = (box[1] - box[3]/2) - 1
x2 = (box[0] + box[2]/2) - 1
y2 = (box[1] + box[3]/2) - 1
return x1, y1, x2, y2



def sanity_fix(box):
if (box[0] > box[2]):
box[0], box[2] = box[2], box[0]

if (box[1] > box[3]):
box[1], box[3] = box[3], box[1]

return box

def bbox_iou(box1, box2):
"""
Returns the IoU of two bounding boxes
"""
#Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3]

#get the corrdinates of the intersection rectangle
inter_rect_x1 = torch.max(b1_x1, b2_x1)
inter_rect_y1 = torch.max(b1_y1, b2_y1)
inter_rect_x2 = torch.min(b1_x2, b2_x2)
inter_rect_y2 = torch.min(b1_y2, b2_y2)

#Intersection area
if torch.cuda.is_available():
inter_area = torch.max(inter_rect_x2 - inter_rect_x1 + 1,torch.zeros(inter_rect_x2.shape).cuda())*torch.max(inter_rect_y2 - inter_rect_y1 + 1, torch.zeros(inter_rect_x2.shape).cuda())
else:
inter_area = torch.max(inter_rect_x2 - inter_rect_x1 + 1,torch.zeros(inter_rect_x2.shape))*torch.max(inter_rect_y2 - inter_rect_y1 + 1, torch.zeros(inter_rect_x2.shape))

#Union Area
b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1)

iou = inter_area / (b1_area + b2_area - inter_area)

return iou


def pred_corner_coord(prediction):
#Get indices of non-zero confidence bboxes
ind_nz = torch.nonzero(prediction[:,:,4]).transpose(0,1).contiguous()

box = prediction[ind_nz[0], ind_nz[1]]


box_a = box.new(box.shape)
box_a[:,0] = (box[:,0] - box[:,2]/2)
box_a[:,1] = (box[:,1] - box[:,3]/2)
box_a[:,2] = (box[:,0] + box[:,2]/2)
box_a[:,3] = (box[:,1] + box[:,3]/2)
box[:,:4] = box_a[:,:4]

prediction[ind_nz[0], ind_nz[1]] = box

return prediction




def write(x, batches, results, colors, classes):
c1 = tuple(x[1:3].int())
c2 = tuple(x[3:5].int())
img = results[int(x[0])]
cls = int(x[-1])
label = "{0}".format(classes[cls])
color = random.choice(colors)
cv2.rectangle(img, c1, c2,color, 1)
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1 , 1)[0]
c2 = c1[0] + t_size[0] + 3, c1[1] + t_size[1] + 4
cv2.rectangle(img, c1, c2,color, -1)
cv2.putText(img, label, (c1[0], c1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1);
return img
Loading

0 comments on commit 33ef1b2

Please sign in to comment.