-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
178 lines (142 loc) · 5.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import numpy as np
import pickle
import numpy as np
import torch
def find_npy_files(root_dir):
"""
Find all .npy files under a directory and its subdirectories.
Parameters:
- root_dir: The root directory to start the search from.
Returns:
- A list of full paths to .npy files.
"""
npy_files = []
for dirpath, dirnames, filenames in os.walk(root_dir):
for filename in filenames:
if filename.endswith('.npy'):
full_path = os.path.join(dirpath, filename)
npy_files.append(full_path)
return npy_files
import os
def parse_bids_filename(filepath):
"""
Parse a BIDS-format filepath to extract its components.
Parameters:
- filepath: Full path to the file following BIDS convention.
Returns:
- Dictionary with the identified BIDS components.
"""
bids_components = {}
# Extract the filename and its components
directory, filename = os.path.split(filepath)
base, ext = os.path.splitext(filename)
# Use split to separate components
components = base.split("_")
for component in components:
key, value = component.split("-")[0], "-".join(component.split("-")[1:])
if key in ['sub', 'ses', 'task', 'run']:
bids_components[key] = value
elif key == 'desc':
# This will ensure that we capture all the parts of the descriptor after 'desc-' till the end of the filename (excluding the extension).
# Special handling for desc since it might contain additional underscores
desc_index = base.index("desc-")
eeg_index = base.index("_eeg")
bids_components['desc'] = base[desc_index + len("desc-"):eeg_index][14:] # just get the audio descriptor
break # Since desc is the last identifiable component before the modality suffix (eeg), we break out.
return bids_components
def find_files_with_prefix_suffix(search_dir, prefix, suffix):
"""
Find files in a directory (and its subdirectories) that start with a specific prefix and end with a specific suffix.
Parameters:
- search_dir: The directory in which to start the search.
- prefix: The prefix string that files should start with.
- suffix: The suffix string that files should end with.
Returns:
- A list of full paths to files matching the criteria.
"""
matching_files = []
for dirpath, dirnames, filenames in os.walk(search_dir):
for filename in filenames:
if filename.startswith(prefix) and filename.endswith(suffix):
full_path = os.path.join(dirpath, filename)
matching_files.append(full_path)
return matching_files
def save_pickle(data, filename):
"""
Save data to a pickle file.
Parameters:
- data: The data to be saved.
- filename: The filename of the pickle file.
"""
with open(filename, 'wb') as f:
pickle.dump(data, f)
def load_pickle(filename):
"""
Load data from a pickle file.
Parameters:
- filename: The filename of the pickle file.
Returns:
- The data loaded from the pickle file.
"""
with open(filename, 'rb') as f:
data = pickle.load(f)
return data
def window_data(data, window_length, hop):
"""Window data into overlapping windows.
Parameters
----------
data: np.ndarray
Data to window. Shape (n_samples, n_channels)
window_length: int
Length of the window in samples.
hop: int
Hop size in samples.
Returns
-------
np.ndarray
Windowed data. Shape (n_windows, window_length, n_channels)
"""
new_data = np.empty(
((data.shape[0] - window_length) // hop, window_length, data.shape[1])
)
for i in range(new_data.shape[0]):
new_data[i, :, :] = data[
i * hop : i * hop + window_length, : # noqa: E203 E501
]
return new_data
class EarlyStopping:
def __init__(self, patience=5):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
def step(self, current_score):
"""
Check if early stopping criteria is met.
Args:
current_score (float): Current validation score.
Returns:
bool: Whether training should be stopped.
"""
if self.best_score is None:
self.best_score = current_score
elif current_score >= self.best_score: # Assuming we want to minimize the score (e.g., loss). For accuracy, you should change this condition.
self.best_score = current_score
self.counter = 0
elif current_score < self.best_score:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
def load_model(model, model_path, device):
"""
Load a model from a file.
Parameters:
- model: The model to load the parameters into.
- model_path: The path to the model file.
- device: The device to load the model on.
"""
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model