forked from neurodynamics-ai/travnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example_sorter.py
37 lines (25 loc) · 857 Bytes
/
example_sorter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Copyright (c) Neurodynamics Sollutions llc.
# See LICENSE file for licensing details.
import numpy as np
import fire
from travnet import TravNet, CNN, ModelArgs, Prep, DataArgs
def main(
#TODO add arguments here
#ckpt_dir: str,
#data_fn: str,
#model_path: str = ##TODO convert model to file json
#num_channels: int = 0,
#sample_rate: int = 30000,
#threshold: float = -8.0,
batch_size: int = 40000
):
#Load the data
data = Prep(DataArgs).import_data()
data_out = TravNet(DataArgs).prepare_data(data)
eval_loader = data_out[0]
outputs = TravNet.spike_sorter(ModelArgs, eval_loader).cpu().numpy()
# TODO Output the results in an open source file format (ie neuroshare) for loading
# into other software
return outputs
if __name__ == '__main__':
fire.Fire(main)