Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduced a format for displaying the DCSR_matrix #1176

Closed
wants to merge 9 commits into from
5 changes: 1 addition & 4 deletions heat/sparse/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,10 @@ def __binary_op_csr(
else:
result = operation(t1.larray.to(promoted_type), t2.larray.to(promoted_type), **fn_kwargs)

output_gnnz = torch.tensor(result._nnz())
if output_split is not None:
output_gnnz = torch.tensor(result._nnz())
output_comm.Allreduce(MPI.IN_PLACE, output_gnnz, MPI.SUM)
output_gnnz = output_gnnz.item()
else:
output_gnnz = torch.tensor(result._nnz())

output_type = types.canonical_heat_type(result.dtype)

if out is None:
Expand Down
19 changes: 19 additions & 0 deletions heat/sparse/dcsr_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,22 @@ def __repr__(self) -> str:
if self.comm.rank != 0:
return ""
return print_string

def __str__(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better if this function is moved to the core/printing module. There are cases like local print vs global print to be handled. See the __str__ method of the DNDarray class for more information.

And it is usually best to use the printing format used by PyTorch since it gives more information to the user. The information outputted by __str__ should prioritize user-readability more than anything else. Just the indptr and indices would be too difficult to understand. I, personally, would prefer to see the exact coordinates being printed out. I say this because the users of Heat shouldn't be expected to know about the CSR matrix and its format. They should just be able to use this more-efficient data structure without having to learn too much about it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What @mrfh92 says also seems acceptable to me. Maybe the users won't want to see the actual data points. But my only problem with this solution is, the users wouldn't be able to see the data points even if they wanted to. There is no other way for them to get the coordinates if we don't show it in the print output.

"""
Computes a string representation of the passed ``DCSR_matrix``.
"""
size = self.__gshape
nnz = self.__gnnz

print_string = (
f"DCSR_matrix(indices={self.indices},\n"
f" indptr={self.indptr},\n"
f" data={self.data},\n"
f" size={size}, nnz={nnz}, split={self.__split})"
)

if self.comm.rank != 0:
return ""

return print_string