diff --git a/heat/sparse/_operations.py b/heat/sparse/_operations.py index 1d38114955..94946a1dde 100644 --- a/heat/sparse/_operations.py +++ b/heat/sparse/_operations.py @@ -135,13 +135,10 @@ def __binary_op_csr( else: result = operation(t1.larray.to(promoted_type), t2.larray.to(promoted_type), **fn_kwargs) + output_gnnz = torch.tensor(result._nnz()) if output_split is not None: - output_gnnz = torch.tensor(result._nnz()) output_comm.Allreduce(MPI.IN_PLACE, output_gnnz, MPI.SUM) output_gnnz = output_gnnz.item() - else: - output_gnnz = torch.tensor(result._nnz()) - output_type = types.canonical_heat_type(result.dtype) if out is None: diff --git a/heat/sparse/dcsr_matrix.py b/heat/sparse/dcsr_matrix.py index 468ab84c27..38c0245ea9 100644 --- a/heat/sparse/dcsr_matrix.py +++ b/heat/sparse/dcsr_matrix.py @@ -337,3 +337,22 @@ def __repr__(self) -> str: if self.comm.rank != 0: return "" return print_string + + def __str__(self) -> str: + """ + Computes a string representation of the passed ``DCSR_matrix``. + """ + size = self.__gshape + nnz = self.__gnnz + + print_string = ( + f"DCSR_matrix(indices={self.indices},\n" + f" indptr={self.indptr},\n" + f" data={self.data},\n" + f" size={size}, nnz={nnz}, split={self.__split})" + ) + + if self.comm.rank != 0: + return "" + + return print_string