-
-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
874 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# How to Contribute | ||
|
||
We'd love to accept your patches and contributions to this project. There are | ||
just a few small guidelines you need to follow. | ||
|
||
## Contributor License Agreement | ||
|
||
Contributions to this project must be accompanied by a Contributor License | ||
Agreement. You (or your employer) retain the copyright to your contribution; | ||
this simply gives us permission to use and redistribute your contributions as | ||
part of the project. Head over to <https://ambianic.ai/cla> to see | ||
your current agreements on file or to sign a new one. | ||
|
||
You generally only need to submit a CLA once, so if you've already submitted one | ||
(even if it was for a different project), you probably don't need to do it | ||
again. | ||
|
||
## Code reviews | ||
|
||
All submissions, including submissions by project members, require review. We | ||
use GitHub pull requests for this purpose. Consult | ||
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more | ||
information on using pull requests. | ||
|
||
## Community Guidelines | ||
|
||
This project follows [Google's Open Source Community | ||
Guidelines](https://opensource.google.com/conduct/). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
|
||
Apache License | ||
Version 2.0, January 2004 | ||
http://www.apache.org/licenses/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/bin/sh | ||
# Copyright 2019 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
mkdir -p all_models | ||
wget https://dl.google.com/coral/canned_models/all_models.tar.gz | ||
tar -C all_models -xvzf all_models.tar.gz | ||
rm -f all_models.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
This folder contains two examples using gstreamer to obtain camera images. These | ||
examples work on Linux using a webcam, Raspberry Pi with | ||
the Raspicam and on the Coral DevBoard using the Coral camera. For the | ||
former two you will also need a Coral USB Accelerator to run the models. | ||
|
||
## Installation | ||
|
||
Make sure the gstreamer libraries are install. On the Coral DevBoard this isn't | ||
necessary, but on Raspberry Pi or a general Linux system it will be. | ||
|
||
``` | ||
sh install_requirements.sh | ||
``` | ||
|
||
|
||
## Classification Demo | ||
|
||
``` | ||
python3 classify.py | ||
``` | ||
|
||
You can change the model and the labels file using flags ```--model``` and | ||
```--labels```. | ||
## Detection Demo (SSD models) | ||
|
||
``` | ||
python3 detect.py | ||
``` | ||
|
||
As before, you can change the model and the labels file using flags ```--model``` | ||
and ```--labels```. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright 2019 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""A demo which runs object classification on camera frames.""" | ||
import argparse | ||
import time | ||
import re | ||
import svgwrite | ||
import imp | ||
import os | ||
from edgetpu.classification.engine import ClassificationEngine | ||
import gstreamer | ||
|
||
def load_labels(path): | ||
p = re.compile(r'\s*(\d+)(.+)') | ||
with open(path, 'r', encoding='utf-8') as f: | ||
lines = (p.match(line).groups() for line in f.readlines()) | ||
return {int(num): text.strip() for num, text in lines} | ||
|
||
def generate_svg(dwg, text_lines): | ||
for y, line in enumerate(text_lines): | ||
dwg.add(dwg.text(line, insert=(11, y*20+1), fill='black', font_size='20')) | ||
dwg.add(dwg.text(line, insert=(10, y*20), fill='white', font_size='20')) | ||
|
||
def main(): | ||
default_model_dir = "../all_models" | ||
default_model = 'mobilenet_v2_1.0_224_quant_edgetpu.tflite' | ||
default_labels = 'imagenet_labels.txt' | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model', help='.tflite model path', | ||
default=os.path.join(default_model_dir,default_model)) | ||
parser.add_argument('--labels', help='label file path', | ||
default=os.path.join(default_model_dir, default_labels)) | ||
parser.add_argument('--top_k', type=int, default=3, | ||
help='number of classes with highest score to display') | ||
parser.add_argument('--threshold', type=float, default=0.1, | ||
help='class score threshold') | ||
args = parser.parse_args() | ||
|
||
print("Loading %s with %s labels."%(args.model, args.labels)) | ||
engine = ClassificationEngine(args.model) | ||
labels = load_labels(args.labels) | ||
|
||
last_time = time.monotonic() | ||
def user_callback(image, svg_canvas): | ||
nonlocal last_time | ||
start_time = time.monotonic() | ||
results = engine.ClassifyWithImage(image, threshold=args.threshold, top_k=args.top_k) | ||
end_time = time.monotonic() | ||
text_lines = [ | ||
'Inference: %.2f ms' %((end_time - start_time) * 1000), | ||
'FPS: %.2f fps' %(1.0/(end_time - last_time)), | ||
] | ||
for index, score in results: | ||
text_lines.append('score=%.2f: %s' % (score, labels[index])) | ||
print(' '.join(text_lines)) | ||
last_time = end_time | ||
generate_svg(svg_canvas, text_lines) | ||
|
||
result = gstreamer.run_pipeline(user_callback) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2019 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""A demo which runs object detection on camera frames. | ||
export TEST_DATA=/usr/lib/python3/dist-packages/edgetpu/test_data | ||
Run face detection model: | ||
python3 -m edgetpuvision.detect \ | ||
--model ${TEST_DATA}/mobilenet_ssd_v2_face_quant_postprocess_edgetpu.tflite | ||
Run coco model: | ||
python3 -m edgetpuvision.detect \ | ||
--model ${TEST_DATA}/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite \ | ||
--labels ${TEST_DATA}/coco_labels.txt | ||
""" | ||
import argparse | ||
import time | ||
import re | ||
import svgwrite | ||
import imp | ||
import os | ||
from edgetpu.detection.engine import DetectionEngine | ||
import gstreamer | ||
|
||
def load_labels(path): | ||
p = re.compile(r'\s*(\d+)(.+)') | ||
with open(path, 'r', encoding='utf-8') as f: | ||
lines = (p.match(line).groups() for line in f.readlines()) | ||
return {int(num): text.strip() for num, text in lines} | ||
|
||
def shadow_text(dwg, x, y, text, font_size=20): | ||
dwg.add(dwg.text(text, insert=(x+1, y+1), fill='black', font_size=font_size)) | ||
dwg.add(dwg.text(text, insert=(x, y), fill='white', font_size=font_size)) | ||
|
||
def generate_svg(dwg, objs, labels, text_lines): | ||
width, height = dwg.attribs['width'], dwg.attribs['height'] | ||
for y, line in enumerate(text_lines): | ||
shadow_text(dwg, 10, y*20, line) | ||
for obj in objs: | ||
x0, y0, x1, y1 = obj.bounding_box.flatten().tolist() | ||
x, y, w, h = x0, y0, x1 - x0, y1 - y0 | ||
x, y, w, h = int(x * width), int(y * height), int(w * width), int(h * height) | ||
percent = int(100 * obj.score) | ||
label = '%d%% %s' % (percent, labels[obj.label_id]) | ||
shadow_text(dwg, x, y - 5, label) | ||
dwg.add(dwg.rect(insert=(x,y), size=(w, h), | ||
fill='red', fill_opacity=0.3, stroke='white')) | ||
|
||
def main(): | ||
default_model_dir = '../all_models' | ||
default_model = 'mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite' | ||
default_labels = 'coco_labels.txt' | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model', help='.tflite model path', | ||
default=os.path.join(default_model_dir,default_model)) | ||
parser.add_argument('--labels', help='label file path', | ||
default=os.path.join(default_model_dir, default_labels)) | ||
parser.add_argument('--top_k', type=int, default=3, | ||
help='number of classes with highest score to display') | ||
parser.add_argument('--threshold', type=float, default=0.1, | ||
help='class score threshold') | ||
args = parser.parse_args() | ||
|
||
print("Loading %s with %s labels."%(args.model, args.labels)) | ||
engine = DetectionEngine(args.model) | ||
labels = load_labels(args.labels) | ||
|
||
last_time = time.monotonic() | ||
def user_callback(image, svg_canvas): | ||
nonlocal last_time | ||
start_time = time.monotonic() | ||
objs = engine.DetectWithImage(image, threshold=args.threshold, | ||
keep_aspect_ratio=True, relative_coord=True, | ||
top_k=args.top_k) | ||
end_time = time.monotonic() | ||
text_lines = [ | ||
'Inference: %.2f ms' %((end_time - start_time) * 1000), | ||
'FPS: %.2f fps' %(1.0/(end_time - last_time)), | ||
] | ||
print(' '.join(text_lines)) | ||
last_time = end_time | ||
generate_svg(svg_canvas, objs, labels, text_lines) | ||
|
||
result = gstreamer.run_pipeline(user_callback) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright 2019 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the 'License'); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an 'AS IS' BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import sys | ||
from functools import partial | ||
import svgwrite | ||
|
||
import gi | ||
gi.require_version('Gst', '1.0') | ||
gi.require_version('GstBase', '1.0') | ||
from gi.repository import GLib, GObject, Gst, GstBase | ||
from PIL import Image | ||
|
||
GObject.threads_init() | ||
Gst.init(None) | ||
|
||
def on_bus_message(bus, message, loop): | ||
t = message.type | ||
if t == Gst.MessageType.EOS: | ||
loop.quit() | ||
elif t == Gst.MessageType.WARNING: | ||
err, debug = message.parse_warning() | ||
sys.stderr.write('Warning: %s: %s\n' % (err, debug)) | ||
elif t == Gst.MessageType.ERROR: | ||
err, debug = message.parse_error() | ||
sys.stderr.write('Error: %s: %s\n' % (err, debug)) | ||
loop.quit() | ||
return True | ||
|
||
def on_new_sample(sink, overlay, screen_size, appsink_size, user_function): | ||
sample = sink.emit('pull-sample') | ||
buf = sample.get_buffer() | ||
result, mapinfo = buf.map(Gst.MapFlags.READ) | ||
if result: | ||
img = Image.frombytes('RGB', (appsink_size[0], appsink_size[1]), mapinfo.data, 'raw') | ||
svg_canvas = svgwrite.Drawing('', size=(screen_size[0], screen_size[1])) | ||
user_function(img, svg_canvas) | ||
overlay.set_property('data', svg_canvas.tostring()) | ||
buf.unmap(mapinfo) | ||
return Gst.FlowReturn.OK | ||
|
||
def detectCoralDevBoard(): | ||
try: | ||
if 'MX8MQ' in open('/sys/firmware/devicetree/base/model').read(): | ||
print('Detected Edge TPU dev board.') | ||
return True | ||
except: pass | ||
return False | ||
|
||
def run_pipeline(user_function, | ||
src_size=(640,480), | ||
appsink_size=(320, 180)): | ||
PIPELINE = 'v4l2src device=/dev/video0 ! {src_caps} ! {leaky_q} ' | ||
if detectCoralDevBoard(): | ||
SRC_CAPS = 'video/x-raw,format=YUY2,width={width},height={height},framerate=30/1' | ||
PIPELINE += """ ! glupload ! tee name=t | ||
t. ! {leaky_q} ! glfilterbin filter=glcolorscale | ||
! {dl_caps} ! videoconvert ! {sink_caps} ! {sink_element} | ||
t. ! {leaky_q} ! glfilterbin filter=glcolorscale | ||
! rsvgoverlay name=overlay ! waylandsink | ||
""" | ||
else: | ||
SRC_CAPS = 'video/x-raw,width={width},height={height},framerate=30/1' | ||
PIPELINE += """ ! tee name=t | ||
t. ! {leaky_q} ! videoconvert ! videoscale ! {sink_caps} ! {sink_element} | ||
t. ! {leaky_q} ! videoconvert | ||
! rsvgoverlay name=overlay ! videoconvert ! ximagesink | ||
""" | ||
|
||
SINK_ELEMENT = 'appsink name=appsink sync=false emit-signals=true max-buffers=1 drop=true' | ||
DL_CAPS = 'video/x-raw,format=RGBA,width={width},height={height}' | ||
SINK_CAPS = 'video/x-raw,format=RGB,width={width},height={height}' | ||
LEAKY_Q = 'queue max-size-buffers=1 leaky=downstream' | ||
|
||
src_caps = SRC_CAPS.format(width=src_size[0], height=src_size[1]) | ||
dl_caps = DL_CAPS.format(width=appsink_size[0], height=appsink_size[1]) | ||
sink_caps = SINK_CAPS.format(width=appsink_size[0], height=appsink_size[1]) | ||
pipeline = PIPELINE.format(leaky_q=LEAKY_Q, | ||
src_caps=src_caps, dl_caps=dl_caps, sink_caps=sink_caps, | ||
sink_element=SINK_ELEMENT) | ||
|
||
print('Gstreamer pipeline: ', pipeline) | ||
pipeline = Gst.parse_launch(pipeline) | ||
|
||
overlay = pipeline.get_by_name('overlay') | ||
appsink = pipeline.get_by_name('appsink') | ||
appsink.connect('new-sample', partial(on_new_sample, | ||
overlay=overlay, screen_size = src_size, | ||
appsink_size=appsink_size, user_function=user_function)) | ||
loop = GObject.MainLoop() | ||
|
||
# Set up a pipeline bus watch to catch errors. | ||
bus = pipeline.get_bus() | ||
bus.add_signal_watch() | ||
bus.connect('message', on_bus_message, loop) | ||
|
||
# Run pipeline. | ||
pipeline.set_state(Gst.State.PLAYING) | ||
try: | ||
loop.run() | ||
except: | ||
pass | ||
|
||
# Clean up. | ||
pipeline.set_state(Gst.State.NULL) | ||
while GLib.MainContext.default().iteration(False): | ||
pass |
Oops, something went wrong.