Skip to content

Commit

Permalink
Tflite tpu (#1449)
Browse files Browse the repository at this point in the history
* add tpu support tflite backend

* tflite conversion & edgetpu compilation

* add pycoral interpreter

* update preprocess name

* Update run_common.sh | fix wrong else case code

* Revert version.py change

---------

Co-authored-by: Arjun Suresh <[email protected]>
Co-authored-by: Miro <[email protected]>
  • Loading branch information
3 people authored Oct 1, 2024
1 parent e0fdec1 commit 518b454
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 12 deletions.
15 changes: 12 additions & 3 deletions vision/classification_and_detection/python/backend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_version = tf.__version__
_git_version = tf.__git_version__


import numpy as np
import backend


Expand All @@ -39,8 +39,13 @@ def image_format(self):
# tflite is always NHWC
return "NHWC"

def load(self, model_path, inputs=None, outputs=None):
self.sess = tflite.Interpreter(model_path=model_path)
def load(self, model_path, inputs=None, outputs=None, use_tpu=False):
self.use_tpu = use_tpu
if use_tpu:
from pycoral.utils.edgetpu import make_interpreter
self.sess = make_interpreter(model_path)
else:
self.sess = tflite.Interpreter(model_path=model_path)
self.sess.allocate_tensors()
# keep input/output name to index mapping
self.input2index = {i["name"]: i["index"] for i in self.sess.get_input_details()}
Expand All @@ -54,6 +59,10 @@ def predict(self, feed):
self.lock.acquire()
# set inputs
for k, v in self.input2index.items():
if self.use_tpu and self.sess.get_input_details()[v]['dtype'] == np.uint8:
input_scale, input_zero_point = self.sess.get_input_details()[v]["quantization"]
feed[k] = feed[k] / input_scale + input_zero_point
feed[k] = feed[k].astype(np.uint8)
self.sess.set_tensor(v, feed[k])
self.sess.invoke()
# get results
Expand Down
12 changes: 12 additions & 0 deletions vision/classification_and_detection/python/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@ def pre_process_mobilenet(img, dims=None, need_transpose=False):
return img


def pre_process_imagenet_tflite_tpu(img, dims=None, need_transpose=False):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = resize_with_aspectratio(img, 224, 224, inter_pol=cv2.INTER_LINEAR)
img = center_crop(img, 224, 224)
img = np.asarray(img, dtype='float32')

img = img[..., ::-1]
means = np.array([103.939, 116.779, 123.68], dtype=np.float32)
img -= means
return img


def pre_process_imagenet_pytorch(img, dims=None, need_transpose=False):
from PIL import Image
import torchvision.transforms.functional as F
Expand Down
15 changes: 14 additions & 1 deletion vision/classification_and_detection/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
"imagenet_mobilenet":
(imagenet.Imagenet, dataset.pre_process_mobilenet, dataset.PostProcessArgMax(offset=-1),
{"image_size": [224, 224, 3]}),
"imagenet_tflite_tpu":
(imagenet.Imagenet, dataset.pre_process_imagenet_tflite_tpu, dataset.PostProcessArgMax(offset=0),
{"image_size": [224, 224, 3]}),
"imagenet_pytorch":
(imagenet.Imagenet, dataset.pre_process_imagenet_pytorch, dataset.PostProcessArgMax(offset=0),
{"image_size": [224, 224, 3]}),
Expand Down Expand Up @@ -114,6 +117,12 @@
"backend": "ncnn",
"model-name": "resnet50",
},
"resnet50-tflite": {
"dataset": "imagenet_tflite_tpu",
"outputs": "ArgMax:0",
"backend": "tflite",
"model-name": "resnet50",
},

# mobilenet
"mobilenet-tf": {
Expand Down Expand Up @@ -231,6 +240,7 @@ def get_args():
parser.add_argument("--inputs", help="model inputs")
parser.add_argument("--outputs", help="model outputs")
parser.add_argument("--backend", help="runtime to use")
parser.add_argument("--device", help="device to use")
parser.add_argument("--model-name", help="name of the mlperf model, ie. resnet50")
parser.add_argument("--threads", default=os.cpu_count(), type=int, help="threads")
parser.add_argument("--qps", type=int, help="target qps")
Expand Down Expand Up @@ -500,7 +510,10 @@ def main():
threads=args.threads,
**kwargs)
# load model to backend
model = backend.load(args.model, inputs=args.inputs, outputs=args.outputs)
if args.device == "tpu":
model = backend.load(args.model, inputs=args.inputs, outputs=args.outputs, use_tpu=True)
else:
model = backend.load(args.model, inputs=args.inputs, outputs=args.outputs)
final_results = {
"runtime": model.name(),
"version": model.version(),
Expand Down
23 changes: 15 additions & 8 deletions vision/classification_and_detection/run_common.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

if [ $# -lt 1 ]; then
echo "usage: $0 tf|onnxruntime|pytorch|tflite|tvm-onnx|tvm-pytorch|tvm-tflite [resnet50|mobilenet|ssd-mobilenet|ssd-resnet34|retinanet] [cpu|gpu]"
echo "usage: $0 tf|onnxruntime|pytorch|tflite|tvm-onnx|tvm-pytorch|tvm-tflite [resnet50|mobilenet|ssd-mobilenet|ssd-resnet34|retinanet] [cpu|gpu|tpu]"
exit 1
fi
if [ "x$DATA_DIR" == "x" ]; then
Expand All @@ -19,7 +19,7 @@ device="cpu"
for i in $* ; do
case $i in
tf|onnxruntime|tflite|pytorch|tvm-onnx|tvm-pytorch|tvm-tflite|ncnn) backend=$i; shift;;
cpu|gpu|rocm) device=$i; shift;;
cpu|gpu|tpu|rocm) device=$i; shift;;
gpu) device=gpu; shift;;
resnet50|mobilenet|ssd-mobilenet|ssd-resnet34|ssd-resnet34-tf|retinanet) model=$i; shift;;
esac
Expand Down Expand Up @@ -108,14 +108,21 @@ fi
#
# tflite
#
if [ $name == "resnet50-tflite" ] ; then
model_path="$MODEL_DIR/resnet50_v1.tflite"
profile=resnet50-tf
extra_args="$extra_args --backend tflite"
if [ "$name" = "resnet50-tflite" ]; then
if [ "$device" = "tpu" ]; then
model_path="$MODEL_DIR/resnet50_quant_full_mlperf_edgetpu.tflite"
profile="resnet50-tflite"
extra_args="$extra_args --backend tflite --device tpu"
else
model_path="$MODEL_DIR/resnet50_v1.tflite"
profile="resnet50-tf"
extra_args="$extra_args --backend tflite"
fi
fi
if [ $name == "mobilenet-tflite" ] ; then

if [ "$name" = "mobilenet-tflite" ]; then
model_path="$MODEL_DIR/mobilenet_v1_1.0_224.tflite"
profile=mobilenet-tf
profile="mobilenet-tf"
extra_args="$extra_args --backend tflite"
fi

Expand Down
5 changes: 5 additions & 0 deletions vision/classification_and_detection/run_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,10 @@ done

cmd="python3 python/main.py --profile $profile $common_opt --model \"$model_path\" $dataset \
--output \"$OUTPUT_DIR\" $EXTRA_OPS ${ARGS}"

if [[ $EXTRA_OPS == *"tpu"* ]]; then
cmd="sudo $cmd"
fi

echo $cmd
eval $cmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list
apt-get update
apt install edgetpu-compiler
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import argparse
import os
import cv2
import numpy as np

import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import ResNet50, decode_predictions, preprocess_input


def main(path_to_mlperf_calib_dataset):
# load tf model
model = ResNet50(weights='imagenet')

# read jpeg files from mlperf calib dataset
jpeg_files = []
for filename in os.listdir(path_to_mlperf_calib_dataset):
if filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'):
jpeg_files.append(os.path.join(path_to_mlperf_calib_dataset, filename))

# qunatization
def representative_data_gen():
for img_path in jpeg_files:
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
yield [x]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

# full quant
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()
with open('resnet50_quant_full_mlperf.tflite', 'wb') as f:
f.write(tflite_model_quant)


if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--image-dir', type=str, default=None)
args = parser.parse_args()
print(args)
if args.image_dir is None or not os.path.exists(args.image_dir):
raise ValueError("Please provide a calibration dataset.")
main(args.image_dir)

# compile model for edge tpu
os.system("edgetpu_compiler resnet50_quant_full_mlperf.tflite")

0 comments on commit 518b454

Please sign in to comment.