Skip to content

Commit

Permalink
test: Enhance Python gRPC streaming test to send multiple requests (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui authored Oct 7, 2024
1 parent 9bbee48 commit 71a285a
Showing 1 changed file with 25 additions and 47 deletions.
72 changes: 25 additions & 47 deletions qa/L0_python_api/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import os
import queue
from functools import partial
from typing import Union

import numpy as np
Expand Down Expand Up @@ -101,59 +102,36 @@ def send_and_test_inference_identity(frontend_client, url: str) -> bool:
return input_data[0] == output_data[0].decode()


# Sends a streaming inference request to test_model_repository/identity model
# and verifies input == output
# Sends multiple streaming requests to "delayed_identity" model with negligible delays,
# and verifies the inputs matches outputs and the ordering is preserved.
def send_and_test_stream_inference(frontend_client, url: str) -> bool:
model_name = "identity"

# Setting up the gRPC client stream
results = queue.Queue()
callback = lambda error, result: results.put(error or result)
client = frontend_client.InferenceServerClient(url=url)

client.start_stream(callback=callback)

# Preparing Input Data
text_input = "testing"
input_tensor = frontend_client.InferInput(
name="INPUT0", shape=[1], datatype="BYTES"
)
input_tensor.set_data_from_numpy(np.array([text_input.encode()], dtype=np.object_))
num_requests = 100
requests = []
for i in range(num_requests):
input0_np = np.array([[float(i) / 1000]], dtype=np.float32)
inputs = [frontend_client.InferInput("INPUT0", input0_np.shape, "FP32")]
inputs[0].set_data_from_numpy(input0_np)
requests.append(inputs)

# Sending Streaming Inference Request
client.async_stream_infer(
model_name=model_name, inputs=[input_tensor], enable_empty_final_response=True
)

# Looping through until exception thrown or request completed
completed_requests, num_requests = 0, 1
text_output, is_final = None, None
while completed_requests != num_requests:
result = results.get()
if isinstance(result, InferenceServerException):
if result.status() == "StatusCode.CANCELLED":
completed_requests += 1
raise result

# Processing Response
text_output = result.as_numpy("OUTPUT0")[0].decode()
responses = []

triton_final_response = result.get_response().parameters.get(
"triton_final_response", {}
)
def callback(responses, result, error):
responses.append({"result": result, "error": error})

is_final = False
if triton_final_response.HasField("bool_param"):
is_final = triton_final_response.bool_param

# Request Completed
if is_final:
completed_requests += 1
client = frontend_client.InferenceServerClient(url=url)
client.start_stream(partial(callback, responses))
for inputs in requests:
client.async_stream_infer("delayed_identity", inputs)
client.stop_stream()
teardown_client(client)

# Tearing down gRPC client stream
client.stop_stream(cancel_requests=True)
assert len(responses) == num_requests
for i in range(len(responses)):
assert responses[i]["error"] is None
output0_np = responses[i]["result"].as_numpy(name="OUTPUT0")
assert np.allclose(output0_np, [[float(i) / 1000]])

return is_final and (text_input == text_output)
return True # test passed


def send_and_test_generate_inference() -> bool:
Expand Down

0 comments on commit 71a285a

Please sign in to comment.