Skip to content

Commit

Permalink
Merge pull request #16 from microsoft/hjiang/fix_prefix_dim_mismatch
Browse files Browse the repository at this point in the history
Fixed(LLMLingua): fix the prefix dimension mismatch.

Co-authored-by: Qianhui Wu <[email protected]>
Co-authored-by: Xufang Luo <[email protected]>
  • Loading branch information
3 people authored Nov 15, 2023
2 parents 6d34053 + 6c58f7d commit 70bbd02
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
self.retrieval_model_name = None
self.open_api_config = open_api_config
self.cache_bos_num = 10
self.prefix_bos_num = 100

def load_model(
self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False
Expand Down Expand Up @@ -214,12 +215,24 @@ def compress_prompt(
)

if condition_flag:
if add_instruction:
context = [question + "\n\n" + instruction] + context
start = self.get_token_length(question + "\n\n" + instruction) + 2
else:
context = [question] + context
start = self.get_token_length(question) + 2
prefix = question + "\n\n" + instruction if add_instruction else question
if (
self.get_token_length(prefix) + 2 + iterative_size * 2
> self.max_position_embeddings
):
tokens = self.tokenizer(prefix, add_special_tokens=False).input_ids
prefix = self.tokenizer.decode(
tokens[: self.prefix_bos_num]
+ tokens[
len(tokens)
- self.max_position_embeddings
+ 2
+ self.prefix_bos_num
+ 2 * iterative_size :
]
)
start = self.get_token_length(prefix) + 2
context = [prefix] + context
else:
start = 0

Expand Down Expand Up @@ -692,6 +705,7 @@ def iterative_compress_prompt(
]
end, ready_end = end - e, ready_end - e
if condition_compare:
s = min(s, self_past_key_values[0][0].shape[2] - e)
self_ready_end -= e
if pop_self_compressed_input_ids is None:
pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
Expand Down

0 comments on commit 70bbd02

Please sign in to comment.