Skip to content

Commit

Permalink
type check with is instead of ==
Browse files Browse the repository at this point in the history
  • Loading branch information
kushaangupta committed Dec 24, 2024
1 parent d73d5bb commit d05bb38
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions neuro_py/ensemble/decoding/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,19 @@ def preprocess_data(hyperparams, ohe, nsv_train, nsv_val, nsv_test, bv_train, bv
num_workers=hyperparams['num_workers'], modeltype=hyperparams['model'])
hyperparams['model_args']['in_dim'] = X_train.shape[-1]
else:
if type(bv_train[0]) == pd.DataFrame:
if type(bv_train[0]) is pd.DataFrame:
y_train = [y.values[:, hyperparams['behaviors']] for y in bv_train]
else:
y_train = [y[:, hyperparams['behaviors']] for y in bv_train]
nbins_per_tseg = [len(y) for y in y_train] # number of time bins in each trial
tseg_bounds_train = np.cumsum([0] + nbins_per_tseg)
if type(bv_val[0]) == pd.DataFrame:
if type(bv_val[0]) is pd.DataFrame:
y_val = [y.values[:, hyperparams['behaviors']] for y in bv_val]
else:
y_val = [y[:, hyperparams['behaviors']] for y in bv_val]
nbins_per_tseg = [len(y) for y in y_val]
tseg_bounds_val = np.cumsum([0] + nbins_per_tseg)
if type(bv_test[0]) == pd.DataFrame:
if type(bv_test[0]) is pd.DataFrame:
y_test = [y.values[:, hyperparams['behaviors']] for y in bv_test]
else:
y_test = [y[:, hyperparams['behaviors']] for y in bv_test]
Expand Down

0 comments on commit d05bb38

Please sign in to comment.