You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In modify_llama.py, the hh_score of H2OCache is computed by attn_scores.sum(0).sum(1), resulting in a shape of [num_heads, hidden_dim]. However, in Llama's GQA implementation(just in the same file), the k/v cache has a shape of [B, num_key_value_heads, ....], which mismatches the hh_score.
I manually implement it with eager torch attention kernel. With either "keep the first head of each group", "get mean of each group", "get sum of each group", or repeat the key/value states in cache, h2o gives me an unacceptable result like :
I almostly copied these lines so it seems not an implementation mistake.
here is the implementation and DynamicCache is same as transformers' DynamicCache.
classH2OCache(DynamicCache):
##inheritance from DynamicCache is not a bug.##it matches the official code.def__init__(self, max_len, device, num_key_value_heads, num_kv_groups,
hh_size=128,
recent_size=512,):
super().__init__(max_len, device)
self.num_key_value_heads=num_key_value_headsself.num_kv_groups=num_kv_groupsself.recent_size=recent_sizeself.hh_size=hh_sizeself.hh_scores= []
def_update_hh_scores(self, attn_score_cache, layer_idx):
##check https://github.com/FMInference/H2O/blob/281ffef3f1432ceb1a6899362d2f20e1ef13aa94/h2o_hf/utils_real_drop/modify_llama.py#L290num_new_tokens=attn_score_cache.shape[2]
iflen(self.hh_scores) <=layer_idx:
hh_score=attn_score_cache.sum(0).sum(1)
self.hh_scores.append(hh_score)
else:
attn_score_cache=attn_score_cache.sum(0).sum(1)##[B, H, Q, K]->[H,Q,K]->[H, K]. attn_score_cache[:, :-num_new_tokens] +=self.hh_scores[layer_idx]##[H, K]##can't work with GQAself.hh_scores[layer_idx] =attn_score_cachedefevict(self, attn_scores, layer_idx):
""" attn_scores:[B, H, Q, K] """##not quite same as the paper. ##here use a top-k selection. self._update_hh_scores(attn_scores, layer_idx)
bsz, num_heads, q_len, k_len=attn_scores.shape##k_len = cache len(after store)seq_len=k_lenifk_len<=self.recent_size+self.hh_size:
returnhead_dim=self.key_cache[layer_idx].shape[-1]
select_hh_scores=self.hh_scores[layer_idx][:, :seq_len-self.recent_size]
_, keep_topk=torch.topk(select_hh_scores, self.hh_size, dim=-1)
keep_topk=keep_topk.sort().values# keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)keep_recent=torch.arange(seq_len-self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
keep_idx=torch.cat([keep_topk, keep_recent], dim=-1)
hh_score=self.hh_scores[layer_idx]
mask=torch.zeros(hh_score.shape, dtype=torch.bool).to(self.key_cache[layer_idx].device)
mask=mask.scatter(-1, keep_idx, 1)
k_hh_recent=self.key_cache[layer_idx].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
v_hh_recent=self.value_cache[layer_idx].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
self.hh_scores[layer_idx] =hh_score[mask].view(num_heads, self.hh_size+self.recent_size)
self.key_cache[layer_idx] =k_hh_recentself.value_cache[layer_idx] =v_hh_recent
The text was updated successfully, but these errors were encountered:
Well that's a bug hiding in rotary embedding(which is not here so I didn't find it...). Re-applying rotary embedding at each step is neccessary so I modify the whole DynamicCache class.
Though, the performance is not outstanding, shown below:
A dedicate support of GQA is needed(here is 'repeat_interleave' implementation. I manually repeated the kv cache. It is closest to paper but least efficient).
In modify_llama.py, the hh_score of H2OCache is computed by attn_scores.sum(0).sum(1), resulting in a shape of [num_heads, hidden_dim]. However, in Llama's GQA implementation(just in the same file), the k/v cache has a shape of [B, num_key_value_heads, ....], which mismatches the hh_score.
I manually implement it with eager torch attention kernel. With either "keep the first head of each group", "get mean of each group", "get sum of each group", or repeat the key/value states in cache, h2o gives me an unacceptable result like :
I almostly copied these lines so it seems not an implementation mistake.
here is the implementation and DynamicCache is same as transformers' DynamicCache.
The text was updated successfully, but these errors were encountered: