Skip to content

Commit

Permalink
Merge pull request #258 from anindyaghosh/event_downsampling
Browse files Browse the repository at this point in the history
Event based spatio-temporal downsampling
  • Loading branch information
biphasic authored Aug 8, 2023
2 parents 7be656b + bba009c commit 639d469
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 18 deletions.
31 changes: 31 additions & 0 deletions docs/gallery/transformations/plot_event_downsampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
==========
EventDownsampling
==========
The :class:`~tonic.transforms.EventDownsampling` applies
spatio-temporal downsampling to events as per the downsampling method chosen.
"""

import tonic

nmnist = tonic.datasets.NMNIST("../../tutorials/data", train=False)
events, label = nmnist[0]

transform = tonic.transforms.Compose(
[
tonic.transforms.EventDownsampling(sensor_size=nmnist.sensor_size,
target_size=(12, 12),
dt=0.01,
downsampling_method="differentiator",
noise_threshold=0,
differentiator_time_bins=2),
tonic.transforms.ToFrame(
sensor_size=(12, 12, 2),
time_window=10000,
),
]
)

frames = transform(events)

ani = tonic.utils.plot_animation(frames)
49 changes: 39 additions & 10 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,23 +247,52 @@ def test_transform_drop_pixel_raster(coordinates, hot_pixel_frequency):
assert not merged_polarity_raster[merged_polarity_raster > 5000].sum().sum()


@pytest.mark.parametrize("time_factor, spatial_factor", [(1, 0.25), (1e-3, 1)])
def test_transform_downsample(time_factor, spatial_factor):
@pytest.mark.parametrize("time_factor, spatial_factor, target_size", [(1, 0.25, None), (1e-3, (1, 2), None), (1, 1, (5, 5))])
def test_transform_downsample(time_factor, spatial_factor, target_size):
orig_events, sensor_size = create_random_input()

transform = transforms.Downsample(
time_factor=time_factor, spatial_factor=spatial_factor
sensor_size=sensor_size, time_factor=time_factor, spatial_factor=spatial_factor, target_size=target_size
)

events = transform(orig_events)

assert np.array_equal(
(orig_events["t"] * time_factor).astype(orig_events["t"].dtype), events["t"]
)
assert np.array_equal(np.floor(orig_events["x"] * spatial_factor), events["x"])
assert np.array_equal(np.floor(orig_events["y"] * spatial_factor), events["y"])

if not isinstance(spatial_factor, tuple):
spatial_factor = (spatial_factor, spatial_factor)

if target_size is None:
assert np.array_equal(
(orig_events["t"] * time_factor).astype(orig_events["t"].dtype), events["t"]
)
assert np.array_equal(np.floor(orig_events["x"] * spatial_factor[0]), events["x"])
assert np.array_equal(np.floor(orig_events["y"] * spatial_factor[1]), events["y"])

else:
spatial_factor_test = np.asarray(target_size) / sensor_size[:-1]
assert np.array_equal(np.floor(orig_events["x"] * spatial_factor_test[0]), events["x"])
assert np.array_equal(np.floor(orig_events["y"] * spatial_factor_test[1]), events["y"])

assert events is not orig_events



@pytest.mark.parametrize("target_size, dt, downsampling_method, noise_threshold, differentiator_time_bins",
[((50, 50), 0.05, 'integrator', 1, None),
((20, 15), 5, 'differentiator', 3, 1)])
def test_transform_event_downsampling(target_size, dt, downsampling_method, noise_threshold,
differentiator_time_bins):

orig_events, sensor_size = create_random_input()

transform = transforms.EventDownsampling(sensor_size=sensor_size, target_size=target_size, dt=dt,
downsampling_method=downsampling_method, noise_threshold=noise_threshold,
differentiator_time_bins=differentiator_time_bins)

events = transform(orig_events)

assert len(events) <= len(orig_events)
assert np.logical_and(np.all(events["x"] <= target_size[0]), np.all(events["y"] <= target_size[1]))
assert events is not orig_events


@pytest.mark.parametrize("target_size", [(50, 50), (10, 5)])
def test_transform_random_crop(target_size):
Expand Down
6 changes: 6 additions & 0 deletions tonic/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
identify_hot_pixel,
identify_hot_pixel_raster,
)
from .event_downsampling import (
integrator_downsample,
differentiator_downsample,
)
from .refractory_period import refractory_period_numpy
from .spatial_jitter import spatial_jitter_numpy
from .time_jitter import time_jitter_numpy
Expand Down Expand Up @@ -38,4 +42,6 @@
"to_voxel_grid_numpy",
"to_bina_rep_numpy",
"uniform_noise_numpy",
"integrator_downsample",
"differentiator_downsample",
]
157 changes: 157 additions & 0 deletions tonic/functional/event_downsampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import numpy as np
from numpy.lib.recfunctions import unstructured_to_structured

from tonic.functional.to_frame import to_frame_numpy

def differentiator_downsample(events: np.ndarray, sensor_size: tuple, target_size: tuple, dt: float,
differentiator_time_bins: int = 2, noise_threshold: int = 0):
"""Spatio-temporally downsample using the integrator method coupled with a differentiator to effectively
downsample large object sizes relative to downsampled pixel resolution in the DVS camera's visual field.
Incorporates the paper Ghosh et al. 2023, Insect-inspired Spatio-temporal Downsampling of Event-based Input,
https://doi.org/10.1145/3589737.3605994
Parameters:
events (ndarray): ndarray of shape [num_events, num_event_channels].
sensor_size (tuple): a 3-tuple of x,y,p for sensor_size.
target_size (tuple): a 2-tuple of x,y denoting new down-sampled size for events to be
re-scaled to (new_width, new_height).
dt (float): step size for simulation, in ms.
differentiator_time_bins (int): number of equally spaced time bins with respect to the dt
to be used for the differentiator.
noise_threshold (int): number of events before a spike representing a new event is emitted.
Returns:
the spatio-temporally downsampled input events using the differentiator method.
"""

assert "x" and "y" and "t" in events.dtype.names
assert np.logical_and(np.remainder(differentiator_time_bins, 1) == 0, differentiator_time_bins >= 1)

events = events.copy()

# Call integrator method
dt_scaling, events_integrated = integrator_downsample(events, sensor_size=sensor_size, target_size=target_size,
dt=(dt / differentiator_time_bins),
noise_threshold=noise_threshold, differentiator_call=True)

if dt_scaling:
dt *= 1000

num_frames = int(events_integrated[-1][0] // dt + 1)
frame_histogram = np.zeros((num_frames, *np.flip(target_size), 2))

for event in events_integrated:
differentiated_time, event_histogram = event
time = int(differentiated_time // dt)

# Separate events based on polarity and apply Heaviside
event_hist_pos = (np.maximum(event_histogram >= noise_threshold, 0)).clip(max=1)
event_hist_neg = (-np.minimum(-event_histogram >= noise_threshold, 0)).clip(max=1)

frame_histogram[time,...,1] += event_hist_pos
frame_histogram[time,...,0] += event_hist_neg

# Differences between subsequent frames
frame_differences = (np.diff(frame_histogram, axis=0)).clip(min=0)

# Restructuring numpy array to structured array
time_index, y_new, x_new, polarity_new = np.nonzero(frame_differences)

events_new = np.column_stack((x_new, y_new, polarity_new.astype(dtype=bool), time_index * dt))

names = ["x", "y", "p", "t"]
formats = ['i4', 'i4', 'i4', 'i4']

dtype = np.dtype({'names': names, 'formats': formats})

return unstructured_to_structured(events_new.copy(), dtype=dtype)

def integrator_downsample(events: np.ndarray, sensor_size: tuple, target_size: tuple, dt: float, noise_threshold: int = 0,
differentiator_call: bool = False):
"""Spatio-temporally downsample using with the following steps:
1. Differencing of ON and OFF events to counter camera shake or jerk.
2. Use an integrate-and-fire (I-F) neuron model with a noise threshold similar to
the membrane potential threshold in the I-F model to eliminate high-frequency noise.
Multiply x/y values by a spatial_factor obtained by dividing sensor size by the target size.
Parameters:
events (ndarray): ndarray of shape [num_events, num_event_channels].
sensor_size (tuple): a 3-tuple of x,y,p for sensor_size.
target_size (tuple): a 2-tuple of x,y denoting new down-sampled size for events to be
re-scaled to (new_width, new_height).
dt (float): temporal resolution of events in milliseconds.
noise_threshold (int): number of events before a spike representing a new event is emitted.
differentiator_call (bool): Preserve frame spikes for differentiator method in order to optimise
differentiator method.
Returns:
the spatio-temporally downsampled input events using the integrator method.
"""

assert "x" and "y" and "t" in events.dtype.names
assert isinstance(noise_threshold, int)
assert dt is not None

events = events.copy()

if np.issubdtype(events["t"].dtype, np.integer):
dt *= 1000
dt_scaling = True

if differentiator_call:
assert dt // events["t"][-1] == 0

# Downsample
spatial_factor = np.asarray(target_size) / sensor_size[:-1]

events["x"] = events["x"] * spatial_factor[0]
events["y"] = events["y"] * spatial_factor[1]

# Compute all histograms at once
all_frame_histograms = to_frame_numpy(events, sensor_size=(*target_size, 2), time_window=dt)

# Subtract the channels for ON/OFF differencing
frame_histogram_diffs = all_frame_histograms[:, 1] - all_frame_histograms[:, 0]

frame_spike = np.zeros(np.flip(target_size))
event_histogram = []

events_new = []

for time, frame_histogram in enumerate(frame_histogram_diffs):

frame_spike += frame_histogram

coordinates_pos = np.stack(np.nonzero(np.maximum(frame_spike >= noise_threshold, 0))).T
coordinates_neg = np.stack(np.nonzero(np.maximum(-frame_spike >= noise_threshold, 0))).T

if np.logical_or(coordinates_pos.size, coordinates_neg.size).sum():

# For optimising differentiator
event_histogram.append((time*dt, frame_spike.copy()))

# Reset spiking coordinates to zero
frame_spike[coordinates_pos[:,0], coordinates_pos[:,1]] = 0
frame_spike[coordinates_neg[:,0], coordinates_neg[:,1]] = 0

# Restructure events
events_new.append(np.column_stack((np.flip(coordinates_pos, axis=1), np.ones((coordinates_pos.shape[0],1)).astype(dtype=bool),
(time*dt)*np.ones((coordinates_pos.shape[0],1)))))

events_new.append(np.column_stack((np.flip(coordinates_neg, axis=1), np.zeros((coordinates_neg.shape[0],1)).astype(dtype=bool),
(time*dt)*np.ones((coordinates_neg.shape[0],1)))))

if differentiator_call:
return dt_scaling, event_histogram
else:
events_new = np.concatenate(events_new.copy())

names = ["x", "y", "p", "t"]
formats = ['i4', 'i4', 'i4', 'i4']

dtype = np.dtype({'names': names, 'formats': formats})

return unstructured_to_structured(events_new.copy(), dtype=dtype)
83 changes: 75 additions & 8 deletions tonic/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,27 +257,47 @@ class Downsample:
Parameters:
time_factor (float): value to multiply timestamps with. Default is 1.
spatial_factor (float): value to multiply pixel coordinates with. Default is 1.
Note that when using subsequential transforms that require
sensor_size, you must change the spatial values for the later
transformation.
spatial_factor (float or tuple of floats): values to multiply pixel coordinates with. Default is 1.
Note that when using subsequential transforms that require
sensor_size, you must change the spatial values for the later
transformation.
sensor_size (tuple): size of the sensor that was used [W,H,P]
target_size (tuple): size of the desired resolution [W,H]
Example:
>>> from tonic.transforms import Downsample
>>> transform1 = Downsample(time_factor=0.001) # change us to ms
>>> transform2 = Downsample(spatial_factor=0.25) # reduce focal plane to 1/4.
>>> transform3 = Downsample(sensor_size=(40, 20, 2), target_size=(10, 5)) # reduce focal plane to 1/4.
"""

time_factor: float = 1
spatial_factor: float = 1

spatial_factor: Union[float, Tuple[float, float]] = 1
sensor_size: Optional[Tuple[int, int, int]] = None
target_size: Optional[Tuple[int, int]] = None

@staticmethod
def get_params(spatial_factor: Union[int, Tuple[int, int]]):
if not type(spatial_factor) == tuple:
spatial_factor = (spatial_factor, spatial_factor)
return spatial_factor

def __call__(self, events):
events = events.copy()

if self.target_size is not None:
# Ensure sensor_size is not None when target_size is not None
assert self.sensor_size is not None
# If both target_size and spatial_factor declared, override spatial_factor value in argument
spatial_factor = np.asarray(self.target_size) / self.sensor_size[:-1]
else:
spatial_factor = self.get_params(spatial_factor=self.spatial_factor)

events = functional.time_skew_numpy(events, coefficient=self.time_factor)
if "x" in events.dtype.names:
events["x"] = events["x"] * self.spatial_factor
events["x"] = events["x"] * spatial_factor[0]
if "y" in events.dtype.names:
events["y"] = events["y"] * self.spatial_factor
events["y"] = events["y"] * spatial_factor[1]
return events


Expand Down Expand Up @@ -316,6 +336,53 @@ def __call__(self, events):
return functional.drop_event_numpy(events, ratio)


@dataclass(frozen=True)
class EventDownsampling:
"""Applies EventDownsampling from the paper "Insect-inspired Spatio-temporal Downsampling of Event-based Input."
Allows:
1. Integrator based method to perform spatio-temporal event-based downsampling
2. Differentiator based method to perform spatio-temporal event-based downsampling
Parameters:
sensor_size (Tuple): size of the sensor that was used [W,H,P]
target_size (Tuple): size of the desired resolution [W,H]
dt (float): temporal resolution of events in ms
downsampling_method (str): string stating downsampling method. Choose from ['naive', 'integrator', 'differentiator']
noise_threshold (int): set number of events in downsampled pixel required to emit spike. Zero by default.
differentiator_time_bins (int): number of differentiator time bins within dt. Two by default.
Example:
>>> transform1 = tonic.transforms.EventDownsampling(sensor_size=(640,480,2), target_size=(20,15), dt=0.5,
downsampling_method='integrator')
>>> transform2 = tonic.transforms.EventDownsampling(sensor_size=(640,480,2), target_size=(20,15), dt=0.5,
downsampling_method='differentiator', noise_threshold=2,
differentiator_time_bins=3)
"""

sensor_size: Tuple[int, int, int]
target_size: Tuple[int, int]
downsampling_method: str
dt: Optional[float] = None
noise_threshold: Optional[int] = None
differentiator_time_bins: Optional[int] = None

def __call__(self, events):
assert self.downsampling_method in ['integrator', 'differentiator']

if self.downsampling_method == 'integrator':
return functional.integrator_downsample(
events=events, sensor_size=self.sensor_size, target_size=self.target_size,
dt=self.dt, noise_threshold=self.noise_threshold
)

elif self.downsampling_method == 'differentiator':
return functional.differentiator_downsample(
events=events, sensor_size=self.sensor_size, target_size=self.target_size,
dt=self.dt, noise_threshold=self.noise_threshold,
differentiator_time_bins=self.differentiator_time_bins
)


@dataclass(frozen=True)
class MergePolarities:
"""Sets all polarities to zero. This transform does not have any parameters.
Expand Down

0 comments on commit 639d469

Please sign in to comment.