-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSpike_analysis.py
1641 lines (1237 loc) · 75.1 KB
/
Spike_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 4 14:46:44 2023
@author: julienballbe
"""
import pandas as pd
import numpy as np
import plotnine as p9
import logging
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import Ordinary_functions as ordifunc
def create_cell_Full_SF_dict_table(original_Full_TVC_table, original_cell_sweep_info_table,do_filter=True,BE_correct =True):
'''
Identify for all TVC table contained in a Full_TVC_table, all spike related_features
Parameters
----------
original_Full_TVC_table : pd.DataFrame
2 columns DataFrame, containing in column 'Sweep' the sweep_id and in the column 'TVC' the corresponding 3 columns DataFrame containing Time, Current and Potential Traces.
original_cell_sweep_info_table : pd.DataFrame
DataFrame containing the information about the different traces for each sweep (one row per sweep).
Returns
-------
Full_SF_dict_table : pd.DataFrame
DataFrame, two columns. For each row, first column ('Sweep') contains Sweep_id, and
second column ('SF') contains a Dict in which the keys correspond to a spike-related feature, and the value to an array of time index (correspondance to sweep related TVC table)
'''
Full_TVC_table = original_Full_TVC_table.copy()
cell_sweep_info_table = original_cell_sweep_info_table.copy()
Full_TVC_table = Full_TVC_table.sort_values(by=['Sweep'])
cell_sweep_info_table = cell_sweep_info_table.sort_values(by=["Sweep"])
sweep_list = np.array(Full_TVC_table['Sweep'])
Full_SF_dict_table=pd.DataFrame(columns=['Sweep',"SF_dict"])
for current_sweep in sweep_list:
current_TVC=ordifunc.get_filtered_TVC_table(Full_TVC_table,current_sweep,do_filter=do_filter,filter=5.,do_plot=False)
current_TVC_copy = current_TVC.copy()
if BE_correct ==True:
BE=original_cell_sweep_info_table.loc[current_sweep,'Bridge_Error_GOhms']
if not np.isnan(BE):
current_TVC_copy.loc[:,'Membrane_potential_mV'] = current_TVC_copy.loc[:,'Membrane_potential_mV']-BE*current_TVC_copy.loc[:,'Input_current_pA']
membrane_trace = np.array(current_TVC_copy['Membrane_potential_mV'])
time_trace = np.array(current_TVC_copy['Time_s'])
current_trace = np.array(current_TVC_copy ['Input_current_pA'])
stim_start = cell_sweep_info_table.loc[current_sweep, 'Stim_start_s']
stim_end = cell_sweep_info_table.loc[current_sweep, 'Stim_end_s']
SF_dict = identify_spike(
membrane_trace, time_trace, current_trace, stim_start, stim_end, do_plot=False)
new_line = pd.DataFrame([str(current_sweep), SF_dict]).T
new_line.columns=['Sweep', 'SF_dict']
Full_SF_dict_table=pd.concat([Full_SF_dict_table,new_line],ignore_index=True)
Full_SF_dict_table.index = Full_SF_dict_table.loc[:, 'Sweep']
Full_SF_dict_table.index = Full_SF_dict_table.index.astype(str)
Full_SF_dict_table.index.name = 'Index'
return Full_SF_dict_table
def create_Full_SF_table(original_Full_TVC_table, original_Full_SF_dict, cell_sweep_info_table,do_filter=True,BE_correct =True):
'''
Create for each Dict contained in original_Full_SF_dict a DataFrame containing the time, voltage and current values of the spike related features (if any, otherwise, empty dataframe)
Parameters
----------
original_Full_TVC_table : pd.DataFrame
2 columns DataFrame, cotaining in column 'Sweep' the sweep_id and in the column 'TVC' the corresponding 3 columns DataFrame containing Time, Current and Potential Traces.
original_Full_SF_dict : pd.DataFrame
DataFrame, two columns. For each row, first column ('Sweep') contains Sweep_id, and
second column ('SF') contains a Dict in which the keys correspond to a spike-related feature, and the value to an array of time index (correspondance to sweep related TVC table)
cell_sweep_info_table : pd.DataFrame
DataFrame containing the information about the different traces for each sweep (one row per sweep).
Returns
-------
Full_SF_table : TYPE
Two columns DataFrame containing in column 'Sweep' the sweep_id and
in the column 'SF' a pd.DataFrame containing the voltage and time values of the different spike features if any (otherwise empty DataFrame).
'''
Full_TVC_table = original_Full_TVC_table.copy()
Full_SF_dict = original_Full_SF_dict.copy()
sweep_list = np.array(Full_TVC_table['Sweep'])
Full_SF_table = pd.DataFrame(columns=["Sweep", "SF"])
for current_sweep in sweep_list:
current_sweep = str(current_sweep)
current_TVC=ordifunc.get_filtered_TVC_table(Full_TVC_table,current_sweep,do_filter=do_filter,filter=5.,do_plot=False)
BE = cell_sweep_info_table.loc[cell_sweep_info_table['Sweep'] == current_sweep,'Bridge_Error_GOhms'].values[0]
if BE_correct == True and not np.isnan(BE) :
current_TVC.loc[:,'Membrane_potential_mV'] -= BE*current_TVC.loc[:,'Input_current_pA']
current_SF_table = create_SF_table(current_TVC, Full_SF_dict.loc[current_sweep, 'SF_dict'].copy())
new_line = pd.DataFrame(
[str(current_sweep), current_SF_table]).T
new_line.columns=["Sweep", "SF"]
Full_SF_table=pd.concat([Full_SF_table,new_line],ignore_index=True)
Full_SF_table.index = Full_SF_table['Sweep']
Full_SF_table.index = Full_SF_table.index.astype(str)
Full_SF_table.index.name = 'Index'
return Full_SF_table
def create_SF_table(original_TVC_table, SF_dict):
'''
From traces and spike features indexes, the function returns the Time, Potential and Current values of the different spike features
Parameters
----------
original_TVC_table : pd.DataFrame
Contains the Time, voltage, Current arranged in columns.
SF_dict : Dict
Dict in which the keys correspond to a spike-related feature, and the value to an array of time index (correspondance to sweep related TVC table).
Returns
-------
SF_table : pd.DataFrame
DataFrame containing the voltage and time values of the different spike features if any (otherwise empty DataFrame).
'''
TVC_table = original_TVC_table.copy()
SF_table = pd.DataFrame(columns=['Time_s', 'Membrane_potential_mV', 'Input_current_pA',
'Potential_first_time_derivative_mV/s', 'Potential_second_time_derivative_mV/s/s', 'Feature'])
SF_table = SF_table.astype({'Time_s':'float',
'Membrane_potential_mV' : 'float',
'Input_current_pA' : 'float',
'Potential_first_time_derivative_mV/s' : 'float',
'Potential_second_time_derivative_mV/s/s' : 'float',
'Feature':'object'})
for feature in SF_dict.keys():
current_feature_table = TVC_table.loc[SF_dict[feature], :].copy()
current_feature_table['Feature'] = feature
if current_feature_table.shape[0]!=0:
SF_table = pd.concat([SF_table, current_feature_table],ignore_index = True)
SF_table = SF_table.sort_values(by=['Time_s'])
if SF_table.shape[0]!=0:
spike_index=0
SF_table.loc[:,"Spike_index"] = 0
had_threshold = False
for elt in range(SF_table.shape[0]):
if SF_table.iloc[elt,5]=="Threshold" and had_threshold==False:
had_threshold = True
SF_table.iloc[elt,6] = spike_index
elif SF_table.iloc[elt,5]!="Threshold":
SF_table.iloc[elt,6] = spike_index
elif SF_table.iloc[elt,5]=="Threshold" and had_threshold==True:
spike_index+=1
SF_table.iloc[elt,6] = spike_index
SF_table = get_spike_half_width(TVC_table, SF_table)
else:
SF_table = pd.DataFrame(columns=['Time_s', 'Membrane_potential_mV', 'Input_current_pA',
'Potential_first_time_derivative_mV/s', 'Potential_second_time_derivative_mV/s/s', 'Feature','Spike_index'])
return SF_table
def get_spike_half_width(TVC_table_original, SF_table_original):
TVC_table = TVC_table_original.copy()
stim_amp_pA = np.nanmean(SF_table_original.loc[:,'Input_current_pA'])
threshold_table = SF_table_original.loc[SF_table_original['Feature'] == 'Threshold',:]
threshold_table = threshold_table.sort_values(by=['Time_s'])
threshold_time = list(threshold_table.loc[:,'Time_s'])
trough_table = SF_table_original.loc[SF_table_original['Feature'] == 'Trough',:]
trough_table = trough_table.sort_values(by=['Time_s'])
trough_time = list(trough_table.loc[:,'Time_s'])
peak_table = SF_table_original.loc[SF_table_original['Feature'] == 'Peak',:]
peak_table = peak_table.sort_values(by=['Time_s'])
peak_time = list(peak_table.loc[:,'Time_s'])
upstroke_table = SF_table_original.loc[SF_table_original['Feature'] == 'Upstroke',:]
upstroke_table = upstroke_table.sort_values(by=['Time_s'])
spike_time_list = list(upstroke_table.loc[:,'Time_s'])
if len(threshold_time) != len(trough_time):
if len(threshold_time)> len(trough_time):
while len(threshold_time)> len(trough_time):
threshold_time=threshold_time[:-1]
elif len(threshold_time) < len(trough_time):
while len(threshold_time) < len(trough_time):
trough_time=trough_time[:-1]
spike_index = 0
for threshold, trough, peak, spike_time in zip(threshold_time, trough_time, peak_time, spike_time_list):
threshold_to_peak_table = TVC_table.loc[(TVC_table['Time_s']>=threshold)&(TVC_table['Time_s']<=peak), :]
membrane_voltage_array = np.array(threshold_to_peak_table.loc[:,'Membrane_potential_mV'])
time_array = np.array(threshold_to_peak_table.loc[:,'Time_s'])
spike_heigth = threshold_to_peak_table.loc[threshold_to_peak_table['Time_s'] == peak,'Membrane_potential_mV'].values[0] - threshold_to_peak_table.loc[threshold_to_peak_table['Time_s'] == threshold,'Membrane_potential_mV'].values[0]
spike_height_line = pd.DataFrame([spike_time, spike_heigth, stim_amp_pA, np.nan, np.nan, "Spike_heigth", spike_index ]).T
spike_height_line.columns = SF_table_original.columns
spike_height_line = spike_height_line.astype({'Time_s':'float',
'Membrane_potential_mV' : 'float',
'Input_current_pA' : 'float',
'Potential_first_time_derivative_mV/s' : 'float',
'Potential_second_time_derivative_mV/s/s' : 'float',
'Feature':'object',
"Spike_index" : 'int'})
SF_table_original = pd.concat([SF_table_original,spike_height_line ],ignore_index=True)
putative_half_spike_heigth = (threshold_to_peak_table.loc[threshold_to_peak_table['Time_s'] == peak,'Membrane_potential_mV'].values[0]+threshold_to_peak_table.loc[threshold_to_peak_table['Time_s'] == threshold,'Membrane_potential_mV'].values[0])/2
putative_half_width_start_index = ordifunc.find_time_index(membrane_voltage_array, putative_half_spike_heigth)
putative_half_width_start = time_array[putative_half_width_start_index]
peak_to_trough_table = TVC_table.loc[(TVC_table['Time_s']>=peak)&(TVC_table['Time_s']<=trough), :]
membrane_voltage_array = np.array(peak_to_trough_table.loc[:,'Membrane_potential_mV'])
time_array = np.array(peak_to_trough_table.loc[:,'Time_s'])
if putative_half_spike_heigth < membrane_voltage_array[-1]: # if spike do not decreases enough to reach half spike computed in ascending phase, then
half_spike_heigth_end = membrane_voltage_array[-2]
half_width_start_index = ordifunc.find_time_index(membrane_voltage_array, half_spike_heigth_end)
half_width_start = time_array[half_width_start_index]
else:
half_spike_heigth_end = putative_half_spike_heigth
half_width_start = putative_half_width_start
#assert membrane_voltage_array[0] >= half_spike_heigth_end >= membrane_voltage_array[-1], "Given potential ({:f}) is outside of potential range ({:f}, {:f})".format(half_spike_heigth, membrane_voltage_array[0], membrane_voltage_array[-1])
half_width_end_index = np.argmin(abs(membrane_voltage_array - half_spike_heigth_end))
#half_width_end_index = ordifunc.find_time_index(membrane_voltage_array, half_spike_heigth)
half_width_end = time_array[half_width_end_index]
half_height_width = half_width_end - half_width_start
half_width_line = pd.DataFrame([half_height_width, np.nan, stim_amp_pA, np.nan, np.nan, "Spike_width_at_half_heigth",spike_index ]).T
half_width_line.columns = SF_table_original.columns
half_width_line = half_width_line.astype({'Time_s':'float',
'Membrane_potential_mV' : 'float',
'Input_current_pA' : 'float',
'Potential_first_time_derivative_mV/s' : 'float',
'Potential_second_time_derivative_mV/s/s' : 'float',
'Feature':'object',
"Spike_index" : 'int'})
SF_table_original = pd.concat([SF_table_original,half_width_line ],ignore_index=True)
spike_index+=1
return SF_table_original
def identify_spike(membrane_trace_array,time_array, current_trace, stim_start_time, stim_end_time, do_plot=False):
'''
Based on AllenSDK.EPHYS_EPHYS_FEATURES module
Identify spike and their related features based on membrane voltage trace and time traces
Parameters
----------
membrane_trace_array : np.array
Time_varying membrane_voltage in mV.
time_array : np.array
Time array in s.
current_trace : np.array
Time_varying input current in pA.
stim_start_time : float
Stimulus start time.
stim_end_time : float
Stimulus end time.
do_plot : Boolean, optional
Do plot. The default is False.
Returns
-------
spike_feature_dict : Dict
Dict in which the keys correspond to a spike-related feature, and the value to an array of time index (correspondance to sweep related TVC table)..
'''
first_derivative=ordifunc.get_derivative(membrane_trace_array,time_array)
second_derivative=ordifunc.get_derivative(first_derivative,time_array)
filtered_second_derivative = ordifunc.filter_trace(second_derivative,time_array,filter=1.,do_plot=False)
TVC_table=pd.DataFrame({'Time_s':time_array,
'Membrane_potential_mV':membrane_trace_array,
'Input_current_pA':current_trace,
'Potential_first_time_derivative_mV/s':first_derivative,
'Potential_second_time_derivative_mV/s/s':second_derivative},
dtype=np.float64)
time_derivative=first_derivative
second_time_derivative=second_derivative
preliminary_spike_index=detect_putative_spikes(v=membrane_trace_array,
t=time_array,
start=stim_start_time,
end=stim_end_time,
filter=5.,
dv_cutoff=18.,
dvdt=time_derivative)
if do_plot:
preliminary_spike_table=TVC_table.iloc[preliminary_spike_index[~np.isnan(preliminary_spike_index)],:]
preliminary_spike_table['Feature']='A-Preliminary_spike_threshold'
current_plot=p9.ggplot(TVC_table,p9.aes(x='Time_s',y="Membrane_potential_mV"))+p9.geom_line()+p9.geom_point(preliminary_spike_table,p9.aes(x='Time_s',y="Membrane_potential_mV",color='Feature'))+p9.xlim(stim_start_time,stim_end_time)
current_plot += p9.xlab('1 - After detect_putative_spikes')
print(current_plot)
peak_index_array=find_peak_indexes(v=membrane_trace_array,
t=time_array,
spike_indexes=preliminary_spike_index,
end=stim_end_time)
if do_plot:
peak_spike_table=TVC_table.iloc[peak_index_array[~np.isnan(peak_index_array)],:]
peak_spike_table['Feature']='B-Preliminary_spike_peak'
Full_table=pd.concat([preliminary_spike_table,peak_spike_table])
current_plot=p9.ggplot(TVC_table,p9.aes(x='Time_s',y="Membrane_potential_mV"))+p9.geom_line()+p9.geom_point(Full_table,p9.aes(x='Time_s',y="Membrane_potential_mV",color='Feature'))+p9.xlim(stim_start_time,stim_end_time)
current_plot += p9.xlab('2 - After find_peak_index')
print(current_plot)
spike_threshold_index,peak_index_array=filter_putative_spikes(v=membrane_trace_array,
t=time_array,
spike_indexes=preliminary_spike_index,
peak_indexes=peak_index_array,
min_height=20.,
min_peak=-30.,
filter=5.,
dvdt=time_derivative)
if do_plot:
peak_spike_table=TVC_table.iloc[peak_index_array[~np.isnan(peak_index_array)],:];peak_spike_table['Feature']='A-Spike_peak'
spike_threshold_table=TVC_table.iloc[spike_threshold_index[~np.isnan(spike_threshold_index)],:]; spike_threshold_table['Feature']='B-Spike_threshold'
Full_table=pd.concat([spike_threshold_table,peak_spike_table])
current_plot=p9.ggplot(TVC_table,p9.aes(x='Time_s',y="Membrane_potential_mV"))+p9.geom_line()
current_plot+=p9.geom_point(Full_table,p9.aes(x='Time_s',y="Membrane_potential_mV",color='Feature'))
current_plot+=p9.xlim(stim_start_time-.01,stim_end_time+.01)
current_plot += p9.xlab('3 - After filter_putative_spikes')
print(current_plot)
upstroke_index=find_upstroke_indexes(v=membrane_trace_array,
t=time_array,
spike_indexes=spike_threshold_index,
peak_indexes=peak_index_array,filter=5.,
dvdt=time_derivative)
if do_plot:
upstroke_table=TVC_table.iloc[upstroke_index[~np.isnan(upstroke_index)],:]; upstroke_table['Feature']='C-Upstroke'
Full_table=pd.concat([Full_table,upstroke_table])
current_plot=p9.ggplot(TVC_table,p9.aes(x='Time_s',y="Membrane_potential_mV"))+p9.geom_line()
current_plot+=p9.geom_point(Full_table,p9.aes(x='Time_s',y="Membrane_potential_mV",color='Feature'))
current_plot+=p9.xlim(stim_start_time-.01,stim_end_time+.01)
current_plot += p9.xlab('4 - After find_upstroke_indexes')
print(current_plot)
spike_threshold_index=refine_threshold_indexes(v=membrane_trace_array,
t=time_array,
peak_indexes=peak_index_array,
upstroke_indexes=upstroke_index,
method = 'Mean_upstroke_fraction',
thresh_frac=0.05,
filter=5.,
dvdt=time_derivative,
dv2dt2=filtered_second_derivative)
if do_plot:
refined_threshold_table=TVC_table.iloc[spike_threshold_index[~np.isnan(spike_threshold_index)],:]; refined_threshold_table['Feature']='D-Refined_spike_threshold'
Full_table=pd.concat([Full_table,refined_threshold_table])
current_plot=p9.ggplot(TVC_table,p9.aes(x='Time_s',y="Membrane_potential_mV"))+p9.geom_line()
current_plot+=p9.geom_point(Full_table,p9.aes(x='Time_s',y="Membrane_potential_mV",color='Feature'))
current_plot+=p9.xlim(stim_start_time-.01,stim_end_time+.01)
current_plot += p9.xlab('5 - After refine_threshold_indexes')
print(current_plot)
spike_threshold_index,peak_index_array,upstroke_index,clipped=check_thresholds_and_peaks(v=membrane_trace_array,
t=time_array,
spike_indexes=spike_threshold_index,
peak_indexes=peak_index_array,
upstroke_indexes=upstroke_index,
start=stim_start_time,
end=stim_end_time,
max_interval=0.01,
thresh_frac=0.05,
filter=5.,
dvdt=time_derivative,
tol=1.0,
reject_at_stim_start_interval=0.)
if do_plot:
spike_threshold_table=TVC_table.iloc[spike_threshold_index[~np.isnan(spike_threshold_index)],:]; spike_threshold_table['Feature']='B-Spike_threshold'
peak_spike_table=TVC_table.iloc[peak_index_array[~np.isnan(peak_index_array)],:];peak_spike_table['Feature']='A-Spike_peak'
upstroke_table=TVC_table.iloc[upstroke_index[~np.isnan(upstroke_index)],:]; upstroke_table['Feature']='C-Upstroke'
Full_table=pd.concat([spike_threshold_table,peak_spike_table,upstroke_table])
current_plot=p9.ggplot(TVC_table,p9.aes(x='Time_s',y="Membrane_potential_mV"))+p9.geom_line()
current_plot+=p9.geom_point(Full_table,p9.aes(x='Time_s',y="Membrane_potential_mV",color='Feature'))
current_plot+=p9.xlim(stim_start_time-.01,stim_end_time+.01)
current_plot += p9.xlab('6 - After check_thresholds_and_peaks')
print(current_plot)
if len(clipped)==1:
if clipped[0] == True:
spike_feature_dict={'Threshold':np.array([]).astype(int),
'Peak':np.array([]).astype(int),
'Upstroke':np.array([]).astype(int),
'Downstroke':np.array([]).astype(int),
'Trough':np.array([]).astype(int),
'Fast_Trough':np.array([]).astype(int),
'Slow_Trough':np.array([]).astype(int),
'ADP':np.array([]).astype(int),
'fAHP':np.array([]).astype(int)}
return spike_feature_dict
trough_index=find_trough_indexes(v=membrane_trace_array,
t=time_array,
spike_indexes=spike_threshold_index,
peak_indexes=peak_index_array,
clipped=clipped,
end=stim_end_time)
if do_plot:
trough_table=TVC_table.iloc[trough_index[~np.isnan(trough_index)].astype(int),:]; trough_table['Feature']='D-Trough'
Full_table=pd.concat([Full_table,trough_table])
current_plot=p9.ggplot(TVC_table,p9.aes(x='Time_s',y="Membrane_potential_mV"))+p9.geom_line()
current_plot+=p9.geom_point(Full_table,p9.aes(x='Time_s',y="Membrane_potential_mV",color='Feature'))
current_plot+=p9.xlim(stim_start_time-.01,stim_end_time+.01)
current_plot += p9.xlab('7 - After find_trough_indexes')
print(current_plot)
fast_AHP_index=find_fast_AHP_indexes(v=membrane_trace_array,
t=time_array,
spike_indexes=spike_threshold_index,
peak_indexes=peak_index_array,
clipped=clipped,
end=stim_end_time,
dvdt=time_derivative)
if do_plot:
fast_AHP_table=TVC_table.iloc[fast_AHP_index[~np.isnan(fast_AHP_index)],:]; fast_AHP_table['Feature']='E-Fast_AHP'
Full_table=pd.concat([Full_table,fast_AHP_table])
current_plot=p9.ggplot(TVC_table,p9.aes(x='Time_s',y="Membrane_potential_mV"))+p9.geom_line()
current_plot+=p9.geom_point(Full_table,p9.aes(x='Time_s',y="Membrane_potential_mV",color='Feature'))
current_plot+=p9.xlim(stim_start_time-.01,stim_end_time+.01)
current_plot+=p9.xlab('8 - After find_fast_AHP_indexes')
print(current_plot)
downstroke_index=find_downstroke_indexes(v=membrane_trace_array,
t=time_array,
peak_indexes=peak_index_array,
trough_indexes=trough_index,
clipped=clipped,
filter=5.,
dvdt=time_derivative)
if do_plot:
downstroke_table=TVC_table.iloc[downstroke_index[~np.isnan(downstroke_index)],:]; downstroke_table['Feature']='F-Downstroke'
Full_table=pd.concat([Full_table,downstroke_table])
current_plot=p9.ggplot(TVC_table,p9.aes(x='Time_s',y="Membrane_potential_mV"))+p9.geom_line()
current_plot+=p9.geom_point(Full_table,p9.aes(x='Time_s',y="Membrane_potential_mV",color='Feature'))
current_plot+=p9.xlim(stim_start_time-.01,stim_end_time+.01)
current_plot+=p9.xlab('9 - After find_trough_indexes')
print(current_plot)
fast_trough_index, adp_index, slow_trough_index, clipped=find_fast_trough_adp_slow_trough(v=membrane_trace_array,
t=time_array,
spike_indexes=spike_threshold_index,
peak_indexes=peak_index_array,
downstroke_indexes=downstroke_index,
clipped=clipped,
end=stim_end_time,
filter=5.,
heavy_filter=1.,
downstroke_frac=.01,
adp_thresh=1.5,
tol=1.,
flat_interval=.002,
adp_max_delta_t=.005,
adp_max_delta_v=10,
dvdt=time_derivative)
if do_plot:
fast_trough_table=TVC_table.iloc[fast_trough_index[~np.isnan(fast_trough_index)],:]; fast_trough_table['Feature']='G-Fast_Trough'
Full_table=pd.concat([Full_table,fast_trough_table])
adp_table=TVC_table.iloc[adp_index[~np.isnan(adp_index)],:]; adp_table['Feature']='H-ADP'
Full_table=pd.concat([Full_table,adp_table])
slow_trough_table=TVC_table.iloc[slow_trough_index[~np.isnan(slow_trough_index)],:]; slow_trough_table['Feature']='I-Slow_Trough'
Full_table=pd.concat([Full_table,slow_trough_table])
current_plot=p9.ggplot(TVC_table,p9.aes(x='Time_s',y="Membrane_potential_mV"))+p9.geom_line()
current_plot+=p9.geom_point(Full_table,p9.aes(x='Time_s',y="Membrane_potential_mV",color='Feature'))
current_plot+=p9.xlim(stim_start_time-.01,stim_end_time+.01)
current_plot+=p9.xlab('10 - After find_fast_trough/ADP/slow_trough_indexes')
print(current_plot)
spike_threshold_index = spike_threshold_index[~np.isnan(spike_threshold_index)].astype(int)
peak_index_array = peak_index_array[~np.isnan(peak_index_array)].astype(int)
upstroke_index = upstroke_index[~np.isnan(upstroke_index)].astype(int)
downstroke_index = downstroke_index[~np.isnan(downstroke_index)].astype(int)
trough_index = trough_index[~np.isnan(trough_index)].astype(int)
fast_trough_index = fast_trough_index[~np.isnan(fast_trough_index)].astype(int)
slow_trough_index = slow_trough_index[~np.isnan(slow_trough_index)].astype(int)
adp_index = adp_index[~np.isnan(adp_index)].astype(int)
fast_AHP_index = fast_AHP_index[~np.isnan(fast_AHP_index)].astype(int)
spike_feature_dict={'Threshold':np.array(spike_threshold_index),
'Peak':np.array(peak_index_array),
'Upstroke':np.array(upstroke_index),
'Downstroke':np.array(downstroke_index),
'Trough':np.array(trough_index),
'Fast_Trough':np.array(fast_trough_index),
'Slow_Trough':np.array(slow_trough_index),
'ADP':np.array(adp_index),
'fAHP':np.array(fast_AHP_index)}
return spike_feature_dict
def detect_putative_spikes(v, t, start=None, end=None, filter=5., dv_cutoff=20., dvdt=None):
"""
From AllenSDK.EPHYS.EPHYS_FEATURES Module
Perform initial detection of spikes and return their indexes.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
start : start of time window for spike detection (optional)
end : end of time window for spike detection (optional)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
putative_spikes : numpy array of preliminary spike indexes
"""
if not isinstance(v, np.ndarray):
raise TypeError("v is not an np.ndarray")
if not isinstance(t, np.ndarray):
raise TypeError("t is not an np.ndarray")
if start is None:
start = t[0]
if end is None:
end = t[-1]
start_index = ordifunc.find_time_index(t, start)
end_index = ordifunc.find_time_index(t, end)
v_window = v[start_index:end_index + 1]
t_window = t[start_index:end_index + 1]
if dvdt is None:
dvdt = ordifunc.get_derivative(v_window, t_window)
else:
dvdt = dvdt[start_index:end_index]
# Find positive-going crossings of dV/dt cutoff level
putative_spikes = np.flatnonzero(np.diff(np.greater_equal(dvdt, dv_cutoff).astype(int)) == 1)
if len(putative_spikes) <= 1:
# Set back to original index space (not just window)
return np.array(putative_spikes) + start_index
# Only keep spike times if dV/dt has dropped all the way to zero between putative spikes
putative_spikes = [putative_spikes[0]] + [s for i, s in enumerate(putative_spikes[1:])
if np.any(dvdt[putative_spikes[i]:s] < 0)]
# Set back to original index space (not just window)
return np.array(putative_spikes) + start_index
def find_peak_indexes(v, t, spike_indexes, end=None):
"""
From AllenSDK.EPHYS.EPHYS_FEATURES Module
Find indexes of spike peaks.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of preliminary spike indexes
end : end of time window for spike detection (optional)
"""
if not end:
end = t[-1]
end_index = ordifunc.find_time_index(t, end)
spks_and_end = np.append(spike_indexes, end_index)
peak_indexes = [np.argmax(v[spk:next]) + spk for spk, next in
zip(spks_and_end[:-1], spks_and_end[1:])]
#finds index of maximum value between two consecutive spikes index
#spk represents the first spike index (going through the list of spike_index without the last one)
#next respesent the second spike index (going through the list of spike index without the first one)
#np.argmax(v[spk:next]) + spk --> add spk because is initilized at 0 in (v[spk:next])
peak_indexes = np.array(peak_indexes)
#peak_indexes = peak_indexes.astype(int)
return np.array(peak_indexes)
def filter_putative_spikes(v, t, spike_indexes, peak_indexes, min_height=2.,
min_peak=-30., filter=5., dvdt=None):
"""
From AllenSDK.EPHYS.EPHYS_FEATURES Module
Filter out events that are unlikely to be spikes based on:
* Height (threshold to peak)
* Absolute peak level
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of preliminary spike indexes
peak_indexes : numpy array of indexes of spike peaks
min_height : minimum acceptable height from threshold to peak in mV (optional, default 2)
min_peak : minimum acceptable absolute peak level in mV (optional, default -30)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
spike_indexes : numpy array of threshold indexes
peak_indexes : numpy array of peak indexes
"""
if not spike_indexes.size or not peak_indexes.size:
return np.array([]), np.array([])
if dvdt is None:
dvdt = ordifunc.get_derivative(v, t)
diff_mask = [np.any(dvdt[peak_ind:spike_ind] < 0)
for peak_ind, spike_ind
in zip(peak_indexes[:-1], spike_indexes[1:])]
#diif_mask --> check if the derivative between a peak and the next threshold ever goes negative
peak_indexes = peak_indexes[np.array(diff_mask + [True])]
spike_indexes = spike_indexes[np.array([True] + diff_mask)]
#keep only peak indexes where diff mask was True (same for spike index)
peak_level_mask = v[peak_indexes] >= min_peak
#check if identified peaks are higher than minimum peak values (defined)
spike_indexes = spike_indexes[peak_level_mask]
peak_indexes = peak_indexes[peak_level_mask]
#keep only spike and peaks if spike_peak is higher than minimum value
height_mask = (v[peak_indexes] - v[spike_indexes]) >= min_height
spike_indexes = spike_indexes[height_mask]
peak_indexes = peak_indexes[height_mask]
#keep only events where the voltage difference between peak and threshold is higher than minimum height (defined)
spike_indexes = np.array(spike_indexes)
#spike_indexes = spike_indexes.astype(int)
peak_indexes = np.array(peak_indexes)
#peak_indexes = peak_indexes.astype(int)
return spike_indexes, peak_indexes
def find_upstroke_indexes(v, t, spike_indexes, peak_indexes, filter=5., dvdt=None):
"""
From AllenSDK.EPHYS.EPHYS_FEATURES Module
Find indexes of maximum upstroke of spike.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of preliminary spike indexes
peak_indexes : numpy array of indexes of spike peaks
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
upstroke_indexes : numpy array of upstroke indexes
"""
if dvdt is None:
dvdt = ordifunc.get_derivative(v, t)
upstroke_indexes = [np.argmax(dvdt[spike:peak]) + spike for spike, peak in
zip(spike_indexes, peak_indexes)]
upstroke_indexes = np.array(upstroke_indexes)
return upstroke_indexes
def refine_threshold_indexes(v, t, peak_indexes,upstroke_indexes,method="Mean_upstroke_fraction", thresh_frac=0.05, filter=5., dvdt=None,dv2dt2=None):
"""
From AllenSDK.EPHYS.EPHYS_FEATURES Module
Refine threshold detection of previously-found spikes.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
upstroke_indexes : numpy array of indexes of spike upstrokes (for threshold target calculation)
thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
threshold_indexes : numpy array of threshold indexes
"""
if not peak_indexes.size:
return np.array([])
if dvdt is None:
dvdt = ordifunc.get_derivative(v, t)
if method == "Second_Derivative":
##########
# Here the threshold is defined as the local maximum of the second voltage derivative
opening_window=np.append(np.array([0]), peak_indexes[:-1])
closing_window=peak_indexes
threshold_indexes = [np.argmax(dv2dt2[prev_peak:nex_peak])+prev_peak for prev_peak, nex_peak in
zip(opening_window, closing_window)]
return np.array(threshold_indexes)
##########
elif method == "Mean_upstroke_fraction":
## Here the threshold is defined as the last index where dvdt= avg_upstroke * thresh_frac
avg_upstroke = dvdt[upstroke_indexes].mean()
target = avg_upstroke * thresh_frac
upstrokes_and_start = np.append(np.array([0]), upstroke_indexes)
threshold_indexes = []
for upstk, upstk_prev in zip(upstrokes_and_start[1:], upstrokes_and_start[:-1]):
voltage_indexes = np.flatnonzero(dvdt[upstk:upstk_prev:-1] <= target)
if not voltage_indexes.size:
# couldn't find a matching value for threshold,
# so just going to the start of the search interval
threshold_indexes.append(upstk_prev)
else:
threshold_indexes.append(upstk - voltage_indexes[0])
threshold_indexes = np.array(threshold_indexes)
return threshold_indexes
def check_thresholds_and_peaks(v, t, spike_indexes, peak_indexes, upstroke_indexes, start=None, end=None,
max_interval=0.01, thresh_frac=0.05, filter=5., dvdt=None,
tol=1.0, reject_at_stim_start_interval=0.):
"""
From AllenSDK.EPHYS.EPHYS_FEATURES Module
Validate thresholds and peaks for set of spikes
Check that peaks and thresholds for consecutive spikes do not overlap
Spikes with overlapping thresholds and peaks will be merged.
Check that peaks and thresholds for a given spike are not too far apart.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of spike indexes
peak_indexes : numpy array of indexes of spike peaks
upstroke_indexes : numpy array of indexes of spike upstrokes
start : start of time window for feature analysis (optional)
end : end of time window for feature analysis (optional)
max_interval : maximum allowed time between start of spike and time of peak in sec (default 0.005)
thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
tol : tolerance for returning to threshold in mV (optional, default 1)
reject_at_stim_start_interval : duration of window after start to reject voltage spikes (optional, default 0)
Returns
-------
spike_indexes : numpy array of modified spike indexes
peak_indexes : numpy array of modified spike peak indexes
upstroke_indexes : numpy array of modified spike upstroke indexes
clipped : numpy array of clipped status of spikes
"""
if start is not None and reject_at_stim_start_interval > 0:
mask = t[spike_indexes] > (start + reject_at_stim_start_interval)
spike_indexes = spike_indexes[mask]
peak_indexes = peak_indexes[mask]
upstroke_indexes = upstroke_indexes[mask]
overlaps = np.flatnonzero(spike_indexes[1:] <= peak_indexes[:-1] + 1)
if overlaps.size:
spike_mask = np.ones_like(spike_indexes, dtype=bool)
spike_mask[overlaps + 1] = False
spike_indexes = spike_indexes[spike_mask]
peak_mask = np.ones_like(peak_indexes, dtype=bool)
peak_mask[overlaps] = False
peak_indexes = peak_indexes[peak_mask]
upstroke_mask = np.ones_like(upstroke_indexes, dtype=bool)
upstroke_mask[overlaps] = False
upstroke_indexes = upstroke_indexes[upstroke_mask]
# Validate that peaks don't occur too long after the threshold
# If they do, try to re-find threshold from the peak
too_long_spikes = []
for i, (spk, peak) in enumerate(zip(spike_indexes, peak_indexes)):
if t[peak] - t[spk] >= max_interval:
logging.info("Need to recalculate threshold-peak pair that exceeds maximum allowed interval ({:f} s)".format(max_interval))
too_long_spikes.append(i)
if too_long_spikes:
if dvdt is None:
dvdt = ordifunc.get_derivative(v, t)
avg_upstroke = dvdt[upstroke_indexes].mean()
target = avg_upstroke * thresh_frac
drop_spikes = []
for i in too_long_spikes:
# First guessing that threshold is wrong and peak is right
peak = peak_indexes[i]
t_0 = ordifunc.find_time_index(t, t[peak] - max_interval)
below_target = np.flatnonzero(dvdt[upstroke_indexes[i]:t_0:-1] <= target)
if not below_target.size:
# Now try to see if threshold was right but peak was wrong
# Find the peak in a window twice the size of our allowed window
spike = spike_indexes[i]
if t[spike] + 2 * max_interval >= t[-1]:
t_0=ordifunc.find_time_index(t, t[-1])
else:
t_0 = ordifunc.find_time_index(t, t[spike] + 2 * max_interval)
new_peak = np.argmax(v[spike:t_0]) + spike
# If that peak is okay (not outside the allowed window, not past the next spike)
# then keep it
if t[new_peak] - t[spike] < max_interval and \
(i == len(spike_indexes) - 1 or t[new_peak] < t[spike_indexes[i + 1]]):
peak_indexes[i] = new_peak
else:
# Otherwise, log and get rid of the spike
logging.info("Could not redetermine threshold-peak pair - dropping that pair")
drop_spikes.append(i)
# raise FeatureError("Could not redetermine threshold")
else:
spike_indexes[i] = upstroke_indexes[i] - below_target[0]
if drop_spikes:
spike_indexes = np.delete(spike_indexes, drop_spikes)
peak_indexes = np.delete(peak_indexes, drop_spikes)
upstroke_indexes = np.delete(upstroke_indexes, drop_spikes)
if not end:
end = t[-1]
end_index = ordifunc.find_time_index(t, end)
clipped = find_clipped_spikes(v, t, spike_indexes, peak_indexes, end_index, tol)
spike_indexes = np.array(spike_indexes)
peak_indexes = np.array(peak_indexes)
upstroke_indexes = np.array(upstroke_indexes)
return spike_indexes, peak_indexes, upstroke_indexes, clipped
def find_clipped_spikes(v, t, spike_indexes, peak_indexes, end_index, tol, time_tol=0.005):
"""
From AllenSDK.EPHYS.EPHYS_FEATURES Module
Check that last spike was not cut off too early by end of stimulus
by checking that the membrane voltage returned to at least the threshold
voltage - otherwise, drop it
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of spike indexes
peak_indexes : numpy array of indexes of spike peaks
end_index: int index of the end of time window for feature analysis
tol: float tolerance to returning to threshold
time_tol: float specify the time window in which
Returns
-------
clipped: Boolean np.array
"""
clipped = np.zeros_like(spike_indexes, dtype=bool)