Skip to content

Commit

Permalink
fixing 3d tensor error
Browse files Browse the repository at this point in the history
  • Loading branch information
piergiaj committed Oct 4, 2018
1 parent 78b74aa commit c7ba392
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions super_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def __init__(self, classes=65):
# to take 2xD*3 to D*3
self.cls_wts = nn.Parameter(torch.Tensor(classes))

self.sup_mat = nn.Parameter(torch.Tensor(1, classes, 1024))
self.sup_mat = nn.Parameter(torch.Tensor(1, classes, 512*3))
stdv = 1./np.sqrt(1024+1024)
self.sup_mat.data.uniform_(-stdv, stdv)

self.per_frame = nn.Conv3d(1024, classes, (1,1,1))
self.per_frame = nn.Conv3d(512, classes, (1,1,1))
self.per_frame.weight.data.uniform_(-stdv, stdv)
self.per_frame.bias.data.uniform_(-stdv, stdv)
self.add_module('pf', self.per_frame)
Expand All @@ -42,19 +42,22 @@ def forward(self, inp):
val = True
dim = 0

#print inp[0].size()
super_event = self.dropout(torch.stack([self.super_event(inp).squeeze(), self.super_event2(inp).squeeze()], dim=dim))
if val:
super_event = super_event.unsqueeze(0)
# we have B x 2 x D*3
# we want B x C x D*3

#print super_event.size()
# now we have C x 2 matrix
cls_wts = torch.stack([torch.sigmoid(self.cls_wts), 1-torch.sigmoid(self.cls_wts)], dim=1)

# now we do a bmm to get B x C x D*3
#print cls_wts.expand(inp[0].size()[0], -1, -1).size(), super_event.size()
super_event = torch.bmm(cls_wts.expand(inp[0].size()[0], -1, -1), super_event)
del cls_wts

print super_event.size()
# apply the super-event weights
super_event = torch.sum(self.sup_mat * super_event, dim=2)
#super_event = self.sup_mat(super_event.view(-1, 1024)).view(-1, self.classes)
Expand Down
2 changes: 1 addition & 1 deletion temporal_structure_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forward(self, inp):
o = torch.bmm(f, vid.squeeze(2))
del f
del vid
o = o.view(batch, channels, self.Ni).unsqueeze(3).unsqueeze(3)
o = o.view(batch, channels*self.Ni)#.unsqueeze(3).unsqueeze(3)
return o


Expand Down

0 comments on commit c7ba392

Please sign in to comment.