diff --git a/ml4gw/waveforms/phenom_d.py b/ml4gw/waveforms/phenom_d.py index 6630e925..fe3e2160 100644 --- a/ml4gw/waveforms/phenom_d.py +++ b/ml4gw/waveforms/phenom_d.py @@ -2,6 +2,7 @@ from jaxtyping import Float from ml4gw.constants import MTSUN_SI, PI +from ml4gw.transforms import SplineInterpolate from ml4gw.types import BatchTensor, FrequencySeries1d from .phenom_d_data import QNMData_a, QNMData_fdamp, QNMData_fring @@ -583,7 +584,9 @@ def fring_fdamp(self, eta, eta2, chi1, chi2): finspin = self.FinalSpin0815(eta, eta2, chi1, chi2) Erad = self.PhenomInternal_EradRational0815(eta, eta2, chi1, chi2) - fRD, fDM = self._linear_interp_finspin(finspin) + spline_interpolant = SplineInterpolate(self.qnmdata_a, kx=3, sx=1e-8) + fRD = spline_interpolant(self.qnmdata_fring, x_out=finspin).squeeze() + fDM = spline_interpolant(self.qnmdata_fdamp, x_out=finspin).squeeze() fRD /= 1.0 - Erad fDM /= 1.0 - Erad diff --git a/tests/waveforms/test_cbc_waveforms.py b/tests/waveforms/test_cbc_waveforms.py index a8881089..3d880663 100644 --- a/tests/waveforms/test_cbc_waveforms.py +++ b/tests/waveforms/test_cbc_waveforms.py @@ -225,16 +225,16 @@ def test_phenom_d( hc_torch = hc_torch[torch_mask] assert np.allclose( - 1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=2e-4 + 1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=3e-4 ) assert np.allclose( - 1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=2e-4 + 1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=3e-4 ) assert np.allclose( - 1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=2e-4 + 1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=3e-4 ) assert np.allclose( - 1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=2e-4 + 1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=3e-4 ) @@ -342,14 +342,14 @@ def test_phenom_p(chirp_mass, mass_ratio, chi1z, chi2z, distance, sample_rate): hc_torch = hc_torch[torch_mask] assert np.allclose( - 1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=2e-3 + 1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=3e-3 ) assert np.allclose( - 1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=2e-3 + 1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=3e-3 ) assert np.allclose( - 1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=2e-3 + 1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=3e-3 ) assert np.allclose( - 1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=2e-3 + 1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=3e-3 )