Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No support of GQA of Llama in real_drop #32

Open
Tomorrowdawn opened this issue May 27, 2024 · 1 comment
Open

No support of GQA of Llama in real_drop #32

Tomorrowdawn opened this issue May 27, 2024 · 1 comment

Comments

@Tomorrowdawn
Copy link

Tomorrowdawn commented May 27, 2024

In modify_llama.py, the hh_score of H2OCache is computed by attn_scores.sum(0).sum(1), resulting in a shape of [num_heads, hidden_dim]. However, in Llama's GQA implementation(just in the same file), the k/v cache has a shape of [B, num_key_value_heads, ....], which mismatches the hh_score.

I manually implement it with eager torch attention kernel. With either "keep the first head of each group", "get mean of each group", "get sum of each group", or repeat the key/value states in cache, h2o gives me an unacceptable result like :

5177706214c60061f4e854098449d29

I almostly copied these lines so it seems not an implementation mistake.

here is the implementation and DynamicCache is same as transformers' DynamicCache.

class H2OCache(DynamicCache):
    ##inheritance from DynamicCache is not a bug.
    ##it matches the official code.
    def __init__(self, max_len, device, num_key_value_heads, num_kv_groups,
        hh_size=128,
        recent_size=512,):
        super().__init__(max_len, device)
        self.num_key_value_heads = num_key_value_heads
        self.num_kv_groups = num_kv_groups

        self.recent_size = recent_size

        self.hh_size = hh_size
        self.hh_scores = []
    def _update_hh_scores(self, attn_score_cache, layer_idx):
        ##check https://github.com/FMInference/H2O/blob/281ffef3f1432ceb1a6899362d2f20e1ef13aa94/h2o_hf/utils_real_drop/modify_llama.py#L290
        
        num_new_tokens = attn_score_cache.shape[2]
        if len(self.hh_scores) <= layer_idx:
            hh_score = attn_score_cache.sum(0).sum(1)
            self.hh_scores.append(hh_score)
        else:
            attn_score_cache = attn_score_cache.sum(0).sum(1)##[B, H, Q, K]->[H,Q,K]->[H, K]. 
            attn_score_cache[:, :-num_new_tokens] += self.hh_scores[layer_idx]##[H, K]
            ##can't work with GQA
            self.hh_scores[layer_idx] = attn_score_cache

    def evict(self, attn_scores, layer_idx):
        """
        attn_scores:[B, H, Q, K]
        """
        ##not quite same as the paper. 
        ##here use a top-k selection. 
        self._update_hh_scores(attn_scores, layer_idx)

        bsz, num_heads, q_len, k_len = attn_scores.shape##k_len = cache len(after store)
        seq_len = k_len
        if k_len <= self.recent_size + self.hh_size:
            return
        head_dim = self.key_cache[layer_idx].shape[-1]
        
        select_hh_scores = self.hh_scores[layer_idx][:, :seq_len - self.recent_size]
        _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
        keep_topk = keep_topk.sort().values

        # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
        keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
        keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
        hh_score = self.hh_scores[layer_idx]
        mask = torch.zeros(hh_score.shape, dtype=torch.bool).to(self.key_cache[layer_idx].device)
        mask = mask.scatter(-1, keep_idx, 1)

        k_hh_recent = self.key_cache[layer_idx].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
        v_hh_recent = self.value_cache[layer_idx].squeeze()[mask].view(bsz, num_heads, -1, head_dim)

        self.hh_scores[layer_idx] = hh_score[mask].view(num_heads, self.hh_size + self.recent_size)

        self.key_cache[layer_idx] = k_hh_recent
        self.value_cache[layer_idx] = v_hh_recent
@Tomorrowdawn
Copy link
Author

Tomorrowdawn commented May 27, 2024

Supplementary:

Well that's a bug hiding in rotary embedding(which is not here so I didn't find it...). Re-applying rotary embedding at each step is neccessary so I modify the whole DynamicCache class.

Though, the performance is not outstanding, shown below:

image

A dedicate support of GQA is needed(here is 'repeat_interleave' implementation. I manually repeated the kv cache. It is closest to paper but least efficient).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant