diff --git a/h2o_hf/utils_lm_eval/modify_llama.py b/h2o_hf/utils_lm_eval/modify_llama.py index ea037a5..511eae3 100644 --- a/h2o_hf/utils_lm_eval/modify_llama.py +++ b/h2o_hf/utils_lm_eval/modify_llama.py @@ -131,20 +131,40 @@ def forward( heavy_budget = int(self.heavy_budget_ratio * attn_weights.shape[-1]) recent_budget = int(self.recent_budget_ratio * attn_weights.shape[-1]) - # Heavy Hitter Mask - if heavy_budget > 0: - mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget) # Default: No padding applied to input - else: - mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool) + + # # Heavy Hitter Mask (Based on local statistics) + # if heavy_budget > 0: + # mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget) # Default: No padding applied to input + # else: + # mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool) + + # ones = torch.ones_like(attn_weights, dtype=torch.bool) + # ones = torch.triu(ones, diagonal=-recent_budget) + # mask_bottom = torch.logical_or(mask_bottom, ones) + + # mask_bottom = torch.tril(mask_bottom, diagonal=0) + + # # mask_bottom = ones + # attn_weights[~mask_bottom] = torch.min(attention_mask) + + + + # Heavy Hitter Mask (Based on global statistics) + tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(attn_weights.dtype) + tmp_sum = torch.sum(tmp_attn, dim=-2) + _, tmp_topk = tmp_sum.topk(k=heavy_budget, dim=-1) + + zeros = torch.zeros_like(tmp_sum, dtype=torch.bool) + mask_bottom = zeros.scatter(-1, tmp_topk, True).unsqueeze(2) + mask_bottom = mask_bottom.expand(mask_bottom.shape[0], mask_bottom.shape[1], attn_weights.shape[-2], mask_bottom.shape[-1]) ones = torch.ones_like(attn_weights, dtype=torch.bool) + ones = torch.tril(ones, diagonal=recent_budget) ones = torch.triu(ones, diagonal=-recent_budget) mask_bottom = torch.logical_or(mask_bottom, ones) - - mask_bottom = torch.tril(mask_bottom, diagonal=0) - # mask_bottom = ones - attn_weights[~mask_bottom] = torch.min(attention_mask) + attn_weights[~mask_bottom] = torch.finfo(attn_weights.dtype).min + # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)