Skip to content

Commit

Permalink
[Text Generation][Fix] Assert that tokens have the same shape on ou…
Browse files Browse the repository at this point in the history
…tput (#1143)

* initial implementation

* initial commit

* Update src/deepsparse/transformers/engines/nl_decoder_engine.py

* add split_engine_inputs

* reuse already existing softmax

* Revert "reuse already existing softmax"

This reverts commit 8234524.

* Reuse already existing softmax

* initial implementation
  • Loading branch information
dbogunowicz authored Jul 26, 2023
1 parent 982938b commit 3cab6a3
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,8 @@ def has_cache(self) -> bool:
"""
return self.multitoken_engine.kv_cache_enabled

@staticmethod
def join_engine_outputs(
batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
) -> List[numpy.ndarray]:
"""
Takes a list of outputs (batches) from the engine
Expand All @@ -483,7 +482,22 @@ def join_engine_outputs(
:return: A list of joined outputs
"""
tokens, logits = zip(*batch_outputs)

# find the longest sequence in the batch of tokens
max_len = max([token.shape[1] for token in tokens])

# pad all tokens to the same length
tokens = [
pad_to_fixed_length(
array=prediction,
max_len=max_len,
value=self.tokenizer.pad_token_id,
axis=1,
)
for prediction in tokens
]
tokens = numpy.concatenate(tokens, axis=0)

# find the longest sequence in the batch of logits
max_len = max([logits.shape[1] for logits in logits])
# pad all logits to the same length
Expand Down

0 comments on commit 3cab6a3

Please sign in to comment.