Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError: 'Cache only has 0 layers, attempted to access layer with index 0' #25

Open
hasanar1f opened this issue Apr 2, 2024 · 3 comments

Comments

@hasanar1f
Copy link

hasanar1f commented Apr 2, 2024

How do I reproduce?

import torch
from transformers import AutoModelForCausalLM, LlamaForCausalLM, AutoTokenizer, AutoConfig, LlamaTokenizer, LlamaConfig
from transformers.modeling_utils import load_sharded_checkpoint
from accelerate import init_empty_weights, load_checkpoint_in_model, load_checkpoint_and_dispatch
from utils_hh.modify_llama import convert_kvcache_llama_heavy_recent, LlamaAttention_heavy_hitter

ENABLE_Heavy_Hitter_FUNCTIONS = {
    "llama": convert_kvcache_llama_heavy_recent,
}

model_name = 'meta-llama/Llama-2-7b-hf'
cache_dir = 'checkpoint/models--meta-llama--Llama-2-7b-hf/snapshots/8a0442e81540efaeb1a0fe3e95477b5e0edfd423'
heavy_ratio = 0.1
recent_ratio = 0.1
length = 64
seed = 42


torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: {}".format(device))

tokenizer = LlamaTokenizer.from_pretrained(model_name)
config = LlamaConfig.from_pretrained(model_name)
config.heavy_ratio = heavy_ratio
config.recent_ratio = recent_ratio

# load model without weight
with init_empty_weights():
    model = LlamaForCausalLM(config)

model = convert_kvcache_llama_heavy_recent(model, config)

print(model) 

# load checkpoint into the empty weight

model = load_checkpoint_and_dispatch(
    model,
    checkpoint=cache_dir,
    device_map="auto",
    offload_folder=cache_dir,
    dtype=torch.float16,
    offload_state_dict=True, 
)

prompt_text = 'Hello.'
input_ids = tokenizer(prompt_text, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)

generate_ids_hh = model.generate(
    input_ids,
    max_new_tokens=100,
    use_cache=True,
    ) # error here
@hasanar1f
Copy link
Author

hasanar1f commented Apr 2, 2024

using transformers==4.33.0 solved it.

@hasanar1f hasanar1f reopened this Apr 3, 2024
@hasanar1f
Copy link
Author

hasanar1f commented Apr 3, 2024

Hi, @Kyriection, as I mentioned, one workaround to avoid this error is to downgrade transformers to something < 3.36.

BUT, I am trying to test H2O on some of the of the latest LLM. For that, I need to use an upgraded version of transformer (>=4.37.2) Do you have any suggestion?

the error is causing here in the forward function of H2OLlamaAttention_streaming --

        # remake causal mask
        attention_mask = _make_causal_mask(
            bsz=bsz,
            tgt_len=q_len,
            past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
            dtype=query_states.dtype,
            device=query_states.device,
        )

@Kyriection
Copy link
Collaborator

Hi @hasanibnarif, Huggingface update their cache implementation since version 3.36. Previously the past_key_value are a list of tensors that contain key and value embeddings while now they use a cache instance to maintain the kv cache. The definition of kv cache is located in https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L76.

I have a initial version of the new h2o kv cache implementation based on the cache class (https://github.com/Kyriection/llama-recipes/blob/main/research/long-context-llama/H2O/utils/cache.py#L342), Please note that this version is still under developed and I will release it once finished.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants