Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add allow partial argument to compute epoch intersects #43

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions neuro_py/process/intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def overlap_intersect(


@jit(nopython=True)
def _find_intersecting_intervals(set1: np.ndarray, set2: np.ndarray) -> List[float]:
def _find_intersecting_intervals(
set1: np.ndarray, set2: np.ndarray, allow_partial: bool
) -> List[float]:
"""
Find the amount of time two sets of intervals are intersecting each other for each interval in set1.

Expand All @@ -196,18 +198,27 @@ def _find_intersecting_intervals(set1: np.ndarray, set2: np.ndarray) -> List[flo
An array of intervals represented as pairs of start and end times.
set2 : ndarray
An array of intervals represented as pairs of start and end times.
allow_partial : bool
If True, allow partial intersections between intervals. If False, only
full intersections of set1 lying within set2 are considered.

Returns
-------
list of float
A list of floats, where each float represents the amount of time the
corresponding interval in set1 intersects with any interval in set2.
"""

def chk_intersect(start1, end1, start2, end2, allow_partial):
if allow_partial:
return start2 <= end1 and end2 >= start1
return start1 >= start2 and end1 <= end2 # set1 is within set2

intersecting_intervals = []
for i, (start1, end1) in enumerate(set1):
# Check if any of the intervals in set2 intersect with the current interval in set1
for start2, end2 in set2:
if start2 <= end1 and end2 >= start1:
if chk_intersect(start1, end1, start2, end2, allow_partial):
# Calculate the amount of intersection between the two intervals
intersection = min(end1, end2) - max(start1, start2)
intersecting_intervals.append(intersection)
Expand All @@ -219,7 +230,10 @@ def _find_intersecting_intervals(set1: np.ndarray, set2: np.ndarray) -> List[flo


def find_intersecting_intervals(
set1: nel.EpochArray, set2: nel.EpochArray, return_indices: bool = True
set1: nel.EpochArray,
set2: nel.EpochArray,
return_indices: bool = True,
allow_partial: bool = True,
) -> Union[np.ndarray, List[bool]]:
"""
Find the amount of time two sets of intervals are intersecting each other for each intersection.
Expand All @@ -231,14 +245,24 @@ def find_intersecting_intervals(
set2 : nelpy EpochArray
The second set of intervals to check for intersections.
return_indices : bool, optional
If True, return the indices of the intervals in set2 that intersect with each interval in set1.
If False, return the amount of time each interval in set1 intersects with any interval in set2.
If True, return the indices of the intervals in set2 that intersect with
each interval in set1.
If False, return the amount of time each interval in set1 intersects
with any interval in set2.
Default is True.
allow_partial : bool, optional
If True, allow partial intersections between intervals.
If False, only full intersections of set1 lying within set2 are
considered.
Default is True.

Returns
-------
Union[np.ndarray, List[bool]]
If return_indices is True, returns a boolean array indicating whether each interval in set1 intersects with any interval in set2.
If return_indices is False, returns a NumPy array with the amount of time each interval in set1 intersects with any interval in set2.
If return_indices is True, returns a boolean array indicating whether
each interval in set1 intersects with any interval in set2.
If return_indices is False, returns a NumPy array with the amount of
time each interval in set1 intersects with any interval in set2.

Examples
--------
Expand All @@ -252,7 +276,9 @@ def find_intersecting_intervals(
if not isinstance(set1, core.IntervalArray) & isinstance(set2, core.IntervalArray):
raise ValueError("only EpochArrays are supported")

intersection = np.array(_find_intersecting_intervals(set1.data, set2.data))
intersection = np.array(
_find_intersecting_intervals(set1.data, set2.data, allow_partial)
)
if return_indices:
return intersection > 0
return intersection
Expand Down
Loading