diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index cf519586a4..bc32d2f7ad 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -468,9 +468,8 @@ def has_cache(self) -> bool: """ return self.multitoken_engine.kv_cache_enabled - @staticmethod def join_engine_outputs( - batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int + self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int ) -> List[numpy.ndarray]: """ Takes a list of outputs (batches) from the engine @@ -483,7 +482,22 @@ def join_engine_outputs( :return: A list of joined outputs """ tokens, logits = zip(*batch_outputs) + + # find the longest sequence in the batch of tokens + max_len = max([token.shape[1] for token in tokens]) + + # pad all tokens to the same length + tokens = [ + pad_to_fixed_length( + array=prediction, + max_len=max_len, + value=self.tokenizer.pad_token_id, + axis=1, + ) + for prediction in tokens + ] tokens = numpy.concatenate(tokens, axis=0) + # find the longest sequence in the batch of logits max_len = max([logits.shape[1] for logits in logits]) # pad all logits to the same length