Skip to content

Commit

Permalink
Merge pull request #1518 from cornellius-gp/lazy_evaluated_kernel_ten…
Browse files Browse the repository at this point in the history
…sor_grad

Ensure LazyEvaluatedKernelTensor requires grad.
  • Loading branch information
Balandat authored Mar 18, 2021
2 parents dbfd5ac + 878dca9 commit 011679a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
18 changes: 14 additions & 4 deletions gpytorch/lazy/lazy_evaluated_kernel_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ def dtype(self):
def device(self):
return self.x1.device

@property
def requires_grad(self):
return super().requires_grad or any(param.requires_grad for param in self.kernel.parameters())

def _set_requires_grad(self, val):
super()._set_requires_grad(val)
# The behavior that differs from the base LazyTensor setter
for param in self.kernel.parameters():
param.requires_grad_(val)

def _expand_batch(self, batch_shape):
return self.evaluate_kernel()._expand_batch(batch_shape)

Expand Down Expand Up @@ -91,8 +101,8 @@ def _getitem(self, row_index, col_index, *batch_indices):
except IndexError:
if any(not isinstance(bi, slice) for bi in batch_indices):
raise RuntimeError(
f"Attempting to tensor index a non-batch matrix's batch dimensions. "
"Got batch index {batch_indices} but my shape was {self.shape}"
"Attempting to tensor index a non-batch matrix's batch dimensions. "
f"Got batch index {batch_indices} but my shape was {self.shape}"
)
x1 = x1.expand(*([1] * (len(batch_indices) - self.x1.dim() + 2)), *self.x1.shape)
x1 = x1[(*batch_indices, row_index, dim_index)]
Expand All @@ -105,8 +115,8 @@ def _getitem(self, row_index, col_index, *batch_indices):
except IndexError:
if any([not isinstance(bi, slice) for bi in batch_indices]):
raise RuntimeError(
f"Attempting to tensor index a non-batch matrix's batch dimensions. "
"Got batch index {batch_indices} but my shape was {self.shape}"
"Attempting to tensor index a non-batch matrix's batch dimensions. "
f"Got batch index {batch_indices} but my shape was {self.shape}"
)
x2 = x2.expand(*([1] * (len(batch_indices) - self.x1.dim() + 2)), *self.x2.shape)
x2 = x2[(*batch_indices, col_index, dim_index)]
Expand Down
16 changes: 11 additions & 5 deletions gpytorch/lazy/lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,22 +1529,28 @@ def requires_grad(self):
if hasattr(arg, "requires_grad")
)

@requires_grad.setter
def requires_grad(self, val):
def _set_requires_grad(self, val):
# Note: subclasses should overwrite this method, not the requires_grad.setter
for arg in self._args:
if hasattr(arg, "requires_grad"):
if arg.dtype in (torch.float, torch.double, torch.half):
arg.requires_grad = val
arg.requires_grad_(val)
for arg in self._kwargs.values():
if hasattr(arg, "requires_grad"):
arg.requires_grad = val
arg.requires_grad_(val)

@requires_grad.setter
def requires_grad(self, val):
# Note: subclasses cannot overwrite this method
# To change the setter behavior, overwrite the _set_requires_grad method instead
self._set_requires_grad(val)

def requires_grad_(self, val):
"""
Sets `requires_grad=val` on all the Tensors that make up the LazyTensor
This is an inplace operation.
"""
self.requires_grad = val
self._set_requires_grad(val)
return self

@cached(name="diagonalization")
Expand Down

0 comments on commit 011679a

Please sign in to comment.