diff --git a/heat/core/linalg/basics.py b/heat/core/linalg/basics.py index 53e5e94e8..2f57fe774 100644 --- a/heat/core/linalg/basics.py +++ b/heat/core/linalg/basics.py @@ -430,9 +430,9 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: Parameters ----------- a : DNDarray - matrix :math:`L \\times P` or vector :math:`P` or batch of matrices/vectors: :math:`B_1 \\times ... \\times B_k [\\times L] \\times P` + matrix :math:`L \\times P` or vector :math:`P` or batch of matrices: :math:`B_1 \\times ... \\times B_k \\times L \\times P` b : DNDarray - matrix :math:`P \\times Q` or vector :math:`P` or batch of matrices/vectors: :math:`B_1 \\times ... \\times B_k \\times P [\\times Q]` + matrix :math:`P \\times Q` or vector :math:`P` or batch of matrices: :math:`B_1 \\times ... \\times B_k \\times P \\times Q` allow_resplit : bool, optional Whether to distribute ``a`` in the case that both ``a.split is None`` and ``b.split is None``. Default is ``False``. If ``True``, if both are not split then ``a`` will be distributed in-place along axis 0. @@ -440,7 +440,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: Notes ----------- - For batched inputs, batch dimensions must coincide and if one matrix is split along a batch axis the other must be split along the same axis. - - If ``a`` or ``b`` is a (possibly batched) vector the result will also be a (possibly batched) vector. + - If ``a`` or ``b`` is a vector the result will also be a vector. - We recommend to avoid the particular split combinations ``1``-``0``, ``None``-``0``, and ``1``-``None`` (for ``a.split``-``b.split``) due to their comparably high memory consumption, if possible. Applying ``DNDarray.resplit_`` or ``heat.resplit`` on one of the two factors before calling ``matmul`` in these situations might improve performance of your code / might avoid memory bottlenecks. References @@ -529,6 +529,10 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: raise NotImplementedError( "Both input matrices have to be split along the same batch axis!" ) + if vector_flag: # batched matrix vector multiplication not supported + raise NotImplementedError( + "Batched matrix-vector multiplication is not supported, try using expand_dims to make it a batched matrix-matrix multiplication." + ) comm = a.comm ndim = max(a.ndim, b.ndim) diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index 39fc2583b..870d671f6 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -827,14 +827,16 @@ def test_matmul(self): a = ht.zeros((3, 3, 3), split=0) b = ht.zeros((4, 3, 3), split=0) ht.matmul(a, b) - # not implemented split - """ - todo + # split along different batch dimension with self.assertRaises(NotImplementedError): - a = ht.zeros((3, 3, 3)) - b = ht.zeros((3, 3, 3)) + a = ht.zeros((4, 3, 3, 3), split=0) + b = ht.zeros((4, 3, 3, 3), split=1) + ht.matmul(a, b) + # batched matrix-vector multiplication + with self.assertRaises(NotImplementedError): + a = ht.zeros((3, 3, 3), split=0) + b = ht.zeros((3, 3), split=0) ht.matmul(a, b) - """ # batched, split batch n = 11 # number of batches