Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed difference ONNX vs TensorRT with samples sorted by sequence length #55

Open
v1nc3nt27 opened this issue Feb 23, 2022 · 6 comments

Comments

@v1nc3nt27
Copy link

I noticed something unexpected when comparing two scenarios for a model converted via ONNX and TensorRT (distilroberta with classification head):

  1. Scenario: I use a dataset with varying sentence lengths (~20-60 tokens) and run it randomly sampled through both models
  2. Scenario: I use the same dataset but sort the sentences by sentence length (decreasing) before running it through both models

Result: The TensorRT model does not seem to care about the sequence lengths and keeps the same speed for both scenarios. The ONNX model, however, gets almost twice as fast when I use the second scenario.

I was wondering if tensorRT's optimization does somehow require to pad to the max length internally. I was searching for a parameter or a reason for this behavior but couldn't find anything useful. For conversion, I set the seq-len parameter to 1 60 60.

I was wondering if perhaps someone else has already observed this and knows the reason / a solution.

@v1nc3nt27 v1nc3nt27 changed the title Bucketing with ONNX vs TensorRT Speed difference ONNX vs TensorRT samples sorted by sequence length Feb 23, 2022
@v1nc3nt27 v1nc3nt27 changed the title Speed difference ONNX vs TensorRT samples sorted by sequence length Speed difference ONNX vs TensorRT with samples sorted by sequence length Feb 23, 2022
@pommedeterresautee
Copy link
Member

Is there some batching applied?

@v1nc3nt27
Copy link
Author

Oh, I completely forgot to mention that. Yes, I use a batch size of 64. This behavior only applies if batching is used.

@pommedeterresautee
Copy link
Member

how each batch is built? is it made of seq of the exact same len ?

@v1nc3nt27
Copy link
Author

v1nc3nt27 commented Mar 3, 2022

The samples are just ordered by character length and then batched, so they still may vary within a batch (but much less than before). The speed up just comes from the fact less batches are padded to the model_max_length in that case.

I replaced

tokens: Dict[str, np.ndarray] = self.tokenizer(text=query, return_tensors=TensorType.NUMPY)

with

tokens: Dict[str, np.ndarray] = self.tokenizer(query_question, query_answer return_tensors=TensorType.NUMPY, padding="longest", truncation=True)

and added self.tokenizer.model_max_length = 60 as a last line to initialize().

@pommedeterresautee
Copy link
Member

can you provide me with some reproducible code so I test on my side?

@v1nc3nt27
Copy link
Author

v1nc3nt27 commented Mar 8, 2022

Hey @pommedeterresautee, sorry for the long wait - I was on a holiday trip.

I based my script on your demo scripts but I cannot disclose the model and/or dataset. You can basically use any dataset with 2 inputs, e.g. example for QA. I hope you can make use of it anyway.

I attached the script to call the inference assemble hosted in triton (transformer_onnx_inference or transformer_trt_inference) and the slightly modified model.py for the tokenize endpoint in triton.

If you experience the same what I do, then calling the ONNX model's inference endpoint should be slower if you comment out the length sorting in triton_inference_qa_test.py and there should be no difference if you do the same for the trt model's inference.

triton_inference_qa_test.py

import argparse
import math
import time
import numpy as np
import tritonclient.http
from tqdm import tqdm
from scipy.special import softmax
from transformer_deploy.benchmarks.utils import print_timings, setup_logging, track_infer_time


def _batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]


def run_http_sync():
    all_gold = []
    all_pred = []
    triton_client = tritonclient.http.InferenceServerClient(url=url, verbose=False)

    assert triton_client.is_model_ready(
        model_name=model_name, model_version=model_version
    ), f"model {model_name} not yet ready"
    setup_logging()

    model_score = tritonclient.http.InferRequestedOutput(name="output", binary_data=True)
    for b in tqdm(_batch(list(zip(questions, answers, golds)), batch_size), total=math.ceil(len(answers)/batch_size)):
        with track_infer_time(time_buffer):
            topic_b, sent_b, gold_b = zip(*b) 
            all_gold.extend(gold_b)

            query_sent = tritonclient.http.InferInput(name="sent", shape=(len(b),), datatype="BYTES")
            query_topic = tritonclient.http.InferInput(name="topic", shape=(len(b),), datatype="BYTES")

            query_sent.set_data_from_numpy(np.asarray(sent_b, dtype=object))
            query_topic.set_data_from_numpy(np.asarray(topic_b, dtype=object))

            response = triton_client.infer(
                model_name=model_name, model_version=model_version, inputs=[query_topic, query_sent], outputs=[model_score],
                response_compression_algorithm="gzip", request_compression_algorithm="gzip"
            )
            res = response.as_numpy("output")
            scores = softmax(res, axis=1)
            all_pred.extend([np.argmax(pred) for pred in scores])

    return all_gold, all_pred


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="require inference", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--triton-model-name", help="Model name in triton server", type=str)
    parser.add_argument("--url", help="", type=str, default="127.0.0.1:8000")
    parser.add_argument("--batch-size", help="", type=int, default=64)
    args, _ = parser.parse_known_args()

    setup_logging()
    model_name = args.triton_model_name
    url = args.url
    model_version = "1"
    batch_size = args.batch_size

    # todo read data
    data = read_data(...)

    answers = data["answers"].tolist()
    golds = data["label"].tolist()
    questions = data["questions"].tolist()
    t = time.time()

    # comment out the following 4 lines to switch off bucketing
    length_sorted_idx = np.argsort([len(q+a) for q, a in zip(questions, answers)])
    answers = [answers[idx] for idx in length_sorted_idx]
    golds = [golds[idx] for idx in length_sorted_idx]
    questions = [questions[idx] for idx in length_sorted_idx]
    time_buffer = list()

    all_gold, all_pred = run_http_sync()

    print_timings(name="triton transformers", timings=time_buffer)
    total_time = time.time() - t
    print("Total time: " + str(total_time))

model.py

import os
from typing import Dict, List

import numpy as np


try:
    # noinspection PyUnresolvedReferences
    import triton_python_backend_utils as pb_utils
except ImportError:
    pass  # triton_python_backend_utils exists only inside Triton Python backend.

from transformers import AutoTokenizer, PreTrainedTokenizer, TensorType


class TritonPythonModel:
    is_tensorrt: bool
    tokenizer: PreTrainedTokenizer

    def initialize(self, args: Dict[str, str]) -> None:
        """
        Initialize the tokenization process
        :param args: arguments from Triton config file
        """
        path: str = os.path.join(args["model_repository"], args["model_version"])
        model_name: str = args["model_name"]
        self.is_tensorrt = "trt" in model_name or "tensorrt" in model_name
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.tokenizer.model_max_length = 60

    def execute(self, requests) -> "List[List[pb_utils.Tensor]]":
        """
        Parse and tokenize each request
        :param requests: 1 or more requests received by Triton server.
        :return: text as input tensors
        """
        responses = []
        # for loop for batch requests (disabled in our case)
        for request in requests:
            # binary data typed back to string
            query_topic = [t.decode("UTF-8") for t in pb_utils.get_input_tensor_by_name(request, "topic").as_numpy().tolist()]
            query_sent = [t.decode("UTF-8") for t in pb_utils.get_input_tensor_by_name(request, "sent").as_numpy().tolist()]

            tokens: Dict[str, np.ndarray] = self.tokenizer(query_topic, query_sent, return_tensors=TensorType.NUMPY,
                                                           padding="longest", truncation=True)
            if self.is_tensorrt:
                # tensorrt uses int32 as input type, ort uses int64
                tokens = {k: v.astype(np.int32) for k, v in tokens.items()}
            # communicate the tokenization results to Triton server
            outputs = list()
            for input_name in self.tokenizer.model_input_names:
                tensor_input = pb_utils.Tensor(input_name, tokens[input_name])
                outputs.append(tensor_input)

            inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
            responses.append(inference_response)

        return responses

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants