From bd31e431c8e27902a3635982ced802697e7d42bd Mon Sep 17 00:00:00 2001
From: Huiqiang Jiang
Date: Mon, 9 Oct 2023 14:17:58 +0000
Subject: [PATCH] Feature(LLMLingua): add KV-Cache Compression & HF Space Demo
---
README.md | 8 ++--
llmlingua/prompt_compressor.py | 70 ++++++++++++++++++++++++++++++----
llmlingua/version.py | 2 +-
3 files changed, 69 insertions(+), 11 deletions(-)
diff --git a/README.md b/README.md
index 7e353fe..4eef7ad 100644
--- a/README.md
+++ b/README.md
@@ -2,20 +2,22 @@
-# LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models [[paper]()] & LongLLMLingua [[paper]()]
+# LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models [paper] & LongLLMLingua [paper]
https://github.com/microsoft/LLMLingua/assets/30883354/ef52995c-ef3c-4eac-a9fd-1acb491c325b
+You can try the LLMLingua demo in [HF Space](https://huggingface.co/spaces/microsoft/LLMLingua).
+
## Tl;DR
LLMLingua, that uses a well-trained small language model after alignment, such as GPT2-small or LLaMA-7B, to detect the unimportant tokens in the prompt and enable inference with the compressed prompt in black-box LLMs, achieving up to 20x compression with minimal performance loss.
-[LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models]() (EMNLP 2023).
+LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models (EMNLP 2023).
_Huiqiang Jiang, Qianhui Wu, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
LongLLMLingua is a method that enhances LLMs' ability to perceive key information in long-context scenarios using prompt compression, achieveing up to $28.5 in cost savings per 1,000 samples while also improving performance.
-[LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression]() (Under Review).
+LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression (Under Review).
_Huiqiang Jiang, Qianhui Wu, Xufang Luo, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
## 🎥 Overview
diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py
index 37a544d..f2f7400 100644
--- a/llmlingua/prompt_compressor.py
+++ b/llmlingua/prompt_compressor.py
@@ -24,12 +24,13 @@ def __init__(
self.load_model(model_name, device_map, use_auth_token)
self.sbert = None
self.open_api_config = open_api_config
+ self.cache_bos_num = 10
def load_model(
self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False
):
- config = AutoConfig.from_pretrained(model_name)
- tokenizer = AutoTokenizer.from_pretrained(model_name)
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = (
config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
@@ -40,12 +41,11 @@ def load_model(
if "cuda" in device_map or "cpu" in device_map:
model = AutoModelForCausalLM.from_pretrained(
model_name,
- torch_dtype="auto",
+ torch_dtype="auto" if device_map == "cuda" else torch.float32,
config=config,
ignore_mismatched_sizes=True,
+ trust_remote_code=True,
).to(device_map)
- if device_map == "cpu":
- model = model.type(torch.float32)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
@@ -56,10 +56,12 @@ def load_model(
offload_state_dict=True,
cache_dir="/tmp/cache",
use_auth_token=use_auth_token,
+ trust_remote_code=True,
)
self.tokenizer = tokenizer
self.model = model
self.context_idxs = []
+ self.max_position_embeddings = config.max_position_embeddings
def get_ppl(
self,
@@ -83,7 +85,7 @@ def get_ppl(
past_length = 0
if end is None:
end = input_ids.shape[1]
- end = min(end, past_length + 4096)
+ end = min(end, past_length + self.max_position_embeddings)
with torch.no_grad():
response = self.model(
input_ids[:, past_length:end],
@@ -145,11 +147,17 @@ def compress_prompt(
assert not (
rank_method == "longllmlingua" and not question
), "In the LongLLMLingua, it is necessary to set a question."
+ if condition_compare and "_condition" not in condition_in_question:
+ condition_in_question += "_condition"
if rank_method == "longllmlingua":
if condition_in_question == "none":
condition_in_question = "after"
elif rank_method == "llmlingua":
- condition_in_question = "none"
+ condition_in_question = (
+ "none"
+ if "_condition" not in condition_in_question
+ else "none_condition"
+ )
origin_tokens = len(
encoding.encode("\n\n".join([instruction] + context + [question]).strip())
)
@@ -653,8 +661,52 @@ def iterative_compress_prompt(
keep_flag = torch.tensor(keep_flag).to(self.device)
past_key_values, past_loss, ready_end = None, None, 0
self_past_key_values, self_past_loss, self_ready_end = None, None, 0
+ pop_compressed_input_ids, pop_self_compressed_input_ids = None, None
idx = 0
while end <= compressed_input_ids.shape[1]:
+ if end > self.max_position_embeddings and past_key_values is not None:
+ # KV-Cache Compression
+ e, s = end - self.max_position_embeddings, self.cache_bos_num
+ if pop_compressed_input_ids is None:
+ pop_compressed_input_ids = compressed_input_ids[:, :e]
+ else:
+ pop_compressed_input_ids = torch.cat(
+ [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
+ )
+ compressed_input_ids = compressed_input_ids[:, e:]
+ compressed_attention_mask = compressed_attention_mask[:, e:]
+ past_key_values = [
+ [
+ torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
+ torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
+ ]
+ for k, v in past_key_values
+ ]
+ end, ready_end = end - e, ready_end - e
+ if condition_compare:
+ self_ready_end -= e
+ if pop_self_compressed_input_ids is None:
+ pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
+ else:
+ pop_self_compressed_input_ids = torch.cat(
+ [
+ pop_self_compressed_input_ids,
+ self_compressed_input_ids[:, :e],
+ ],
+ dim=-1,
+ )
+ self_compressed_input_ids = self_compressed_input_ids[:, e:]
+ self_compressed_attention_mask = self_compressed_attention_mask[
+ :, e:
+ ]
+ self_past_key_values = [
+ [
+ torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
+ torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
+ ]
+ for k, v in self_past_key_values
+ ]
+
loss, past_key_values = self.get_ppl(
"",
"token",
@@ -762,6 +814,10 @@ def iterative_compress_prompt(
)
end += iterative_size
idx += 1
+ if pop_compressed_input_ids is not None:
+ compressed_input_ids = torch.cat(
+ [pop_compressed_input_ids, compressed_input_ids], dim=-1
+ )
return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
def recover(
diff --git a/llmlingua/version.py b/llmlingua/version.py
index f3165f4..156b192 100644
--- a/llmlingua/version.py
+++ b/llmlingua/version.py
@@ -2,7 +2,7 @@
_MINOR = "1"
# On master and in a nightly release the patch should be one ahead of the last
# released build.
-_PATCH = "1"
+_PATCH = "2"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""