From daa55180f1e4f6f1451eaa352a2d7940b3424c2d Mon Sep 17 00:00:00 2001 From: Gary Wang <38331932+gwang111@users.noreply.github.com> Date: Mon, 6 Jan 2025 16:55:35 -0800 Subject: [PATCH] model server might have already done a serialization. honor that by not decoding the request again if it is not already bytes or bytestream (#4987) --- .../multi_model_server/inference.py | 22 ++++++++++++++----- .../model_server/torchserve/inference.py | 22 ++++++++++++++----- .../torchserve/xgboost_inference.py | 22 ++++++++++++++----- 3 files changed, 48 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 908ffcc7aa..9361765da0 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -46,18 +46,28 @@ def input_fn(input_data, content_type, context=None): if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type[0], ) diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index 489cc1bc1e..058103a1fd 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -68,18 +68,28 @@ def input_fn(input_data, content_type): if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type[0], ) diff --git a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py index 517c774bbc..49cec5aab5 100644 --- a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py +++ b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py @@ -71,18 +71,28 @@ def input_fn(input_data, content_type): if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type, ) else: return schema_builder.input_deserializer.deserialize( ( - io.BytesIO(input_data) - if type(input_data) == bytes - else io.BytesIO(input_data.encode("utf-8")) + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) ), content_type[0], )