diff --git a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py index 5c0e3f4cd..e597a3471 100644 --- a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py +++ b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py @@ -54,13 +54,18 @@ def _getitem(self, row_index, col_index, *batch_indices): x1 = self.x1 x2 = self.x2 num_outs_per_in = self.kernel.num_outputs_per_input(x1, x2) + if isinstance(num_outs_per_in, tuple): + num_outs_per_in_rows, num_outs_per_in_cols = num_outs_per_in + else: + num_outs_per_in_rows = num_outs_per_in + num_outs_per_in_cols = num_outs_per_in # The row index and col index should exactly correspond to which entries of x1 and x2 we need # So we'll basically call x1[*batch_indices, row_index, :], x2[*batch_indices, col_index, :] # However - if we have multiple outputs per input, then the indices won't directly # correspond to the entries of row/col. We'll have to do a little pre-processing - if num_outs_per_in != 1: + if num_outs_per_in_rows != 1 or num_outs_per_in_cols != 1: if not isinstance(x1, slice) or not isinstance(x2, slice): # It's too complicated to deal with tensor indices in this case - we'll use the super method return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices) @@ -81,16 +86,16 @@ def _getitem(self, row_index, col_index, *batch_indices): if row_step is not None or col_step is not None: return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices) if ( - (row_start % num_outs_per_in) - or (col_start % num_outs_per_in) - or (row_end % num_outs_per_in) - or (col_end % num_outs_per_in) + (row_start % num_outs_per_in_rows) + or (col_start % num_outs_per_in_cols) + or (row_end % num_outs_per_in_rows) + or (col_end % num_outs_per_in_cols) ): return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices) # Otherwise - let's divide the slices by the number of outputs per input - row_index = slice(row_start // num_outs_per_in, row_end // num_outs_per_in, None) - col_index = slice(col_start // num_outs_per_in, col_end // num_outs_per_in, None) + row_index = slice(row_start // num_outs_per_in_rows, row_end // num_outs_per_in_rows, None) + col_index = slice(col_start // num_outs_per_in_cols, col_end // num_outs_per_in_cols, None) # Define the index we're using for the last index # If the last index corresponds to a batch, then we'll use the appropriate batch_index @@ -220,9 +225,14 @@ def _size(self): x1 = self.x1 x2 = self.x2 - num_outputs_per_input = self.kernel.num_outputs_per_input(x1, x2) - num_rows = x1.size(-2) * num_outputs_per_input - num_cols = x2.size(-2) * num_outputs_per_input + num_outs_per_in = self.kernel.num_outputs_per_input(x1, x2) + if isinstance(num_outs_per_in, tuple): + num_outs_per_in_rows, num_outs_per_in_cols = num_outs_per_in + else: + num_outs_per_in_rows = num_outs_per_in + num_outs_per_in_cols = num_outs_per_in + num_rows = x1.size(-2) * num_outs_per_in_rows + num_cols = x2.size(-2) * num_outs_per_in_cols # Default case - when we're not using broadcasting # We write this case special for efficiency