From bd31e431c8e27902a3635982ced802697e7d42bd Mon Sep 17 00:00:00 2001 From: Huiqiang Jiang Date: Mon, 9 Oct 2023 14:17:58 +0000 Subject: [PATCH] Feature(LLMLingua): add KV-Cache Compression & HF Space Demo --- README.md | 8 ++-- llmlingua/prompt_compressor.py | 70 ++++++++++++++++++++++++++++++---- llmlingua/version.py | 2 +- 3 files changed, 69 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 7e353fe..4eef7ad 100644 --- a/README.md +++ b/README.md @@ -2,20 +2,22 @@ LLMLingua

-# LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models [[paper]()] & LongLLMLingua [[paper]()] +# LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models [paper] & LongLLMLingua [paper] https://github.com/microsoft/LLMLingua/assets/30883354/ef52995c-ef3c-4eac-a9fd-1acb491c325b +You can try the LLMLingua demo in [HF Space](https://huggingface.co/spaces/microsoft/LLMLingua). + ## Tl;DR LLMLingua, that uses a well-trained small language model after alignment, such as GPT2-small or LLaMA-7B, to detect the unimportant tokens in the prompt and enable inference with the compressed prompt in black-box LLMs, achieving up to 20x compression with minimal performance loss. -[LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models]() (EMNLP 2023). +LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models (EMNLP 2023).
_Huiqiang Jiang, Qianhui Wu, Chin-Yew Lin, Yuqing Yang and Lili Qiu_ LongLLMLingua is a method that enhances LLMs' ability to perceive key information in long-context scenarios using prompt compression, achieveing up to $28.5 in cost savings per 1,000 samples while also improving performance. -[LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression]() (Under Review). +LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression (Under Review).
_Huiqiang Jiang, Qianhui Wu, Xufang Luo, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu_ ## 🎥 Overview diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index 37a544d..f2f7400 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -24,12 +24,13 @@ def __init__( self.load_model(model_name, device_map, use_auth_token) self.sbert = None self.open_api_config = open_api_config + self.cache_bos_num = 10 def load_model( self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False ): - config = AutoConfig.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.padding_side = "left" tokenizer.pad_token_id = ( config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id @@ -40,12 +41,11 @@ def load_model( if "cuda" in device_map or "cpu" in device_map: model = AutoModelForCausalLM.from_pretrained( model_name, - torch_dtype="auto", + torch_dtype="auto" if device_map == "cuda" else torch.float32, config=config, ignore_mismatched_sizes=True, + trust_remote_code=True, ).to(device_map) - if device_map == "cpu": - model = model.type(torch.float32) else: model = AutoModelForCausalLM.from_pretrained( model_name, @@ -56,10 +56,12 @@ def load_model( offload_state_dict=True, cache_dir="/tmp/cache", use_auth_token=use_auth_token, + trust_remote_code=True, ) self.tokenizer = tokenizer self.model = model self.context_idxs = [] + self.max_position_embeddings = config.max_position_embeddings def get_ppl( self, @@ -83,7 +85,7 @@ def get_ppl( past_length = 0 if end is None: end = input_ids.shape[1] - end = min(end, past_length + 4096) + end = min(end, past_length + self.max_position_embeddings) with torch.no_grad(): response = self.model( input_ids[:, past_length:end], @@ -145,11 +147,17 @@ def compress_prompt( assert not ( rank_method == "longllmlingua" and not question ), "In the LongLLMLingua, it is necessary to set a question." + if condition_compare and "_condition" not in condition_in_question: + condition_in_question += "_condition" if rank_method == "longllmlingua": if condition_in_question == "none": condition_in_question = "after" elif rank_method == "llmlingua": - condition_in_question = "none" + condition_in_question = ( + "none" + if "_condition" not in condition_in_question + else "none_condition" + ) origin_tokens = len( encoding.encode("\n\n".join([instruction] + context + [question]).strip()) ) @@ -653,8 +661,52 @@ def iterative_compress_prompt( keep_flag = torch.tensor(keep_flag).to(self.device) past_key_values, past_loss, ready_end = None, None, 0 self_past_key_values, self_past_loss, self_ready_end = None, None, 0 + pop_compressed_input_ids, pop_self_compressed_input_ids = None, None idx = 0 while end <= compressed_input_ids.shape[1]: + if end > self.max_position_embeddings and past_key_values is not None: + # KV-Cache Compression + e, s = end - self.max_position_embeddings, self.cache_bos_num + if pop_compressed_input_ids is None: + pop_compressed_input_ids = compressed_input_ids[:, :e] + else: + pop_compressed_input_ids = torch.cat( + [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1 + ) + compressed_input_ids = compressed_input_ids[:, e:] + compressed_attention_mask = compressed_attention_mask[:, e:] + past_key_values = [ + [ + torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), + torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), + ] + for k, v in past_key_values + ] + end, ready_end = end - e, ready_end - e + if condition_compare: + self_ready_end -= e + if pop_self_compressed_input_ids is None: + pop_self_compressed_input_ids = self_compressed_input_ids[:, :e] + else: + pop_self_compressed_input_ids = torch.cat( + [ + pop_self_compressed_input_ids, + self_compressed_input_ids[:, :e], + ], + dim=-1, + ) + self_compressed_input_ids = self_compressed_input_ids[:, e:] + self_compressed_attention_mask = self_compressed_attention_mask[ + :, e: + ] + self_past_key_values = [ + [ + torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), + torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), + ] + for k, v in self_past_key_values + ] + loss, past_key_values = self.get_ppl( "", "token", @@ -762,6 +814,10 @@ def iterative_compress_prompt( ) end += iterative_size idx += 1 + if pop_compressed_input_ids is not None: + compressed_input_ids = torch.cat( + [pop_compressed_input_ids, compressed_input_ids], dim=-1 + ) return compressed_input_ids[:, start:], compressed_attention_mask[:, start:] def recover( diff --git a/llmlingua/version.py b/llmlingua/version.py index f3165f4..156b192 100644 --- a/llmlingua/version.py +++ b/llmlingua/version.py @@ -2,7 +2,7 @@ _MINOR = "1" # On master and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "1" +_PATCH = "2" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = ""