Skip to content

Commit

Permalink
linear -> spline interp for waveform data
Browse files Browse the repository at this point in the history
  • Loading branch information
deepchatterjeeligo committed Nov 5, 2024
1 parent 032639d commit b9f6863
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
5 changes: 4 additions & 1 deletion ml4gw/waveforms/phenom_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from jaxtyping import Float

from ml4gw.constants import MTSUN_SI, PI
from ml4gw.transforms import SplineInterpolate
from ml4gw.types import BatchTensor, FrequencySeries1d

from .phenom_d_data import QNMData_a, QNMData_fdamp, QNMData_fring
Expand Down Expand Up @@ -583,7 +584,9 @@ def fring_fdamp(self, eta, eta2, chi1, chi2):
finspin = self.FinalSpin0815(eta, eta2, chi1, chi2)
Erad = self.PhenomInternal_EradRational0815(eta, eta2, chi1, chi2)

fRD, fDM = self._linear_interp_finspin(finspin)
spline_interpolant = SplineInterpolate(self.qnmdata_a, kx=3, sx=1e-8)
fRD = spline_interpolant(self.qnmdata_fring, x_out=finspin).squeeze()
fDM = spline_interpolant(self.qnmdata_fdamp, x_out=finspin).squeeze()
fRD /= 1.0 - Erad
fDM /= 1.0 - Erad

Expand Down
16 changes: 8 additions & 8 deletions tests/waveforms/test_cbc_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,16 @@ def test_phenom_d(
hc_torch = hc_torch[torch_mask]

assert np.allclose(
1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=2e-4
1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=3e-4
)
assert np.allclose(
1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=2e-4
1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=3e-4
)
assert np.allclose(
1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=2e-4
1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=3e-4
)
assert np.allclose(
1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=2e-4
1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=3e-4
)


Expand Down Expand Up @@ -342,14 +342,14 @@ def test_phenom_p(chirp_mass, mass_ratio, chi1z, chi2z, distance, sample_rate):
hc_torch = hc_torch[torch_mask]

assert np.allclose(
1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=2e-3
1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=3e-3
)
assert np.allclose(
1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=2e-3
1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=3e-3
)
assert np.allclose(
1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=2e-3
1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=3e-3
)
assert np.allclose(
1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=2e-3
1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=3e-3
)

0 comments on commit b9f6863

Please sign in to comment.