Skip to content

Commit

Permalink
Merge pull request #361 from kaeldai/fix/test_psg_variable
Browse files Browse the repository at this point in the history
fixing minor bug with spike-trains unit test
  • Loading branch information
kaeldai authored Apr 23, 2024
2 parents 44256e3 + f31ed36 commit afd8fbb
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions bmtk/tests/utils/reports/spike_trains/test_spikes_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,29 @@ def test_psg_variable():
times = np.linspace(0.0, 3.0, 1000)
fr = np.exp(-np.power(times - 1.0, 2) / (2*np.power(.5, 2)))*5

psg = PoissonSpikeGenerator(population='test', seed=0.0)
psg = PoissonSpikeGenerator(population='test', seed=100)
psg.add(node_ids=range(10), firing_rate=fr, times=times)

assert(psg.populations == ['test'])
assert(np.all(psg.node_ids() == list(range(10))))
assert(psg.n_spikes() == 59)
assert(np.allclose(psg.time_range(), (139.32107933711294, 2901.3003727909172)))
assert(psg.to_dataframe().shape == (59, 3))
assert(np.allclose(psg.get_times(node_id=0), [442.8378, 520.3624, 640.3880, 1099.0661, 1393.0794, 1725.6109],
atol=1.0e-3))
assert(np.allclose(psg.get_times(node_id=9), [729.6267, 885.2469, 1047.7728, 1276.3554, 1543.6557, 1669.9070,
1881.3605], atol=1.0e-3))
assert(psg.n_spikes() == 54)
assert(np.allclose(psg.time_range(), (170.22331575431056, 2004.337420574704)))
assert(psg.to_dataframe().shape == (54, 3))
assert(np.allclose(psg.get_times(node_id=0), [268.2470, 519.8341, 963.7072, 1004.7012, 1054.3159, 1388.2418, 1727.0501], atol=1.0e-3))
assert(np.allclose(psg.get_times(node_id=9), [302.6400, 706.5132, 719.6730, 897.8392, 1192.0589, 1201.5878, 2004.3374], atol=1.0e-3))


def test_psg_none_seed():
times = np.linspace(0.0, 3.0, 1000)
fr = np.exp(-np.power(times - 1.0, 2) / (2*np.power(.5, 2)))*5

psg0 = PoissonSpikeGenerator(population='test', seed=0)
psg0.add(node_ids=range(10), firing_rate=fr, times=times)

psg_none = PoissonSpikeGenerator(population='test', seed=None)
psg_none.add(node_ids=range(10), firing_rate=fr, times=times)

assert(psg0.n_spikes() != psg_none.n_spikes() and not np.allclose(psg0.n_spikes(), psg_none.n_spikes()))


def test_equals():
Expand Down Expand Up @@ -78,7 +89,8 @@ def test_subset():


if __name__ == '__main__':
test_psg_fixed()
# test_psg_fixed()
# test_psg_variable()
test_psg_none_seed()
# test_equals()
# test_subset()

0 comments on commit afd8fbb

Please sign in to comment.