Skip to content

Commit

Permalink
Merge pull request #13 from hoechenberger/next-stim
Browse files Browse the repository at this point in the history
NF, RF: Enable min_n_entropy stimulus selection, turn stim_history into a list
  • Loading branch information
hoechenberger authored May 21, 2019
2 parents 06d8f70 + 459b1ef commit b4b5c4c
Showing 1 changed file with 40 additions and 27 deletions.
67 changes: 40 additions & 27 deletions questplus/qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ def __init__(self, *,
stim_selection_options
Use this argument to specify options for the stimulus selection
method specified via `stim_selection_method`. Currently, this is
only used to specify the number of `n` stimuli that will yield the
`n` smallest entropies `stim_selection_method=min_n_entropy`.
method specified via `stim_selection_method`. Currently, this can
be used to specify the number of `n` stimuli that will yield the
`n` smallest entropies if `stim_selection_method=min_n_entropy`,
and`max_consecutive_reps`, the number of times the same stimulus
can be presented consecutively.
param_estimation_method
The method to use when deriving the final parameter estimate.
Expand All @@ -84,12 +86,17 @@ def __init__(self, *,
self.likelihoods = self._gen_likelihoods()

self.stim_selection = stim_selection_method
self.stim_selection_options = stim_selection_options

if (self.stim_selection == 'min_n_entropy' and
stim_selection_options is None):
self.stim_selection_options = dict(n=4, max_consecutive_reps=2)
else:
self.stim_selection_options = stim_selection_options

self.param_estimation_method = param_estimation_method

self.resp_history = list()
self.stim_history = {p: [] for p in self.stim_domain.keys()}
self.stim_history = list()
self.entropy = np.nan

@staticmethod
Expand Down Expand Up @@ -183,8 +190,7 @@ def update(self, *,
self.posterior /= self.posterior.sum()

# Log the results, too.
for stim_property, stim_val in stim.items():
self.stim_history[stim_property].append(stim_val)
self.stim_history.append(stim)
self.resp_history.append(outcome)

@property
Expand All @@ -200,7 +206,6 @@ def next_stim(self) -> dict:
The stimulus to present next.
"""
stim_selection = self.stim_selection
new_posterior = self.posterior * self.likelihoods

# Probability.
Expand All @@ -216,31 +221,39 @@ def next_stim(self) -> dict:
# Expected entropies for all possible stimulus parameters.
EH = (pk * H).sum(dim=list(self.outcome_domain.keys()))

if stim_selection == 'min_entropy':
if self.stim_selection == 'min_entropy':
# Get coordinates of stimulus properties that minimize entropy.
index = np.unravel_index(EH.argmin(), EH.shape)
coords = EH[index].coords
stim = {stim_property: stim_val.item()
for stim_property, stim_val in coords.items()}
self.entropy = EH.min().item()
# FIXME: currently disabled, need to adopt above method for
# finding correct coordinates!
# elif stim_selection == 'min_n_entropy':
# index = np.argsort(EH)[:4]
# while True:
# stim_candidates = self.stim_domain['intensity'][index.values]
# stim = np.random.choice(stim_candidates)
#
# if len(self.stim_history['intensity']) < 2:
# break
# elif (np.isclose(stim, self.stim_history['intensity'][-1]) and
# np.isclose(stim, self.stim_history['intensity'][-2])):
# print('\n ==> shuffling again... <==\n')
# continue
# else:
# break
#
# print(f'options: {self.stim_domain["intensity"][index.values]} -> {stim}')
elif self.stim_selection == 'min_n_entropy':
# Number of stimuli to include (the n stimuli that yield the lowest
# entropies)
n_stim = self.stim_selection_options['n']

indices = np.unravel_index(EH.argsort(), EH.shape)[0]
indices = indices[:n_stim]

while True:
# Randomly pick one index and retrieve its coordinates
# (stimulus parameters).
candidate_index = np.random.choice(indices)
coords = EH[candidate_index].coords
stim = {stim_property: stim_val.item()
for stim_property, stim_val in coords.items()}

max_reps = self.stim_selection_options['max_consecutive_reps']

if len(self.stim_history) < 2:
break
elif all([stim == prev_stim
for prev_stim in self.stim_history[-max_reps:]]):
# Shuffle again.
continue
else:
break
else:
raise ValueError('Unknown stim_selection supplied.')

Expand Down

0 comments on commit b4b5c4c

Please sign in to comment.