diff --git a/neuro_py/ensemble/decoding/pipeline.py b/neuro_py/ensemble/decoding/pipeline.py index 9c2aa21..c2473fb 100644 --- a/neuro_py/ensemble/decoding/pipeline.py +++ b/neuro_py/ensemble/decoding/pipeline.py @@ -262,19 +262,19 @@ def preprocess_data(hyperparams, ohe, nsv_train, nsv_val, nsv_test, bv_train, bv num_workers=hyperparams['num_workers'], modeltype=hyperparams['model']) hyperparams['model_args']['in_dim'] = X_train.shape[-1] else: - if type(bv_train[0]) == pd.DataFrame: + if type(bv_train[0]) is pd.DataFrame: y_train = [y.values[:, hyperparams['behaviors']] for y in bv_train] else: y_train = [y[:, hyperparams['behaviors']] for y in bv_train] nbins_per_tseg = [len(y) for y in y_train] # number of time bins in each trial tseg_bounds_train = np.cumsum([0] + nbins_per_tseg) - if type(bv_val[0]) == pd.DataFrame: + if type(bv_val[0]) is pd.DataFrame: y_val = [y.values[:, hyperparams['behaviors']] for y in bv_val] else: y_val = [y[:, hyperparams['behaviors']] for y in bv_val] nbins_per_tseg = [len(y) for y in y_val] tseg_bounds_val = np.cumsum([0] + nbins_per_tseg) - if type(bv_test[0]) == pd.DataFrame: + if type(bv_test[0]) is pd.DataFrame: y_test = [y.values[:, hyperparams['behaviors']] for y in bv_test] else: y_test = [y[:, hyperparams['behaviors']] for y in bv_test]