Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

raw to lfp converter #38

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions neuro_py/raw/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__all__ = ["remove_artifacts"]
__all__ = ["remove_artifacts","downsample_binary"]

from .preprocessing import remove_artifacts
from .preprocessing import remove_artifacts, downsample_binary
152 changes: 148 additions & 4 deletions neuro_py/raw/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import gc
import os
import warnings
from typing import List, Tuple, Optional
from typing import List, Optional, Tuple

import numba as nb
import numpy as np
from scipy.signal import butter, firwin, sosfiltfilt


def remove_artifacts(
Expand Down Expand Up @@ -96,9 +98,7 @@ def remove_artifacts(
data[start, ch],
data[end, ch],
end - start,
).astype(
data.dtype
) # Ensure consistent dtype
).astype(data.dtype) # Ensure consistent dtype
data[start:end, ch] = interpolated
else:
warnings.warn(
Expand Down Expand Up @@ -172,3 +172,147 @@ def remove_artifacts(
f.write(f"Zeroed intervals: {zero_intervals.tolist()}\n")
except Exception as e:
warnings.warn(f"Failed to create log file: {e}")


def downsample_binary(
filepath: str,
n_channels: int,
original_fs: int = 20000,
target_fs: int = 1250,
precision: str = "int16",
filter_order: int = 4,
) -> str:
"""
Optimized function to downsample raw binary data.
"""
if original_fs % target_fs != 0:
raise ValueError(
"Original sampling frequency must be an integer multiple of the target frequency."
)

downsample_factor = original_fs // target_fs
nyquist = target_fs / 2

# Design a stable low-pass filter
sos = butter(filter_order, nyquist / (original_fs / 2), btype="low", output="sos")

downsampled_filepath = (
os.path.splitext(filepath)[0] + ".lfp"
)

bytes_size = np.dtype(precision).itemsize
chunk_size = 10_000 # Adjust for optimal performance
with open(filepath, "rb") as infile, open(downsampled_filepath, "wb") as outfile:
infile.seek(0, 2)
n_samples = infile.tell() // (n_channels * bytes_size)
infile.seek(0, 0)

for start_idx in range(0, n_samples, chunk_size):
end_idx = min(start_idx + chunk_size, n_samples)
n_chunk_samples = end_idx - start_idx

# Load chunk
data = np.fromfile(
infile, dtype=precision, count=n_chunk_samples * n_channels
)
data = data.reshape((n_chunk_samples, n_channels))

# Filter and downsample
filtered_data = sosfiltfilt(sos, data, axis=0)
downsampled_data = filtered_data[::downsample_factor, :]

# Write to output file
downsampled_data.astype(precision).tofile(outfile)

del data, filtered_data, downsampled_data
gc.collect()

return downsampled_filepath





@nb.jit(nopython=True, parallel=True, fastmath=True)
def filter_and_downsample(data, fir_coeffs, downsample_factor):
"""
JIT-compiled function to filter and downsample data.
"""
n_samples, n_channels = data.shape
n_output_samples = n_samples // downsample_factor
output = np.zeros((n_output_samples, n_channels), dtype=data.dtype)

for ch in nb.prange(n_channels):
# Convolve with FIR filter (linear phase, symmetric)
filtered = np.convolve(data[:, ch], fir_coeffs, mode="valid")
# Downsample
output[:, ch] = filtered[::downsample_factor]

return output


def downsample_binary_ultrafast(
filepath: str,
n_channels: int,
original_fs: int = 20000,
target_fs: int = 1250,
precision: str = "int16",
filter_order: int = 64,
) -> str:
"""
Ultrafast function to downsample raw binary data.
"""
if original_fs % target_fs != 0:
raise ValueError("Original sampling frequency must be an integer multiple of the target frequency.")

downsample_factor = original_fs // target_fs
nyquist = target_fs / 2

# Design FIR filter
fir_coeffs = firwin(filter_order + 1, nyquist / (original_fs / 2), pass_zero="lowpass")

# Output file
downsampled_filepath = os.path.splitext(filepath)[0] + ".lfp"

# Memory-mapped I/O setup
bytes_size = np.dtype(precision).itemsize
chunk_size = 10_000_000 # Process 10M samples at a time for I/O efficiency
with open(filepath, "rb") as infile, open(downsampled_filepath, "wb") as outfile:
infile.seek(0, 2)
n_samples = infile.tell() // (n_channels * bytes_size)
infile.seek(0, 0)

for start_idx in range(0, n_samples, chunk_size):
end_idx = min(start_idx + chunk_size, n_samples)
n_chunk_samples = end_idx - start_idx

# Load chunk
data = np.fromfile(infile, dtype=precision, count=n_chunk_samples * n_channels)
data = data.reshape((n_chunk_samples, n_channels))

# Filter and downsample
downsampled_data = filter_and_downsample(data, fir_coeffs, downsample_factor)

# Write to output file
downsampled_data.astype(precision).tofile(outfile)

del data, downsampled_data
gc.collect()

return downsampled_filepath


if __name__ == "__main__":
# time function
import time

start = time.time()
downsample_binary_ultrafast(
filepath=r"U:\data\hpc_ctx_project\HP13\HP13_day1_20241030\HP13_probe_241030_111814\amplifier - Copy.dat",
n_channels=128,
original_fs=20000,
target_fs=1250,
precision="int16",
filter_order=4,
)
print(f"Elapsed time: {time.time() - start:.2f} s")
Loading