Skip to content

Commit

Permalink
Handle dtypes during direcsum
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Ordonez <[email protected]>
  • Loading branch information
Danfoa committed May 30, 2023
1 parent 7340d57 commit 60ce09c
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions escnn/group/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,16 @@ def directsum(reprs: List[escnn.group.Representation],
irreps += r.irreps

size = sum([r.size for r in reprs])

cob = np.zeros((size, size))
cob_inv = np.zeros((size, size))

# Determine the dtype for the change of basis diagonal matrix to avoid unsafe casting.
dtype = np.complex if np.any([np.iscomplexobj(rep.change_of_basis) for rep in reprs]) else np.float

cob = np.zeros((size, size), dtype=dtype)
cob_inv = np.zeros((size, size), dtype=dtype)
p = 0
for r in reprs:
assert np.can_cast(r.change_of_basis.dtype, cob.dtype), \
f"Cannot safely cast {r.change_of_basis.dtype} to {cob.dtype}"
cob[p:p + r.size, p:p + r.size] = r.change_of_basis
cob_inv[p:p + r.size, p:p + r.size] = r.change_of_basis_inv
p += r.size
Expand All @@ -580,7 +585,9 @@ def directsum(reprs: List[escnn.group.Representation],

supported_nonlinearities = set.intersection(*[r.supported_nonlinearities for r in reprs])

return Representation(group, name, irreps, change_of_basis, supported_nonlinearities, change_of_basis_inv=change_of_basis_inv)
return Representation(
group, name, irreps, change_of_basis, supported_nonlinearities, change_of_basis_inv=change_of_basis_inv
)


def disentangle(repr: Representation) -> Tuple[np.ndarray, List[Representation]]:
Expand Down

0 comments on commit 60ce09c

Please sign in to comment.