From 6c58f7dd38c2de80a78e196e9ae99e16e1a92233 Mon Sep 17 00:00:00 2001 From: Huiqiang Jiang Date: Wed, 15 Nov 2023 17:56:50 +0000 Subject: [PATCH] Fixed(LLMLingua): fix the prefix dimension mismatch --- llmlingua/prompt_compressor.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index aef46cb..68a0c4d 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -26,6 +26,7 @@ def __init__( self.retrieval_model_name = None self.open_api_config = open_api_config self.cache_bos_num = 10 + self.prefix_bos_num = 100 def load_model( self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False @@ -214,12 +215,24 @@ def compress_prompt( ) if condition_flag: - if add_instruction: - context = [question + "\n\n" + instruction] + context - start = self.get_token_length(question + "\n\n" + instruction) + 2 - else: - context = [question] + context - start = self.get_token_length(question) + 2 + prefix = question + "\n\n" + instruction if add_instruction else question + if ( + self.get_token_length(prefix) + 2 + iterative_size * 2 + > self.max_position_embeddings + ): + tokens = self.tokenizer(prefix, add_special_tokens=False).input_ids + prefix = self.tokenizer.decode( + tokens[: self.prefix_bos_num] + + tokens[ + len(tokens) + - self.max_position_embeddings + + 2 + + self.prefix_bos_num + + 2 * iterative_size : + ] + ) + start = self.get_token_length(prefix) + 2 + context = [prefix] + context else: start = 0 @@ -692,6 +705,7 @@ def iterative_compress_prompt( ] end, ready_end = end - e, ready_end - e if condition_compare: + s = min(s, self_past_key_values[0][0].shape[2] - e) self_ready_end -= e if pop_self_compressed_input_ids is None: pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]