Skip to content

Commit

Permalink
Merge pull request #74 from AlexImmer/fix-device-bug
Browse files Browse the repository at this point in the history
Fix device bug in eig_lowrank
  • Loading branch information
runame authored Dec 23, 2021
2 parents 7e42de8 + 0b2139b commit 6af87e4
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions laplace/curvature/asdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ def eig_lowrank(self, data_loader):
mask = (eigvals > EPS)
eigvecs = torch.stack([torch.cat([p.flatten() for p in params])
for params in eigvecs], dim=1)[:, mask]
eigvals = eigvals[mask].to(eigvecs.dtype).to(eigvecs.device)
loss = sum([self.lossfunc(self.model(x).detach(), y) for x, y in data_loader])
device = eigvecs.device
eigvals = eigvals[mask].to(eigvecs.dtype).to(device)
loss = sum([self.lossfunc(self.model(x.to(device)).detach(), y.to(device)) for x, y in data_loader])
return eigvecs, self.factor * eigvals, self.factor * loss


Expand Down

0 comments on commit 6af87e4

Please sign in to comment.