-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathfilterbank.py
71 lines (55 loc) · 2.39 KB
/
filterbank.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# -*- coding: utf-8 -*-
"""
Created on Fri Nov 1 10:16:21 2019
@author: ALU
"""
import warnings
import scipy.signal
import numpy as np
def filterbank(eeg, fs, idx_fb):
if idx_fb == None:
warnings.warn('stats:filterbank:MissingInput '\
+'Missing filter index. Default value (idx_fb = 0) will be used.')
idx_fb = 0
elif (idx_fb < 0 or 9 < idx_fb):
raise ValueError('stats:filterbank:InvalidInput '\
+'The number of sub-bands must be 0 <= idx_fb <= 9.')
if (len(eeg.shape)==2):
num_chans = eeg.shape[0]
num_trials = 1
else:
num_chans, _, num_trials = eeg.shape
# Nyquist Frequency = Fs/2N
Nq = fs/2
passband = [6, 14, 22, 30, 38, 46, 54, 62, 70, 78]
stopband = [4, 10, 16, 24, 32, 40, 48, 56, 64, 72]
Wp = [passband[idx_fb]/Nq, 90/Nq]
Ws = [stopband[idx_fb]/Nq, 100/Nq]
[N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3, 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
[B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
y = np.zeros(eeg.shape)
if (num_trials == 1):
for ch_i in range(num_chans):
#apply filter, zero phass filtering by applying a linear filter twice, once forward and once backwards.
# to match matlab result we need to change padding length
y[ch_i, :] = scipy.signal.filtfilt(B, A, eeg[ch_i, :], padtype = 'odd', padlen=3*(max(len(B),len(A))-1))
else:
for trial_i in range(num_trials):
for ch_i in range(num_chans):
y[ch_i, :, trial_i] = scipy.signal.filtfilt(B, A, eeg[ch_i, :, trial_i], padtype = 'odd', padlen=3*(max(len(B),len(A))-1))
return y
if __name__ == '__main__':
from scipy.io import loadmat
D = loadmat("sample.mat")
eeg = D['eeg']
eeg = eeg[:, :, (33):(33+125), :]
eeg = eeg[:,:,:,0] #first bank
eeg = eeg[0, :, :] #first target
y1 = filterbank(eeg, 250, 0)
y2 = filterbank(eeg, 250, 9)
y1_from_matlab = loadmat("y1_from_matlab.mat")['y1']
y2_from_matlab = loadmat("y2_from_matlab.mat")['y2']
dif1 = y1 - y1_from_matlab
dif2 = y2 - y2_from_matlab
print("Difference between matlab and python = ", np.sum(dif1))
print("Difference between matlab and python = ", np.sum(dif2))