Skip to content

Commit

Permalink
Update modify_llama.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyriection authored Apr 4, 2024
1 parent 96c5ad0 commit 281ffef
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions h2o_hf/utils_lm_eval/modify_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,40 @@ def forward(
heavy_budget = int(self.heavy_budget_ratio * attn_weights.shape[-1])
recent_budget = int(self.recent_budget_ratio * attn_weights.shape[-1])

# Heavy Hitter Mask
if heavy_budget > 0:
mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget) # Default: No padding applied to input
else:
mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)

# # Heavy Hitter Mask (Based on local statistics)
# if heavy_budget > 0:
# mask_bottom = local_heavy_hitter_mask(attn_weights, heavy_budget) # Default: No padding applied to input
# else:
# mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)

# ones = torch.ones_like(attn_weights, dtype=torch.bool)
# ones = torch.triu(ones, diagonal=-recent_budget)
# mask_bottom = torch.logical_or(mask_bottom, ones)

# mask_bottom = torch.tril(mask_bottom, diagonal=0)

# # mask_bottom = ones
# attn_weights[~mask_bottom] = torch.min(attention_mask)



# Heavy Hitter Mask (Based on global statistics)
tmp_attn = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(attn_weights.dtype)
tmp_sum = torch.sum(tmp_attn, dim=-2)
_, tmp_topk = tmp_sum.topk(k=heavy_budget, dim=-1)

zeros = torch.zeros_like(tmp_sum, dtype=torch.bool)
mask_bottom = zeros.scatter(-1, tmp_topk, True).unsqueeze(2)
mask_bottom = mask_bottom.expand(mask_bottom.shape[0], mask_bottom.shape[1], attn_weights.shape[-2], mask_bottom.shape[-1])

ones = torch.ones_like(attn_weights, dtype=torch.bool)
ones = torch.tril(ones, diagonal=recent_budget)
ones = torch.triu(ones, diagonal=-recent_budget)
mask_bottom = torch.logical_or(mask_bottom, ones)

mask_bottom = torch.tril(mask_bottom, diagonal=0)

# mask_bottom = ones
attn_weights[~mask_bottom] = torch.min(attention_mask)
attn_weights[~mask_bottom] = torch.finfo(attn_weights.dtype).min


# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
Expand Down

0 comments on commit 281ffef

Please sign in to comment.