From 9b60fad412795cdecd76397dea90285806dfd1b2 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Wed, 5 Jul 2023 00:54:04 +0530 Subject: [PATCH 1/5] implemented a basic format for displaying DCSR_matrix. --- heat/sparse/_operations.py | 5 +---- heat/sparse/dcsr_matrix.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/heat/sparse/_operations.py b/heat/sparse/_operations.py index 1d38114955..94946a1dde 100644 --- a/heat/sparse/_operations.py +++ b/heat/sparse/_operations.py @@ -135,13 +135,10 @@ def __binary_op_csr( else: result = operation(t1.larray.to(promoted_type), t2.larray.to(promoted_type), **fn_kwargs) + output_gnnz = torch.tensor(result._nnz()) if output_split is not None: - output_gnnz = torch.tensor(result._nnz()) output_comm.Allreduce(MPI.IN_PLACE, output_gnnz, MPI.SUM) output_gnnz = output_gnnz.item() - else: - output_gnnz = torch.tensor(result._nnz()) - output_type = types.canonical_heat_type(result.dtype) if out is None: diff --git a/heat/sparse/dcsr_matrix.py b/heat/sparse/dcsr_matrix.py index 468ab84c27..90253229a4 100644 --- a/heat/sparse/dcsr_matrix.py +++ b/heat/sparse/dcsr_matrix.py @@ -337,3 +337,21 @@ def __repr__(self) -> str: if self.comm.rank != 0: return "" return print_string + + def __str__(self) -> str: + """ + Computes a string representation of the passed ``DCSR_matrix``. + """ + size = self.__gshape + nnz = self.__gnnz + + print_string = ( + f"indices={self.indices}, " + f"data={self.data}, DCSR_matrix(size={size}, " + f"nnz={nnz}, split={self.__split})" + ) + + if self.__comm.rank != 0: + return "" + + return print_string From 70ea8995226ef9c4fbd866be3fbe70f9c69f4ba6 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Wed, 5 Jul 2023 11:03:55 +0530 Subject: [PATCH 2/5] Updated the format slightly. --- heat/sparse/dcsr_matrix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/sparse/dcsr_matrix.py b/heat/sparse/dcsr_matrix.py index 90253229a4..ee121739ce 100644 --- a/heat/sparse/dcsr_matrix.py +++ b/heat/sparse/dcsr_matrix.py @@ -347,8 +347,8 @@ def __str__(self) -> str: print_string = ( f"indices={self.indices}, " - f"data={self.data}, DCSR_matrix(size={size}, " - f"nnz={nnz}, split={self.__split})" + f"data={self.data}, " + f"size={size}, nnz={nnz}, split={self.__split})" ) if self.__comm.rank != 0: From d4652ea1bce089ea14797944aa3dad8bf754b6ed Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Wed, 5 Jul 2023 11:17:58 +0530 Subject: [PATCH 3/5] Corrected a mistake in if condition. --- heat/sparse/dcsr_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/sparse/dcsr_matrix.py b/heat/sparse/dcsr_matrix.py index ee121739ce..9c99abd25d 100644 --- a/heat/sparse/dcsr_matrix.py +++ b/heat/sparse/dcsr_matrix.py @@ -351,7 +351,7 @@ def __str__(self) -> str: f"size={size}, nnz={nnz}, split={self.__split})" ) - if self.__comm.rank != 0: + if self.comm.rank != 0: return "" return print_string From 3f85926061ebfc1de9f3a755a5d9fbd5ae715357 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Sun, 9 Jul 2023 22:06:18 +0530 Subject: [PATCH 4/5] Updated the format slightly. --- heat/sparse/dcsr_matrix.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/heat/sparse/dcsr_matrix.py b/heat/sparse/dcsr_matrix.py index 9c99abd25d..577a36a489 100644 --- a/heat/sparse/dcsr_matrix.py +++ b/heat/sparse/dcsr_matrix.py @@ -346,9 +346,9 @@ def __str__(self) -> str: nnz = self.__gnnz print_string = ( - f"indices={self.indices}, " - f"data={self.data}, " - f"size={size}, nnz={nnz}, split={self.__split})" + f"DCSR_matrix(indices={self.indices},\n" + f" data={self.data},\n" + f" size={size}, nnz={nnz}, split={self.__split})" ) if self.comm.rank != 0: From 539130c6ec56e5186ba857f6a23e5ee5d8466e53 Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Sat, 22 Jul 2023 12:12:46 +0530 Subject: [PATCH 5/5] Updated the format to include indptr as well.. --- heat/sparse/dcsr_matrix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/heat/sparse/dcsr_matrix.py b/heat/sparse/dcsr_matrix.py index 577a36a489..38c0245ea9 100644 --- a/heat/sparse/dcsr_matrix.py +++ b/heat/sparse/dcsr_matrix.py @@ -347,6 +347,7 @@ def __str__(self) -> str: print_string = ( f"DCSR_matrix(indices={self.indices},\n" + f" indptr={self.indptr},\n" f" data={self.data},\n" f" size={size}, nnz={nnz}, split={self.__split})" )