Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: split_gaze_data by column values #859

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1c6c769
feat: split_gaze_data into trial
SiQube Oct 23, 2024
976695b
docs: Add missing modules to documentation (#866)
dkrako Oct 23, 2024
953ade3
hotfix: check whether public dataset has gaze files (#872)
SiQube Oct 24, 2024
b842bdb
docs: correctly add EyeTracker class to gaze module (#876)
dkrako Oct 24, 2024
cc1bae1
feat: add support for .ias files in stimulus.text.from_file() (#858)
SiQube Oct 24, 2024
5417804
dataset: beijing sentence corpus (#857)
SiQube Oct 24, 2024
92b49a7
dataset: add InteRead dataset (#862)
SiQube Oct 24, 2024
f0b69a9
fix: copy event resource files instead of moving them to events direc…
SiQube Oct 24, 2024
e6a9ced
hotfix: CopCo dataset precomputed eventsloading (#873)
SiQube Oct 24, 2024
1b8c4bd
ci: ignore too-many-public-methods (#882)
dkrako Oct 25, 2024
69ef837
ci: pre-commit autoupdate (#889)
pre-commit-ci[bot] Oct 29, 2024
cfbce95
ci: pre-commit autoupdate (#890)
pre-commit-ci[bot] Nov 5, 2024
47e734d
build: add support for python 3.13 (#845)
SiQube Nov 7, 2024
166b076
build: update nbsphinx requirement from <0.9.5,>=0.8.8 to >=0.8.8,<0.…
dependabot[bot] Nov 7, 2024
495e5d9
ci: pre-commit autoupdate (#896)
pre-commit-ci[bot] Nov 12, 2024
b691e6d
build: update setuptools-git-versioning requirement from <2 to <3 (#895)
dependabot[bot] Nov 12, 2024
88113c8
hotfix: download link fakenewsperception dataset (#897)
SiQube Nov 13, 2024
21fd0d2
feat: Store metadata from ASC in experiment metadata (#884)
saeub Nov 14, 2024
0856658
move split method to gaze dataframe
SiQube Nov 17, 2024
4751e41
Merge branch 'main' into split-gaze-files-into-trial-dataframes
SiQube Nov 17, 2024
b47ad31
ci: pre-commit autoupdate (#899)
pre-commit-ci[bot] Nov 18, 2024
5f5525a
ci: pre-commit autoupdate (#900)
pre-commit-ci[bot] Nov 27, 2024
c30bd9e
Add trial_columns argument in from_asc() (#898)
saeub Nov 27, 2024
7a25297
ci: pre-commit autoupdate (#902)
pre-commit-ci[bot] Dec 3, 2024
96141d5
docs: add CITATION.cff (#901)
SiQube Dec 8, 2024
5bf55f1
ci: pre-commit autoupdate (#904)
pre-commit-ci[bot] Dec 10, 2024
e4b3e8f
ci: add dataset section to release drafter (#903)
dkrako Dec 10, 2024
eb8aee5
move split method to gaze dataframe
SiQube Nov 17, 2024
ecd6b5c
add tests for number of split files
SiQube Dec 29, 2024
7229cd9
Merge branch 'main' into split-gaze-files-into-trial-dataframes
SiQube Dec 29, 2024
52902af
Merge branch 'main' into split-gaze-files-into-trial-dataframes
SiQube Dec 29, 2024
8994bb1
Merge branch 'main' into split-gaze-files-into-trial-dataframes
SiQube Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/pymovements/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,30 @@ def load_precomputed_reading_measures(self) -> None:
self.paths,
)

def split_gaze_data(
self,
by: list[str] | str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use Sequence here from collections, this way it's more in line with the polars signature: https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.partition_by.html

) -> None:
"""Split gaze data into separated GazeDataFrame's.

Parameters
----------
by: list[str] | str
Column(s) to split dataframe by.
"""
fileinfo_dicts = self.fileinfo['gaze'].to_dicts()

all_gaze_frames = []
all_fileinfo_rows = []

for frame, fileinfo_row in zip(self.gaze, fileinfo_dicts):
split_frames = frame.split(by=by)
all_gaze_frames.extend(split_frames)
all_fileinfo_rows.extend([fileinfo_row] * len(split_frames))

self.gaze = all_gaze_frames
self.fileinfo['gaze'] = pl.concat([pl.from_dict(row) for row in all_fileinfo_rows])

def split_precomputed_events(
self,
by: list[str] | str,
Expand Down
28 changes: 28 additions & 0 deletions src/pymovements/gaze/gaze_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(

# Remove this attribute once #893 is fixed
self._metadata: dict[str, Any] | None = None
self.auto_column_detect = auto_column_detect
SiQube marked this conversation as resolved.
Show resolved Hide resolved

def apply(
self,
Expand All @@ -307,6 +308,33 @@ def apply(
else:
raise ValueError(f"unsupported method '{function}'")

def split(self, by: list[str] | str) -> list[GazeDataFrame]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use Sequence[str]

"""Split the GazeDataFrame into multiple frames based on specified column(s).

Parameters
----------
by: list[str] | str
Column name(s) to split the DataFrame by. If a single string is provided,
it will be used as a single column name. If a list is provided, the DataFrame
will be split by unique combinations of values in all specified columns.

Returns
-------
list[GazeDataFrame]
A list of new GazeDataFrame instances, each containing a partition of the
original data with all metadata and configurations preserved.
"""
return [
GazeDataFrame(
new_frame,
experiment=self.experiment,
trial_columns=self.trial_columns,
time_column='time',
distance_column='distance',
)
for new_frame in self.frame.partition_by(by=by)
]

def transform(
self,
transform_method: str | Callable[..., pl.Expr],
Expand Down
53 changes: 51 additions & 2 deletions tests/unit/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def mock_toy(
'y_left_pix': np.zeros(1000),
'x_right_pix': np.zeros(1000),
'y_right_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
Expand All @@ -154,6 +156,8 @@ def mock_toy(
'y_left_pix': pl.Float64,
'x_right_pix': pl.Float64,
'y_right_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_left_pix', 'y_left_pix', 'x_right_pix', 'y_right_pix']
Expand All @@ -169,6 +173,8 @@ def mock_toy(
'y_right_pix': np.zeros(1000),
'x_avg_pix': np.zeros(1000),
'y_avg_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
Expand All @@ -179,6 +185,8 @@ def mock_toy(
'y_right_pix': pl.Float64,
'x_avg_pix': pl.Float64,
'y_avg_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = [
Expand All @@ -192,12 +200,16 @@ def mock_toy(
'time': np.arange(1000),
'x_left_pix': np.zeros(1000),
'y_left_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
'time': pl.Int64,
'x_left_pix': pl.Float64,
'y_left_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_left_pix', 'y_left_pix']
Expand All @@ -208,12 +220,16 @@ def mock_toy(
'time': np.arange(1000),
'x_right_pix': np.zeros(1000),
'y_right_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
'time': pl.Int64,
'x_right_pix': pl.Float64,
'y_right_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_right_pix', 'y_right_pix']
Expand All @@ -224,12 +240,16 @@ def mock_toy(
'time': np.arange(1000),
'x_pix': np.zeros(1000),
'y_pix': np.zeros(1000),
'trial_id_1': np.concatenate([np.zeros(500), np.ones(500)]),
'trial_id_2': ['a'] * 200 + ['b'] * 200 + ['c'] * 600,
},
schema={
'subject_id': pl.Int64,
'time': pl.Int64,
'x_pix': pl.Float64,
'y_pix': pl.Float64,
'trial_id_1': pl.Float64,
'trial_id_2': pl.Utf8,
},
)
pixel_columns = ['x_pix', 'y_pix']
Expand Down Expand Up @@ -1000,7 +1020,8 @@ def test_detect_events_attribute_error(gaze_dataset_configuration):
},
(
"Column 'position' not found. Available columns are: "
"['time', 'subject_id', 'pixel', 'custom_position', 'velocity']"
"['time', 'trial_id_1', 'trial_id_2', 'subject_id', "
"'pixel', 'custom_position', 'velocity']"
),
id='no_position',
),
Expand All @@ -1012,7 +1033,8 @@ def test_detect_events_attribute_error(gaze_dataset_configuration):
},
(
"Column 'velocity' not found. Available columns are: "
"['time', 'subject_id', 'pixel', 'position', 'custom_velocity']"
"['time', 'trial_id_1', 'trial_id_2', 'subject_id', "
"'pixel', 'position', 'custom_velocity']"
),
id='no_velocity',
),
Expand Down Expand Up @@ -1930,3 +1952,30 @@ def test_load_split_precomputed_events(precomputed_dataset_configuration, by, ex
dataset.load()
dataset.split_precomputed_events(by)
assert len(dataset.precomputed_events) == expected_len


@pytest.mark.parametrize(
('by', 'expected_len'),
[
pytest.param(
'trial_id_1',
40,
id='subset_int',
),
pytest.param(
'trial_id_2',
60,
id='subset_int',
),
pytest.param(
['trial_id_1', 'trial_id_2'],
80,
id='subset_int',
),
],
)
def test_load_split_gaze(gaze_dataset_configuration, by, expected_len):
dataset = pm.Dataset(**gaze_dataset_configuration['init_kwargs'])
dataset.load()
dataset.split_gaze_data(by)
assert len(dataset.gaze) == expected_len
22 changes: 22 additions & 0 deletions tests/unit/gaze/gaze_dataframe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,25 @@ def test_gaze_dataframe_copy_no_experiment():

# We want to have separate experiment instances but the same values.
assert gaze.experiment is gaze_copy.experiment


def test_gaze_dataframe_split():
gaze = pm.GazeDataFrame(
pl.DataFrame(
{
'x': [0, 1, 2, 3],
'y': [1, 1, 0, 0],
'trial_id': [0, 1, 1, 2],
},
schema={'x': pl.Float64, 'y': pl.Float64, 'trial_id': pl.Int8},
),
experiment=None,
position_columns=['x', 'y'],
)

split_gaze = gaze.split('trial_id')
assert all(gaze_df.frame.n_unique('trial_id') == 1 for gaze_df in split_gaze)
assert len(split_gaze) == 3
assert_frame_equal(gaze.frame.filter(pl.col('trial_id') == 0), split_gaze[0].frame)
assert_frame_equal(gaze.frame.filter(pl.col('trial_id') == 1), split_gaze[1].frame)
assert_frame_equal(gaze.frame.filter(pl.col('trial_id') == 2), split_gaze[2].frame)
Loading