From 518b454fd8647bfbd23a074e875e87353f33393e Mon Sep 17 00:00:00 2001 From: makaveli <39617050+makaveli10@users.noreply.github.com> Date: Tue, 1 Oct 2024 22:23:18 +0530 Subject: [PATCH] Tflite tpu (#1449) * add tpu support tflite backend * tflite conversion & edgetpu compilation * add pycoral interpreter * update preprocess name * Update run_common.sh | fix wrong else case code * Revert version.py change --------- Co-authored-by: Arjun Suresh Co-authored-by: Miro --- .../python/backend_tflite.py | 15 ++++- .../python/dataset.py | 12 ++++ .../python/main.py | 15 ++++- .../run_common.sh | 23 +++++--- .../classification_and_detection/run_local.sh | 5 ++ .../tools/install_edgetpu_compiler.sh | 6 ++ .../tools/resnet50_tflite_edgetpu.py | 55 +++++++++++++++++++ 7 files changed, 119 insertions(+), 12 deletions(-) create mode 100644 vision/classification_and_detection/tools/install_edgetpu_compiler.sh create mode 100644 vision/classification_and_detection/tools/resnet50_tflite_edgetpu.py diff --git a/vision/classification_and_detection/python/backend_tflite.py b/vision/classification_and_detection/python/backend_tflite.py index 7c8c78c13..bb8ed7d16 100755 --- a/vision/classification_and_detection/python/backend_tflite.py +++ b/vision/classification_and_detection/python/backend_tflite.py @@ -19,7 +19,7 @@ _version = tf.__version__ _git_version = tf.__git_version__ - +import numpy as np import backend @@ -39,8 +39,13 @@ def image_format(self): # tflite is always NHWC return "NHWC" - def load(self, model_path, inputs=None, outputs=None): - self.sess = tflite.Interpreter(model_path=model_path) + def load(self, model_path, inputs=None, outputs=None, use_tpu=False): + self.use_tpu = use_tpu + if use_tpu: + from pycoral.utils.edgetpu import make_interpreter + self.sess = make_interpreter(model_path) + else: + self.sess = tflite.Interpreter(model_path=model_path) self.sess.allocate_tensors() # keep input/output name to index mapping self.input2index = {i["name"]: i["index"] for i in self.sess.get_input_details()} @@ -54,6 +59,10 @@ def predict(self, feed): self.lock.acquire() # set inputs for k, v in self.input2index.items(): + if self.use_tpu and self.sess.get_input_details()[v]['dtype'] == np.uint8: + input_scale, input_zero_point = self.sess.get_input_details()[v]["quantization"] + feed[k] = feed[k] / input_scale + input_zero_point + feed[k] = feed[k].astype(np.uint8) self.sess.set_tensor(v, feed[k]) self.sess.invoke() # get results diff --git a/vision/classification_and_detection/python/dataset.py b/vision/classification_and_detection/python/dataset.py index 02aa654c2..0d82a99a3 100755 --- a/vision/classification_and_detection/python/dataset.py +++ b/vision/classification_and_detection/python/dataset.py @@ -202,6 +202,18 @@ def pre_process_mobilenet(img, dims=None, need_transpose=False): return img +def pre_process_imagenet_tflite_tpu(img, dims=None, need_transpose=False): + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = resize_with_aspectratio(img, 224, 224, inter_pol=cv2.INTER_LINEAR) + img = center_crop(img, 224, 224) + img = np.asarray(img, dtype='float32') + + img = img[..., ::-1] + means = np.array([103.939, 116.779, 123.68], dtype=np.float32) + img -= means + return img + + def pre_process_imagenet_pytorch(img, dims=None, need_transpose=False): from PIL import Image import torchvision.transforms.functional as F diff --git a/vision/classification_and_detection/python/main.py b/vision/classification_and_detection/python/main.py index 1c2cd9a5a..d6f16467a 100755 --- a/vision/classification_and_detection/python/main.py +++ b/vision/classification_and_detection/python/main.py @@ -41,6 +41,9 @@ "imagenet_mobilenet": (imagenet.Imagenet, dataset.pre_process_mobilenet, dataset.PostProcessArgMax(offset=-1), {"image_size": [224, 224, 3]}), + "imagenet_tflite_tpu": + (imagenet.Imagenet, dataset.pre_process_imagenet_tflite_tpu, dataset.PostProcessArgMax(offset=0), + {"image_size": [224, 224, 3]}), "imagenet_pytorch": (imagenet.Imagenet, dataset.pre_process_imagenet_pytorch, dataset.PostProcessArgMax(offset=0), {"image_size": [224, 224, 3]}), @@ -114,6 +117,12 @@ "backend": "ncnn", "model-name": "resnet50", }, + "resnet50-tflite": { + "dataset": "imagenet_tflite_tpu", + "outputs": "ArgMax:0", + "backend": "tflite", + "model-name": "resnet50", + }, # mobilenet "mobilenet-tf": { @@ -231,6 +240,7 @@ def get_args(): parser.add_argument("--inputs", help="model inputs") parser.add_argument("--outputs", help="model outputs") parser.add_argument("--backend", help="runtime to use") + parser.add_argument("--device", help="device to use") parser.add_argument("--model-name", help="name of the mlperf model, ie. resnet50") parser.add_argument("--threads", default=os.cpu_count(), type=int, help="threads") parser.add_argument("--qps", type=int, help="target qps") @@ -500,7 +510,10 @@ def main(): threads=args.threads, **kwargs) # load model to backend - model = backend.load(args.model, inputs=args.inputs, outputs=args.outputs) + if args.device == "tpu": + model = backend.load(args.model, inputs=args.inputs, outputs=args.outputs, use_tpu=True) + else: + model = backend.load(args.model, inputs=args.inputs, outputs=args.outputs) final_results = { "runtime": model.name(), "version": model.version(), diff --git a/vision/classification_and_detection/run_common.sh b/vision/classification_and_detection/run_common.sh index 37552c71e..0d872c85a 100755 --- a/vision/classification_and_detection/run_common.sh +++ b/vision/classification_and_detection/run_common.sh @@ -1,7 +1,7 @@ #!/bin/bash if [ $# -lt 1 ]; then - echo "usage: $0 tf|onnxruntime|pytorch|tflite|tvm-onnx|tvm-pytorch|tvm-tflite [resnet50|mobilenet|ssd-mobilenet|ssd-resnet34|retinanet] [cpu|gpu]" + echo "usage: $0 tf|onnxruntime|pytorch|tflite|tvm-onnx|tvm-pytorch|tvm-tflite [resnet50|mobilenet|ssd-mobilenet|ssd-resnet34|retinanet] [cpu|gpu|tpu]" exit 1 fi if [ "x$DATA_DIR" == "x" ]; then @@ -19,7 +19,7 @@ device="cpu" for i in $* ; do case $i in tf|onnxruntime|tflite|pytorch|tvm-onnx|tvm-pytorch|tvm-tflite|ncnn) backend=$i; shift;; - cpu|gpu|rocm) device=$i; shift;; + cpu|gpu|tpu|rocm) device=$i; shift;; gpu) device=gpu; shift;; resnet50|mobilenet|ssd-mobilenet|ssd-resnet34|ssd-resnet34-tf|retinanet) model=$i; shift;; esac @@ -108,14 +108,21 @@ fi # # tflite # -if [ $name == "resnet50-tflite" ] ; then - model_path="$MODEL_DIR/resnet50_v1.tflite" - profile=resnet50-tf - extra_args="$extra_args --backend tflite" +if [ "$name" = "resnet50-tflite" ]; then + if [ "$device" = "tpu" ]; then + model_path="$MODEL_DIR/resnet50_quant_full_mlperf_edgetpu.tflite" + profile="resnet50-tflite" + extra_args="$extra_args --backend tflite --device tpu" + else + model_path="$MODEL_DIR/resnet50_v1.tflite" + profile="resnet50-tf" + extra_args="$extra_args --backend tflite" + fi fi -if [ $name == "mobilenet-tflite" ] ; then + +if [ "$name" = "mobilenet-tflite" ]; then model_path="$MODEL_DIR/mobilenet_v1_1.0_224.tflite" - profile=mobilenet-tf + profile="mobilenet-tf" extra_args="$extra_args --backend tflite" fi diff --git a/vision/classification_and_detection/run_local.sh b/vision/classification_and_detection/run_local.sh index 16f905ce4..120b538c7 100755 --- a/vision/classification_and_detection/run_local.sh +++ b/vision/classification_and_detection/run_local.sh @@ -21,5 +21,10 @@ done cmd="python3 python/main.py --profile $profile $common_opt --model \"$model_path\" $dataset \ --output \"$OUTPUT_DIR\" $EXTRA_OPS ${ARGS}" + +if [[ $EXTRA_OPS == *"tpu"* ]]; then + cmd="sudo $cmd" +fi + echo $cmd eval $cmd diff --git a/vision/classification_and_detection/tools/install_edgetpu_compiler.sh b/vision/classification_and_detection/tools/install_edgetpu_compiler.sh new file mode 100644 index 000000000..dea1c9e18 --- /dev/null +++ b/vision/classification_and_detection/tools/install_edgetpu_compiler.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add - +echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list +apt-get update +apt install edgetpu-compiler \ No newline at end of file diff --git a/vision/classification_and_detection/tools/resnet50_tflite_edgetpu.py b/vision/classification_and_detection/tools/resnet50_tflite_edgetpu.py new file mode 100644 index 000000000..b99c989a2 --- /dev/null +++ b/vision/classification_and_detection/tools/resnet50_tflite_edgetpu.py @@ -0,0 +1,55 @@ +import argparse +import os +import cv2 +import numpy as np + +import tensorflow as tf +from tensorflow.keras.preprocessing import image +from tensorflow.keras.applications.resnet50 import ResNet50, decode_predictions, preprocess_input + + +def main(path_to_mlperf_calib_dataset): + # load tf model + model = ResNet50(weights='imagenet') + + # read jpeg files from mlperf calib dataset + jpeg_files = [] + for filename in os.listdir(path_to_mlperf_calib_dataset): + if filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'): + jpeg_files.append(os.path.join(path_to_mlperf_calib_dataset, filename)) + + # qunatization + def representative_data_gen(): + for img_path in jpeg_files: + img = image.load_img(img_path, target_size=(224, 224)) + x = image.img_to_array(img) + x = np.expand_dims(x, axis=0) + x = preprocess_input(x) + yield [x] + + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_data_gen + # Ensure that if any ops can't be quantized, the converter throws an error + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + + # full quant + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + + tflite_model_quant = converter.convert() + with open('resnet50_quant_full_mlperf.tflite', 'wb') as f: + f.write(tflite_model_quant) + + +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--image-dir', type=str, default=None) + args = parser.parse_args() + print(args) + if args.image_dir is None or not os.path.exists(args.image_dir): + raise ValueError("Please provide a calibration dataset.") + main(args.image_dir) + + # compile model for edge tpu + os.system("edgetpu_compiler resnet50_quant_full_mlperf.tflite") \ No newline at end of file