diff --git a/makani/models/common/spectral_convolution.py b/makani/models/common/spectral_convolution.py index 57d768f..69064c0 100644 --- a/makani/models/common/spectral_convolution.py +++ b/makani/models/common/spectral_convolution.py @@ -99,7 +99,7 @@ def __init__(self, forward_transform, inverse_transform, in_channels, out_channe # seemingly the first weight is not really complex, so we need to account for that scale[0] *= math.sqrt(2.0) init = scale * torch.randn(*weight_shape, dtype=torch.complex64) - self.weight = nn.Parameter(init) + self.weight = nn.Parameter(torch.view_as_real(init)) if self.operator_type == "dhconv": self.weight.is_shared_mp = ["matmul", "w"] @@ -134,7 +134,7 @@ def forward(self, x): residual = residual.to(dtype) # approach with unpadded weights - xp = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type) + xp = self._contract(x, torch.view_as_complex(self.weight), separable=self.separable, operator_type=self.operator_type) x = xp.contiguous() with amp.autocast(enabled=False):