Skip to content

Commit

Permalink
Kernels can return tuples from num_outputs_per_input (#1849)
Browse files Browse the repository at this point in the history
* Kernels can now return tuples from num_outputs_per_input

* caught a couple fixes that we missed earlier

Co-authored-by: Misha Padidar <[email protected]>
  • Loading branch information
jacobrgardner and mishapadidar authored Dec 3, 2021
1 parent 0e74f2b commit bf13e7a
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions gpytorch/lazy/lazy_evaluated_kernel_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,18 @@ def _getitem(self, row_index, col_index, *batch_indices):
x1 = self.x1
x2 = self.x2
num_outs_per_in = self.kernel.num_outputs_per_input(x1, x2)
if isinstance(num_outs_per_in, tuple):
num_outs_per_in_rows, num_outs_per_in_cols = num_outs_per_in
else:
num_outs_per_in_rows = num_outs_per_in
num_outs_per_in_cols = num_outs_per_in

# The row index and col index should exactly correspond to which entries of x1 and x2 we need
# So we'll basically call x1[*batch_indices, row_index, :], x2[*batch_indices, col_index, :]

# However - if we have multiple outputs per input, then the indices won't directly
# correspond to the entries of row/col. We'll have to do a little pre-processing
if num_outs_per_in != 1:
if num_outs_per_in_rows != 1 or num_outs_per_in_cols != 1:
if not isinstance(x1, slice) or not isinstance(x2, slice):
# It's too complicated to deal with tensor indices in this case - we'll use the super method
return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices)
Expand All @@ -81,16 +86,16 @@ def _getitem(self, row_index, col_index, *batch_indices):
if row_step is not None or col_step is not None:
return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices)
if (
(row_start % num_outs_per_in)
or (col_start % num_outs_per_in)
or (row_end % num_outs_per_in)
or (col_end % num_outs_per_in)
(row_start % num_outs_per_in_rows)
or (col_start % num_outs_per_in_cols)
or (row_end % num_outs_per_in_rows)
or (col_end % num_outs_per_in_cols)
):
return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices)

# Otherwise - let's divide the slices by the number of outputs per input
row_index = slice(row_start // num_outs_per_in, row_end // num_outs_per_in, None)
col_index = slice(col_start // num_outs_per_in, col_end // num_outs_per_in, None)
row_index = slice(row_start // num_outs_per_in_rows, row_end // num_outs_per_in_rows, None)
col_index = slice(col_start // num_outs_per_in_cols, col_end // num_outs_per_in_cols, None)

# Define the index we're using for the last index
# If the last index corresponds to a batch, then we'll use the appropriate batch_index
Expand Down Expand Up @@ -220,9 +225,14 @@ def _size(self):

x1 = self.x1
x2 = self.x2
num_outputs_per_input = self.kernel.num_outputs_per_input(x1, x2)
num_rows = x1.size(-2) * num_outputs_per_input
num_cols = x2.size(-2) * num_outputs_per_input
num_outs_per_in = self.kernel.num_outputs_per_input(x1, x2)
if isinstance(num_outs_per_in, tuple):
num_outs_per_in_rows, num_outs_per_in_cols = num_outs_per_in
else:
num_outs_per_in_rows = num_outs_per_in
num_outs_per_in_cols = num_outs_per_in
num_rows = x1.size(-2) * num_outs_per_in_rows
num_cols = x2.size(-2) * num_outs_per_in_cols

# Default case - when we're not using broadcasting
# We write this case special for efficiency
Expand Down

0 comments on commit bf13e7a

Please sign in to comment.