Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assert generators do not mutate inputs #8

Merged
merged 4 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ packages = [{ include = "ezmsg", from = "src" }]

[tool.poetry.dependencies]
python = "^3.8"
ezmsg = { git = "https://github.com/iscoe/ezmsg.git", branch = "dev" }
ezmsg = "^3.4.0"
numpy = "^1.19.5"
scipy = "^1.6.3"

[tool.poetry.group.test.dependencies]
pytest = "^7.0.0"
pytest-cov = "*"
pytest-asyncio = "*"
flake8 = "*"

[tool.pytest.ini_options]
Expand Down
4 changes: 3 additions & 1 deletion src/ezmsg/sigproc/affinetransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import numpy.typing as npt
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.generator import consumer, GenAxisArray
from ezmsg.util.generator import consumer

from .base import GenAxisArray


@consumer
Expand Down
4 changes: 3 additions & 1 deletion src/ezmsg/sigproc/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import numpy as np
import ezmsg.core as ez
from ezmsg.util.generator import consumer, GenAxisArray
from ezmsg.util.generator import consumer
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.sigproc.spectral import OptionsEnum

from .base import GenAxisArray


class AggregationFunction(OptionsEnum):
"""Enum for aggregation functions available to be used in :obj:`ranged_aggregate` operation."""
Expand Down
3 changes: 2 additions & 1 deletion src/ezmsg/sigproc/bandpower.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import numpy as np
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.generator import consumer, compose, GenAxisArray
from ezmsg.util.generator import consumer, compose

from .spectrogram import spectrogram, SpectrogramSettings
from .aggregate import ranged_aggregate, AggregationFunction
from .base import GenAxisArray


@consumer
Expand Down
38 changes: 38 additions & 0 deletions src/ezmsg/sigproc/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import typing

import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.generator import GenState


class GenAxisArray(ez.Unit):
STATE: GenState

INPUT_SIGNAL = ez.InputStream(AxisArray)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
INPUT_SETTINGS = ez.InputStream(ez.Settings)

def initialize(self) -> None:
self.construct_generator()

# Method to be implemented by subclasses to construct the specific generator
def construct_generator(self):
raise NotImplementedError

@ez.subscriber(INPUT_SETTINGS)
async def on_settings(self, msg: ez.Settings) -> None:
self.apply_settings(msg)
self.construct_generator()

@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
@ez.publisher(OUTPUT_SIGNAL)
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
try:
ret = self.STATE.gen.send(message)
if ret.data.size > 0:
yield self.OUTPUT_SIGNAL, ret
except (StopIteration, GeneratorExit):
ez.logger.debug(f"Generator closed in {self.address}")
except Exception:
ez.logger.info(traceback.format_exc())

83 changes: 29 additions & 54 deletions src/ezmsg/sigproc/downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import typing

import numpy as np

from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.generator import consumer
import ezmsg.core as ez

from .base import GenAxisArray


@consumer
def downsample(
axis: typing.Optional[str] = None, factor: int = 1
axis: typing.Optional[str] = None,
factor: int = 1
) -> typing.Generator[AxisArray, AxisArray, None]:
"""
Construct a generator that yields a downsampled version of the data .send() to it.
Expand All @@ -35,6 +37,9 @@ def downsample(
axis_arr_in = AxisArray(np.array([]), dims=[""])
axis_arr_out = AxisArray(np.array([]), dims=[""])

if factor < 1:
raise ValueError("Downsample factor must be at least 1 (no downsampling)")

# state variables
s_idx: int = 0 # Index of the next msg's first sample into the virtual rotating ds_factor counter.
template: typing.Optional[AxisArray] = None
Expand Down Expand Up @@ -63,20 +68,23 @@ def downsample(

pub_samples = np.where(samples == 0)[0]
if len(pub_samples) > 0:
# Update the template directly, because we want
# future size-0 msgs to have approx. correct offset.
update_ax = template.axes[axis]
update_ax.offset = axis_info.offset + axis_info.gain * pub_samples[0].item()
axis_arr_out = replace(
template,
data=slice_along_axis(axis_arr_in.data, pub_samples, axis=axis_idx),
axes={**template.axes, axis: replace(update_ax, offset=update_ax.offset)}
)
template.axes[axis].offset = axis_info.offset + axis_info.gain * (n_samples + 1)
n_step = pub_samples[0].item()
data_slice = pub_samples
else:
# This iteration did not yield any samples. Return a size-0 array
# with time offset expected for _next_ sample.
axis_arr_out = template
n_step = 0
data_slice = slice(None, 0, None)
axis_arr_out = replace(
axis_arr_in,
data=slice_along_axis(axis_arr_in.data, data_slice, axis=axis_idx),
axes={
**axis_arr_in.axes,
axis: replace(
axis_info,
gain=axis_info.gain * factor,
offset=axis_info.offset + axis_info.gain * n_step
)
}
)


class DownsampleSettings(ez.Settings):
Expand All @@ -88,45 +96,12 @@ class DownsampleSettings(ez.Settings):
factor: int = 1


class DownsampleState(ez.State):
cur_settings: DownsampleSettings
gen: typing.Generator


class Downsample(ez.Unit):
"""
:obj:`Unit` for :obj:`downsample`.
"""
class Downsample(GenAxisArray):
""":obj:`Unit` for :obj:`bandpower`."""
SETTINGS: DownsampleSettings
STATE: DownsampleState

INPUT_SETTINGS = ez.InputStream(DownsampleSettings)
INPUT_SIGNAL = ez.InputStream(AxisArray)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)

def construct_generator(self):
self.STATE.gen = downsample(axis=self.STATE.cur_settings.axis, factor=self.STATE.cur_settings.factor)

def initialize(self) -> None:
self.STATE.cur_settings = self.SETTINGS
self.construct_generator()

@ez.subscriber(INPUT_SETTINGS)
async def on_settings(self, msg: DownsampleSettings) -> None:
self.STATE.cur_settings = msg
self.construct_generator()

@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
@ez.publisher(OUTPUT_SIGNAL)
async def on_signal(self, msg: AxisArray) -> typing.AsyncGenerator:
if self.STATE.cur_settings.factor < 1:
raise ValueError("Downsample factor must be at least 1 (no downsampling)")

try:
out_msg = self.STATE.gen.send(msg)
if out_msg.data.size > 0:
yield self.OUTPUT_SIGNAL, out_msg
except (StopIteration, GeneratorExit):
ez.logger.debug(f"Downsample closed in {self.address}")
except Exception:
ez.logger.info(traceback.format_exc())
self.STATE.gen = downsample(
axis=self.SETTINGS.axis,
factor=self.SETTINGS.factor
)
2 changes: 1 addition & 1 deletion src/ezmsg/sigproc/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ async def on_settings(self, msg: SamplerSettings) -> None:
async def on_trigger(self, msg: SampleTriggerMessage) -> None:
_ = self.STATE.gen.send(msg)

@ez.subscriber(INPUT_SIGNAL)
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
@ez.publisher(OUTPUT_SAMPLE)
async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
pub_samples = self.STATE.gen.send(msg)
Expand Down
4 changes: 3 additions & 1 deletion src/ezmsg/sigproc/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.generator import consumer, GenAxisArray
from ezmsg.util.generator import consumer

from .base import GenAxisArray


def _tau_from_alpha(alpha: float, dt: float) -> float:
Expand Down
27 changes: 13 additions & 14 deletions src/ezmsg/sigproc/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import numpy as np
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.generator import consumer, GenAxisArray
from ezmsg.util.generator import consumer

from .base import GenAxisArray


"""
Expand Down Expand Up @@ -71,29 +73,26 @@ def slicer(
indices = np.hstack([indices[_] for _ in _slices])
_slice = np.s_[indices]
# Create the output axis.
if axis in axis_arr_in.axes and hasattr(axis_arr_in.axes[axis], "labels"):
if (axis in axis_arr_in.axes
and hasattr(axis_arr_in.axes[axis], "labels")
and len(axis_arr_in.axes[axis].labels) > 0):
new_labels = axis_arr_in.axes[axis].labels[_slice]
new_axis = replace(
axis_arr_in.axes[axis],
labels=new_labels
)

replace_kwargs = {}
if b_change_dims:
out_dims = [_ for dim_ix, _ in enumerate(axis_arr_in.dims) if dim_ix != axis_idx]
out_axes = axis_arr_in.axes.copy()
out_axes.pop(axis, None)
else:
out_dims = axis_arr_in.dims
if new_axis is not None:
out_axes = {k: (v if k != axis else new_axis) for k, v in axis_arr_in.axes.items()}
else:
out_axes = axis_arr_in.axes

# Dropping the target axis
replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(axis_arr_in.dims) if dim_ix != axis_idx]
replace_kwargs["axes"] = {k: v for k, v in axis_arr_in.axes.items() if k != axis}
elif new_axis is not None:
replace_kwargs["axes"] = {k: (v if k != axis else new_axis) for k, v in axis_arr_in.axes.items()}
axis_arr_out = replace(
axis_arr_in,
dims=out_dims,
axes=out_axes,
data=slice_along_axis(axis_arr_in.data, _slice, axis_idx),
**replace_kwargs
)


Expand Down
4 changes: 3 additions & 1 deletion src/ezmsg/sigproc/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.generator import consumer, GenAxisArray, compose
from ezmsg.util.generator import consumer, compose
from ezmsg.util.messages.modify import modify_axis
from ezmsg.sigproc.window import windowing
from ezmsg.sigproc.spectrum import (
spectrum,
WindowFunction, SpectralTransform, SpectralOutput
)

from .base import GenAxisArray


@consumer
def spectrogram(
Expand Down
9 changes: 5 additions & 4 deletions src/ezmsg/sigproc/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import numpy as np
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.generator import consumer, GenAxisArray
from ezmsg.util.generator import consumer

from .base import GenAxisArray


class OptionsEnum(enum.Enum):
Expand Down Expand Up @@ -128,9 +130,8 @@ def spectrum(
else:
f_transform = f1

new_axes = {**axis_arr_in.axes, **{out_axis: freq_axis}}
if out_axis != axis_name:
new_axes.pop(axis_name, None)
new_axes = {k: v for k, v in axis_arr_in.axes.items() if k not in [out_axis, axis_name]}
new_axes[out_axis] = freq_axis

spec = np.fft.fft(axis_arr_in.data * window, axis=axis_idx) / n_time
spec = np.fft.fftshift(spec, axes=axis_idx)
Expand Down
Loading