Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft: linear -> spline interp for waveform data #162

Draft
wants to merge 16 commits into
base: dev
Choose a base branch
from
Draft
1 change: 1 addition & 0 deletions ml4gw/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .snr_rescaler import SnrRescaler
from .spectral import SpectralDensity
from .spectrogram import MultiResolutionSpectrogram
from .spline_interpolation import SplineInterpolate
from .waveforms import WaveformProjector, WaveformSampler
from .whitening import FixedWhiten, Whiten
176 changes: 134 additions & 42 deletions ml4gw/transforms/qtransform.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import math
from typing import List, Optional, Tuple
import warnings
from typing import List, Tuple

import torch
import torch.nn.functional as F
from jaxtyping import Float, Int
from torch import Tensor

from ml4gw.transforms.spline_interpolation import SplineInterpolate
from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d

"""
Expand Down Expand Up @@ -38,7 +40,6 @@ class QTile(torch.nn.Module):
mismatch:
The maximum fractional mismatch between neighboring tiles


"""

def __init__(
Expand Down Expand Up @@ -100,7 +101,9 @@ def get_data_indices(self) -> Int[Tensor, " windowsize"]:
).type(torch.long)

def forward(
self, fseries: FrequencySeries1to3d, norm: str = "median"
self,
fseries: FrequencySeries1to3d,
norm: str = "median",
) -> TimeSeries1to3d:
"""
Compute the transform for this row
Expand Down Expand Up @@ -144,7 +147,7 @@ def forward(
energy /= means
else:
raise ValueError("Invalid normalisation %r" % norm)
return energy.type(torch.float32)
energy = energy.type(torch.float32)
return energy


Expand Down Expand Up @@ -172,6 +175,19 @@ class SingleQTransform(torch.nn.Module):
be chosen based on q, sample_rate, and duration
mismatch:
The maximum fractional mismatch between neighboring tiles
interpolation_method:
The method by which to interpolate each `QTile` to the specified
number of time and frequency bins. The acceptable values are
"bilinear", "bicubic", and "spline". The "bilinear" and "bicubic"
options will use PyTorch's built-in interpolation modes, while
"spline" will use the custom Torch-based implementation in
`ml4gw`, as PyTorch does not have spline-based intertpolation.
The "spline" mode is most similar to the results of GWpy's
Q-transform, which uses `scipy` to do spline interpolation.
However, it is also the slowest and most memory intensive due to
the matrix equation solving steps. Therefore, the default method
is "bicubic" as it produces the most similar results while
optimizing for computing performance.
"""

def __init__(
Expand All @@ -182,6 +198,7 @@ def __init__(
q: float = 12,
frange: List[float] = [0, torch.inf],
mismatch: float = 0.2,
interpolation_method: str = "bicubic",
) -> None:
super().__init__()
self.q = q
Expand All @@ -190,20 +207,87 @@ def __init__(
self.duration = duration
self.mismatch = mismatch

# If q is too large, the minimum of the frange computed
# below will be larger than the maximum
max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
if q >= max_q:
raise ValueError(
"The given q value is too large for the given duration and "
f"sample rate. The maximum allowable value is {max_q}"
)

if interpolation_method not in ["bilinear", "bicubic", "spline"]:
raise ValueError(
"Interpolation method must be either 'bilinear', 'bicubic', "
f"or 'spline'; got {interpolation_method}"
)
self.interpolation_method = interpolation_method

qprime = self.q / 11 ** (1 / 2.0)
if self.frange[0] <= 0: # set non-zero lower frequency
self.frange[0] = 50 * self.q / (2 * torch.pi * duration)
if math.isinf(self.frange[1]): # set non-infinite upper frequency
self.frange[1] = sample_rate / 2 / (1 + 1 / qprime)

self.freqs = self.get_freqs()
self.qtile_transforms = torch.nn.ModuleList(
[
QTile(self.q, freq, self.duration, sample_rate, self.mismatch)
QTile(
q=self.q,
frequency=freq,
duration=self.duration,
sample_rate=sample_rate,
mismatch=self.mismatch,
)
for freq in self.freqs
]
)
self.qtiles = None

if self.interpolation_method == "spline":
self._set_up_spline_interp()

def _set_up_spline_interp(self):
ntiles = [qtile.ntiles() for qtile in self.qtile_transforms]
# For efficiency, we'll stack all qtiles of the same length before
# interpolating, so we need to figure out which those are
unique_ntiles = sorted(list(set(ntiles)))
idx = torch.arange(len(ntiles))
self.stack_idx = [idx[Tensor(ntiles) == n] for n in unique_ntiles]

t_out = torch.arange(
0, self.duration, self.duration / self.spectrogram_shape[1]
)
self.qtile_interpolators = torch.nn.ModuleList(
[
SplineInterpolate(
kx=3,
x_in=torch.arange(0, self.duration, self.duration / tiles),
y_in=torch.arange(len(idx)),
x_out=t_out,
y_out=torch.arange(len(idx)),
)
for tiles, idx in zip(unique_ntiles, self.stack_idx)
]
)

t_in = t_out
f_in = self.freqs
f_out = torch.logspace(
math.log10(self.frange[0]),
math.log10(self.frange[-1]),
self.spectrogram_shape[0],
)

self.interpolator = SplineInterpolate(
kx=3,
ky=3,
x_in=t_in,
y_in=f_in,
x_out=t_out,
y_out=f_out,
)

def get_freqs(self) -> Float[Tensor, " nfreq"]:
"""
Calculate the frequencies that will be used in this transform.
Expand All @@ -220,7 +304,8 @@ def get_freqs(self) -> Float[Tensor, " nfreq"]:

freq_base = math.exp(2 / ((2 + self.q**2) ** (1 / 2.0)) * fstep)
freqs = torch.Tensor([freq_base ** (i + 0.5) for i in range(nfreq)])
freqs = (minf * freqs // fstepmin) * fstepmin
# Cast freqs to float64 to avoid off-by-ones from rounding
freqs = (minf * freqs.double() // fstepmin) * fstepmin
return torch.unique(freqs)

def get_max_energy(
Expand Down Expand Up @@ -268,7 +353,11 @@ def get_max_energy(
if dimension == "batch":
return torch.max(max_across_ft, dim=-1).values

def compute_qtiles(self, X: TimeSeries1to3d, norm: str = "median") -> None:
def compute_qtiles(
self,
X: TimeSeries1to3d,
norm: str = "median",
) -> None:
"""
Take the FFT of the input timeseries and calculate the transform
for each `QTile`
Expand All @@ -278,36 +367,47 @@ def compute_qtiles(self, X: TimeSeries1to3d, norm: str = "median") -> None:
X[..., 1:] *= 2
self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]

def interpolate(self, num_f_bins: int, num_t_bins: int) -> TimeSeries3d:
"""
Interpolate each `QTile` to the specified number of time and
frequency bins. Note that PyTorch does not have the same
interpolation methods that GWpy uses, and so the interpolated
spectrograms will be different even though the uninterpolated
values match. The `bicubic` interpolation method is used as
it seems to match GWpy most closely.
"""
def interpolate(self) -> TimeSeries3d:
if self.qtiles is None:
raise RuntimeError(
"Q-tiles must first be computed with .compute_qtiles()"
)
if self.interpolation_method == "spline":
qtiles = [
torch.stack([self.qtiles[i] for i in idx], dim=-2)
for idx in self.stack_idx
]
time_interped = torch.cat(
[
interpolator(qtile)
for qtile, interpolator in zip(
qtiles, self.qtile_interpolators
)
],
dim=-2,
)
return self.interpolator(time_interped)
num_f_bins, num_t_bins = self.spectrogram_shape
resampled = [
F.interpolate(
qtile[None], (qtile.shape[-2], num_t_bins), mode="bicubic"
qtile[None],
(qtile.shape[-2], num_t_bins),
mode=self.interpolation_method,
)
for qtile in self.qtiles
]
resampled = torch.stack(resampled, dim=-2)
resampled = F.interpolate(
resampled[0], (num_f_bins, num_t_bins), mode="bicubic"
resampled[0],
(num_f_bins, num_t_bins),
mode=self.interpolation_method,
)
return torch.squeeze(resampled)

def forward(
self,
X: TimeSeries1to3d,
norm: str = "median",
spectrogram_shape: Optional[Tuple[int, int]] = None,
):
"""
Compute the Q-tiles and interpolate
Expand All @@ -321,24 +421,15 @@ def forward(
three-dimensional, axes will be added during Q-tile
computation.
norm:
The method of interpolation used by each QTile
spectrogram_shape:
The shape of the interpolated spectrogram, specified as
`(num_f_bins, num_t_bins)`. Because the
frequency spacing of the Q-tiles is in log-space, the frequency
interpolation is log-spaced as well. If not given, the shape
used to initialize the transform will be used.
The method of normalization used by each QTile

Returns:
The interpolated Q-transform for the batch of data. Output will
have one more dimension than the input
"""

if spectrogram_shape is None:
spectrogram_shape = self.spectrogram_shape
num_f_bins, num_t_bins = spectrogram_shape
self.compute_qtiles(X, norm)
return self.interpolate(num_f_bins, num_t_bins)
return self.interpolate()


class QScan(torch.nn.Module):
Expand Down Expand Up @@ -376,14 +467,22 @@ def __init__(
spectrogram_shape: Tuple[int, int],
qrange: List[float] = [4, 64],
frange: List[float] = [0, torch.inf],
interpolation_method="bicubic",
mismatch: float = 0.2,
) -> None:
super().__init__()
self.qrange = qrange
self.mismatch = mismatch
self.qs = self.get_qs()
self.frange = frange
self.spectrogram_shape = spectrogram_shape
max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
self.qs = self.get_qs()
if self.qs[-1] >= max_q:
warnings.warn(
"Some Q values exceed the maximum allowable Q value of "
f"{max_q}. The list of Q values to be tested in this "
"scan will be truncated to avoid those values."
)

# Deliberately doing something different from GWpy here.
# Their final frange is the intersection of the frange
Expand All @@ -397,9 +496,11 @@ def __init__(
spectrogram_shape=spectrogram_shape,
q=q,
frange=self.frange.copy(),
interpolation_method=interpolation_method,
mismatch=self.mismatch,
)
for q in self.qs
if q < max_q
]
)

Expand All @@ -415,14 +516,14 @@ def get_qs(self) -> List[float]:
self.qrange[0] * math.exp(2 ** (1 / 2.0) * dq * (i + 0.5))
for i in range(nplanes)
]

return qs

def forward(
self,
X: TimeSeries1to3d,
fsearch_range: List[float] = None,
norm: str = "median",
spectrogram_shape: Optional[Tuple[int, int]] = None,
):
"""
Compute the set of QTiles for each Q transform and determine which
Expand All @@ -442,12 +543,6 @@ def forward(
for the maximum energy
norm:
The method of interpolation used by each QTile
spectrogram_shape:
The shape of the interpolated spectrogram, specified as
`(num_f_bins, num_t_bins)`. Because the
frequency spacing of the Q-tiles is in log-space, the frequency
interpolation is log-spaced as well. If not given, the shape
used to initialize the transform will be used.

Returns:
An interpolated Q-transform for the batch of data. Output will
Expand All @@ -463,7 +558,4 @@ def forward(
]
)
)
if spectrogram_shape is None:
spectrogram_shape = self.spectrogram_shape
num_f_bins, num_t_bins = spectrogram_shape
return self.q_transforms[idx].interpolate(num_f_bins, num_t_bins)
return self.q_transforms[idx].interpolate()
Loading
Loading