Skip to content

Commit

Permalink
Merge pull request #297 from narduzzi/feat/ntidigits
Browse files Browse the repository at this point in the history
Added the NTIDIGITS18 dataset
  • Loading branch information
biphasic authored Jan 7, 2025
2 parents 82ff84c + 50a80d3 commit fa4fdfb
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Audio event stream classification

SHD
SSC
NTIDIGITS18

Pose estimation, visual odometry, SLAM
--------------------------------------
Expand Down
62 changes: 62 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,68 @@ def inject_fake_data(self, tmpdir):
return {"n_samples": 1}


def create_ntidigits_data(filename, n_samples):
with h5py.File(filename, mode="w") as write_file:
# Generate random times and units
times = np.random.random(size=(n_samples, 100)).astype(np.float16)
units = np.random.randint(0, 64, size=times.shape, dtype=np.uint16)

# Generate random sequences of symbols
symbols = ["o", "1", "2", "3", "4", "5", "6", "7", "8", "9", "z"]
sequences = np.array([
"speaker-a-" + "".join(np.random.choice(symbols, size=np.random.randint(1, 7), replace=True))
for _ in range(n_samples)
]).astype("|S18")

for partition in ["train", "test"]:
# Create a dictionary to store the addresses and timestamps of each speaker
addresses = {}
timestamps = {}
for i in range(n_samples):
speaker = sequences[i]
if speaker not in addresses:
addresses[speaker] = units[i]
timestamps[speaker] = times[i]

# Create a group for addresses and timestamps, Store each speaker's data as a dataset in the group
train_address_group = write_file.create_group("{}_addresses".format(partition))
for speaker, addresses in addresses.items():
train_address_group.create_dataset(speaker, data=np.array(addresses, dtype=np.uint16))

train_timestamps_group = write_file.create_group("{}_timestamps".format(partition))
for speaker, timestamps in timestamps.items():
train_timestamps_group.create_dataset(speaker, data=np.array(timestamps, dtype=np.float16))

# Create datasets for labels
write_file.create_dataset("{}_labels".format(partition), data=sequences)


class NTIDIGITS18TestCaseTrain(dataset_utils.DatasetTestCase):
DATASET_CLASS = datasets.NTIDIGITS18
FEATURE_TYPES = (datasets.NTIDIGITS18.dtype,)
TARGET_TYPES = (int,)
KWARGS = {"train": True}

def inject_fake_data(self, tmpdir):
testfolder = os.path.join(tmpdir, "NTIDIGITS18/")
os.makedirs(testfolder, exist_ok=True)
create_ntidigits_data(testfolder + "n-tidigits.hdf5", n_samples=2)
return {"n_samples": 2}


class NTIDIGITS18TestCaseTest(dataset_utils.DatasetTestCase):
DATASET_CLASS = datasets.NTIDIGITS18
FEATURE_TYPES = (datasets.NTIDIGITS18.dtype,)
TARGET_TYPES = (int,)
KWARGS = {"train": False}

def inject_fake_data(self, tmpdir):
testfolder = os.path.join(tmpdir, "NTIDIGITS18/")
os.makedirs(testfolder, exist_ok=True)
create_ntidigits_data(testfolder + "n-tidigits.hdf5", n_samples=2)
return {"n_samples": 2}


def create_hsd_data(filename, n_samples):
with h5py.File(filename, mode="w") as write_file:
times = np.random.random(size=(n_samples, 100)).astype(np.float16)
Expand Down
1 change: 1 addition & 0 deletions tonic/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .dvs_lips import DVSLip
from .dvsgesture import DVSGesture
from .ebssa import EBSSA
from .ntidigits18 import NTIDIGITS18
from .hsd import SHD, SSC
from .mvsec import MVSEC
from .ncaltech101 import NCALTECH101
Expand Down
145 changes: 145 additions & 0 deletions tonic/datasets/ntidigits18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/user/bin/env python

import numpy as np
import h5py
import os
from typing import Callable, Optional

from tonic.dataset import Dataset
from tonic.io import make_structured_array
import requests
from tqdm import tqdm

class NTIDIGITS18(Dataset):
"""`N-TIDIGITS18 Dataset <https://docs.google.com/document/d/1Uxe7GsKKXcy6SlDUX4hoJVAC0-UkH-8kr5UXp0Ndi1M/edit?tab=t.0#heading=h.sbnu5gtazqjq/>`_
Cochlea Spike Dataset.
::
@article{anumula2018feature,
title={Feature representations for neuromorphic audio spike streams},
author={Anumula, Jithendar and Neil, Daniel and Delbruck, Tobi and Liu, Shih-Chii},
journal={Frontiers in neuroscience},
volume={12},
pages={23},
year={2018},
publisher={Frontiers Media SA}
}
Parameters:
save_to (string): Location to save files to on disk. Will put files in an 'hsd' subfolder.
train (bool): If True, uses training subset, otherwise testing subset.
single_digits (bool): If True, only returns samples with single digits (o, 1, 2, 3, 4, 5, 6, 7, 8, 9, z), with class 0 for 'o' and 11 for 'z'.
transform (callable, optional): A callable of transforms to apply to the data.
target_transform (callable, optional): A callable of transforms to apply to the targets/labels.
Returns:
A dataset object that can be indexed or iterated over. One sample returns a tuple of (events, targets).
"""

base_url = "https://www.dropbox.com/scl/fi/1x4lxt9yyw25sc3tez8oi/n-tidigits.hdf5?e=2&rlkey=w8gi5udvib2zqzosusa5tr3wq&dl=1"
filename = "n-tidigits.hdf5"
file_md5 = "360a2d11e5429555c9197381cf6b58e0"
folder_name = ""

sensor_size = (64, 1, 1)
dtype = np.dtype([("t", int), ("x", int), ("p", int)])
ordering = dtype.names

class_map = {"o": 0,
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"z": 10}

def __init__(
self,
save_to: str,
train: bool = True,
single_digits=False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
):
super().__init__(
save_to,
transform=transform,
target_transform=target_transform,
)

self.url = self.base_url
# load the data
self.file_path = os.path.join(self.location_on_system, self.filename)

if not self._check_exists():
self.download()

self.data = h5py.File(self.file_path, 'r')
self.partition = "train" if train else "test"
self.single_indices = [i for i in range(len(self.data[f"{self.partition}_labels"])) if
len(self.data[f"{self.partition}_labels"][i].decode().split("-")[-1]) == 1]
self._samples = [x.decode() for x in self.data[f"{self.partition}_labels"]]
self.single_digits = single_digits

if single_digits:
self._samples = [self._samples[i] for i in self.single_indices]

self.labels = [x.decode().split("-")[-1] for x in self.data[f"{self.partition}_labels"]]

def download(self) -> None:
response = requests.get(self.base_url, stream=True)
if response.status_code == 200:
print("Downloading N-TIDIGITS from Dropbox at {}...".format(self.base_url))
file_size = int(response.headers.get('Content-Length', 0)) # get total file size in bytes
chunk_size = 8192

os.makedirs(self.location_on_system, exist_ok=True)
# Initialize progress bar
with open(os.path.join(self.location_on_system, self.filename), 'wb') as f, tqdm(
total=file_size,
unit='B',
unit_scale=True,
desc="Downloading",
ascii=True
) as pbar:
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
pbar.update(len(chunk))
else:
print("Failed to download N-TIDIGITS from Dropbox. Please try again later.")
response.raise_for_status()
def __getitem__(self, index):
sample_id = self._samples[index]
x = np.asarray(self.data[f"{self.partition}_addresses"][sample_id])
t = np.asarray(self.data[f"{self.partition}_timestamps"][sample_id])
events = make_structured_array(
t * 1e6,
x,
1,
dtype=self.dtype,
)

target = sample_id.split("-")[-1]

if self.single_digits:
assert len(target) == 1, "Single digit samples requested, but target is not single digit."
target = self.class_map[target]

if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
target = self.target_transform(target)
return events, target

def __len__(self):
return len(self._samples)

def _check_exists(self):
return (
self._is_file_present()
)

0 comments on commit fa4fdfb

Please sign in to comment.