From 60ce09c8db16c86834dc256a7cd2b907c44c6a56 Mon Sep 17 00:00:00 2001 From: Daniel Ordonez Date: Tue, 30 May 2023 11:13:11 +0200 Subject: [PATCH] Handle dtypes during direcsum Signed-off-by: Daniel Ordonez --- escnn/group/representation.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/escnn/group/representation.py b/escnn/group/representation.py index 1b59be8f..73174bf6 100644 --- a/escnn/group/representation.py +++ b/escnn/group/representation.py @@ -562,11 +562,16 @@ def directsum(reprs: List[escnn.group.Representation], irreps += r.irreps size = sum([r.size for r in reprs]) - - cob = np.zeros((size, size)) - cob_inv = np.zeros((size, size)) + + # Determine the dtype for the change of basis diagonal matrix to avoid unsafe casting. + dtype = np.complex if np.any([np.iscomplexobj(rep.change_of_basis) for rep in reprs]) else np.float + + cob = np.zeros((size, size), dtype=dtype) + cob_inv = np.zeros((size, size), dtype=dtype) p = 0 for r in reprs: + assert np.can_cast(r.change_of_basis.dtype, cob.dtype), \ + f"Cannot safely cast {r.change_of_basis.dtype} to {cob.dtype}" cob[p:p + r.size, p:p + r.size] = r.change_of_basis cob_inv[p:p + r.size, p:p + r.size] = r.change_of_basis_inv p += r.size @@ -580,7 +585,9 @@ def directsum(reprs: List[escnn.group.Representation], supported_nonlinearities = set.intersection(*[r.supported_nonlinearities for r in reprs]) - return Representation(group, name, irreps, change_of_basis, supported_nonlinearities, change_of_basis_inv=change_of_basis_inv) + return Representation( + group, name, irreps, change_of_basis, supported_nonlinearities, change_of_basis_inv=change_of_basis_inv + ) def disentangle(repr: Representation) -> Tuple[np.ndarray, List[Representation]]: