Skip to content

Commit

Permalink
Merge pull request #38 from hoechenberger/rng
Browse files Browse the repository at this point in the history
ENH: Allow passing a random seed via stim_selection_options
  • Loading branch information
hoechenberger authored Dec 19, 2019
2 parents 146e721 + 15e63bf commit 1ef53b3
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 9 deletions.
5 changes: 3 additions & 2 deletions .appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ install:
- python setup.py build

# Build & install sdist.
- python setup.py sdist
- python setup.py sdist --formats=zip
# - pip install --no-deps dist/questplus-*.zip
# - pip uninstall --yes questplus

Expand All @@ -30,7 +30,8 @@ install:
# - pip install --no-deps dist/questplus-*.whl
# - pip uninstall --yes questplus

- pip install .
- ps: Remove-Item –path dist, build –recurse
- pip install --no-deps .

test_script:
- py.test
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ install:

- rm -rf dist/ build/

- pip install .
- pip install --no-deps .

script:
- py.test
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
v2019.2
-------
* Allow passing a random seed via `stim_selection_options` keyword
argument
* Better handling of `stim_selection_options` defaults (now allows
to supply only a subset of options)

v2019.1
-------
* Allow to pass priors for only some parameters
Expand Down
3 changes: 3 additions & 0 deletions questplus/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
DEFAULT_N = 4
DEFAULT_MAX_CONSECUTIVE_REPS = 2
DEFAULT_RANDOM_SEED = None
35 changes: 29 additions & 6 deletions questplus/qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def __init__(self, *,
method specified via `stim_selection_method`. Currently, this can
be used to specify the number of `n` stimuli that will yield the
`n` smallest entropies if `stim_selection_method=min_n_entropy`,
and`max_consecutive_reps`, the number of times the same stimulus
can be presented consecutively.
and `max_consecutive_reps`, the number of times the same stimulus
can be presented consecutively. A random number generator seed
may be passed via `random_seed=12345`.
param_estimation_method
The method to use when deriving the final parameter estimate.
Expand All @@ -88,11 +89,33 @@ def __init__(self, *,

self.stim_selection = stim_selection_method

if (self.stim_selection == 'min_n_entropy' and
stim_selection_options is None):
self.stim_selection_options = dict(n=4, max_consecutive_reps=2)
if self.stim_selection == 'min_n_entropy':
from ._constants import (DEFAULT_N, DEFAULT_RANDOM_SEED,
DEFAULT_MAX_CONSECUTIVE_REPS)

if stim_selection_options is None:
self.stim_selection_options = dict(
n=DEFAULT_N,
max_consecutive_reps=DEFAULT_MAX_CONSECUTIVE_REPS,
random_seed=DEFAULT_RANDOM_SEED)
else:
self.stim_selection_options = stim_selection_options.copy()

if 'n' not in stim_selection_options:
self.stim_selection_options['n'] = DEFAULT_N
if 'max_consecutive_reps' not in stim_selection_options:
self.stim_selection_options['max_consecutive_reps'] = DEFAULT_MAX_CONSECUTIVE_REPS
if 'random_seed' not in stim_selection_options:
self.stim_selection_options['random_seed'] = DEFAULT_RANDOM_SEED

del DEFAULT_N, DEFAULT_MAX_CONSECUTIVE_REPS, DEFAULT_RANDOM_SEED

seed = self.stim_selection_options['random_seed']
self._rng = np.random.RandomState(seed=seed)
del seed
else:
self.stim_selection_options = stim_selection_options
self._rng = None

self.param_estimation_method = param_estimation_method

Expand Down Expand Up @@ -271,7 +294,7 @@ def next_stim(self) -> dict:
while True:
# Randomly pick one index and retrieve its coordinates
# (stimulus parameters).
candidate_index = np.random.choice(indices)
candidate_index = self._rng.choice(indices)
coords = EH[candidate_index].coords
stim = {stim_property: stim_val.item()
for stim_property, stim_val in coords.items()}
Expand Down
62 changes: 62 additions & 0 deletions questplus/tests/test_qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import scipy.stats
import numpy as np
from questplus.qp import QuestPlus, QuestPlusWeibull
from questplus import _constants


def test_threshold():
Expand Down Expand Up @@ -597,6 +598,66 @@ def test_prior_for_parameter_subset():
'lower_asymptote']).sum())


def test_stim_selection_options():
threshold = np.arange(-40, 0 + 1)
slope, guess, lapse = 3.5, 0.5, 0.02
contrasts = threshold.copy()

stim_domain = dict(intensity=contrasts)
param_domain = dict(threshold=threshold, slope=slope,
lower_asymptote=guess, lapse_rate=lapse)
outcome_domain = dict(response=['Correct', 'Incorrect'])

f = 'weibull'
scale = 'dB'
stim_selection_method = 'min_n_entropy'
param_estimation_method = 'mode'

common_params = dict(stim_domain=stim_domain, param_domain=param_domain,
outcome_domain=outcome_domain, func=f,
stim_scale=scale,
stim_selection_method=stim_selection_method,
param_estimation_method=param_estimation_method)

stim_selection_options = None
q = QuestPlus(**common_params,
stim_selection_options=stim_selection_options)
expected = dict(n=_constants.DEFAULT_N,
max_consecutive_reps=_constants.DEFAULT_MAX_CONSECUTIVE_REPS,
random_seed=_constants.DEFAULT_RANDOM_SEED)
assert expected == q.stim_selection_options

stim_selection_options = dict(n=5)
q = QuestPlus(**common_params,
stim_selection_options=stim_selection_options)
expected = dict(n=5,
max_consecutive_reps=_constants.DEFAULT_MAX_CONSECUTIVE_REPS,
random_seed=_constants.DEFAULT_RANDOM_SEED)
assert expected == q.stim_selection_options

stim_selection_options = dict(max_consecutive_reps=4)
q = QuestPlus(**common_params,
stim_selection_options=stim_selection_options)
expected = dict(n=_constants.DEFAULT_N,
max_consecutive_reps=4,
random_seed=_constants.DEFAULT_RANDOM_SEED)
assert expected == q.stim_selection_options

stim_selection_options = dict(random_seed=999)
q = QuestPlus(**common_params,
stim_selection_options=stim_selection_options)
expected = dict(n=_constants.DEFAULT_N,
max_consecutive_reps=_constants.DEFAULT_MAX_CONSECUTIVE_REPS,
random_seed=999)
assert expected == q.stim_selection_options

stim_selection_options = dict(n=5, max_consecutive_reps=4, random_seed=999)
q = QuestPlus(**common_params,
stim_selection_options=stim_selection_options)
expected = stim_selection_options.copy()
assert expected == q.stim_selection_options


if __name__ == '__main__':
test_threshold()
test_threshold_slope()
Expand All @@ -609,3 +670,4 @@ def test_prior_for_parameter_subset():
test_marginal_posterior()
test_prior_for_unknown_parameter()
test_prior_for_parameter_subset()
test_stim_selection_options()

0 comments on commit 1ef53b3

Please sign in to comment.