Skip to content

Commit

Permalink
Update group attention, optimize gpu memory usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander0Yang committed Jan 30, 2024
1 parent 39918a2 commit 25f4557
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions projects/mmdet3d_plugin/models/utils/encoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ class BEVWarp(nn.Module):
def __init__(self):
super().__init__()

# 可以with no grad
@torch.no_grad()
def forward(self, lidar_feats, img_feats, img_metas, pts_metas, **kwargs):
batch_size, num_views, I_C, I_H, I_W = img_feats.shape
lidar2img = []
Expand Down Expand Up @@ -224,6 +222,37 @@ def __init__(self, pts_channels, img_channels, dropout):
self.dropout = dropout
self.learnedAlign = nn.MultiheadAttention(pts_channels, 1, dropout=dropout,
kdim=img_channels, vdim=img_channels, batch_first=True)

def group_attn(self, Q, K, V, attn_mask=None, groups=[20, 40, 80, 120]):
out_tensor = Q.new_zeros(Q.shape)
group_sum = (~attn_mask).sum(-1).squeeze(-1)
s = 0
for e in groups:
group_mask = (group_sum > s) & (group_sum <= e)
if group_mask.sum() == 0:
s = e
continue
group_Q, group_K, group_V = Q[group_mask], K[group_mask], V[group_mask]
group_attn_mask = attn_mask[group_mask]
new_group_K = torch.cat([group_K, group_K.new_zeros(1, 1, 1).expand(group_K.shape[0], e, group_K.shape[2])], 1)
new_group_V = torch.cat([group_V, group_V.new_zeros(1, 1, 1).expand(group_V.shape[0], e, group_V.shape[2])], 1)

new_group_sum = group_sum[group_mask]
group_padding_num = e - new_group_sum
padding_mask = group_attn_mask.new_zeros(group_attn_mask.shape[0], e+1)
padding_mask[torch.arange(padding_mask.shape[0],device=padding_mask.device).long(),
group_padding_num.long()] = 1
padding_mask = padding_mask.cumsum(dim=1).bool().unsqueeze(1)[..., :-1]
padded_group_mask = torch.cat([group_attn_mask, padding_mask], -1)

new_group_K = new_group_K[~padded_group_mask.squeeze(1)].reshape(group_Q.shape[0], e, -1)
new_group_V = new_group_V[~padded_group_mask.squeeze(1)].reshape(group_Q.shape[0], e, -1)

new_groud_attn_mask = (~padding_mask).flip(-1)
group_out = self.learnedAlign(group_Q, new_group_K, new_group_V, attn_mask=new_groud_attn_mask)[0]
out_tensor[group_mask] = group_out
s = e
return out_tensor

def forward(self, lidar_feat, img_feat, img_metas, pts_metas, **kwargs):
batch_size = len(img_metas)
Expand Down Expand Up @@ -284,7 +313,8 @@ def forward(self, lidar_feat, img_feat, img_metas, pts_metas, **kwargs):
Q = lidar_feat[b,:,voxel_coor[:,2].long(),voxel_coor[:,3].long()].t().unsqueeze(1)
valid = mask[...,0].sum(dim=1) > 0
attn_output = lidar_feat.new_zeros(num_voxels, 1, self.pts_channels)
attn_output[valid] = self.learnedAlign(Q[valid],K[valid],V[valid],attn_mask=(~mask[valid]).permute(0,2,1))[0]
attn_output[valid] = self.group_attn(Q[valid],K[valid],V[valid],attn_mask=(~mask[valid]).permute(0,2,1))
# attn_output[valid] = self.learnedAlign(Q[valid],K[valid],V[valid],attn_mask=(~mask[valid]).permute(0,2,1))[0]
decorated_lidar_feat[b,:,voxel_coor[:,2].long(),voxel_coor[:,3].long()] = attn_output.squeeze(1).t()
cur_start = cur_end
return decorated_lidar_feat
Empty file modified tools/dist_train.sh
100644 → 100755
Empty file.

0 comments on commit 25f4557

Please sign in to comment.