Skip to content

Commit

Permalink
Implement Fixed Parameter Asks (facebookresearch#512)
Browse files Browse the repository at this point in the history
Summary:

Allow asks to provide some some parameter values to fix when generating points. When set only the unset parameters will be automatically generated.

This more or less only works for OptimizeAcqfGenerator. Other generators will simply throw a warning and ignore the directive.

To conform API, every generator's gen method now includes the **kwargs but most of them will simply ignore it.

Differential Revision: D67956510
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Jan 9, 2025
1 parent f046f93 commit 66d4a05
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 12 deletions.
2 changes: 1 addition & 1 deletion aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
pass

@abc.abstractmethod
def gen(self, num_points: int, model: AEPsychModelType) -> torch.Tensor:
def gen(self, num_points: int, model: AEPsychModelType, **kwargs) -> torch.Tensor:
pass

@classmethod
Expand Down
5 changes: 3 additions & 2 deletions aepsych/generators/epsilon_greedy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,18 @@ def from_config(cls, config: Config) -> "EpsilonGreedyGenerator":
epsilon = config.getfloat(classname, "epsilon", fallback=0.1)
return cls(lb=lb, ub=ub, subgenerator=subgen, epsilon=epsilon)

def gen(self, num_points: int, model: ModelProtocol) -> torch.Tensor:
def gen(self, num_points: int, model: ModelProtocol, **kwargs) -> torch.Tensor:
"""Query next point(s) to run by sampling from the subgenerator with probability 1-epsilon, and randomly otherwise.
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Model to use for generating points.
**kwargs: Passed to subgenerator if not exploring
"""
if num_points > 1:
raise NotImplementedError("Epsilon-greedy batched gen is not implemented!")
if np.random.uniform() < self.epsilon:
sample = np.random.uniform(low=self.lb, high=self.ub)
return torch.tensor(sample).reshape(1, -1)
else:
return self.subgenerator.gen(num_points, model)
return self.subgenerator.gen(num_points, model, **kwargs)
2 changes: 2 additions & 0 deletions aepsych/generators/manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by quasi-randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Defaults to 1.
model (AEPsychMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
**kwargs: Ignored, API compatibility
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""
Expand Down
2 changes: 2 additions & 0 deletions aepsych/generators/monotonic_rejection_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ def gen(
self,
num_points: int, # Current implementation only generates 1 point at a time
model: MonotonicRejectionGP,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query. Currently only supports 1.
model (AEPsychMixin): Fitted model of the data.
**kwargs: Ignored, API compatibility
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""
Expand Down
2 changes: 2 additions & 0 deletions aepsych/generators/monotonic_thompson_sampler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ def gen(
self,
num_points: int, # Current implementation only generates 1 point at a time
model: MonotonicRejectionGP,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query. current implementation only generates 1 point at a time.
model (MonotonicRejectionGP): Fitted model of the data.
**kwargs: Ignored, API compatibility
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""
Expand Down
3 changes: 3 additions & 0 deletions aepsych/generators/random_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility.
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Currently, only 1 point can be queried at a time.
model (AEPsychMixin, optional): Model to use for generating points. Not used in this generator.
**kwargs: Ignored, API compatibility
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""
Expand Down
6 changes: 5 additions & 1 deletion aepsych/generators/semi_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,25 @@ def gen( # type: ignore[override]
num_points: int,
model: SemiParametricGPModel, # type: ignore[override]
context_objective: Type = SemiPThresholdObjective,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function for both context and intensity.
Args:
num_points (int): Number of points to query.
model (SemiParametricGPModel): Fitted semi-parametric model of the data.
context_objective (Type): The objective function used for context. Defaults to SemiPThresholdObjective.
**kwargs: Passed to generator
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""
if "fixed_features" in kwargs:
raise ValueError("Fixed features not supported for semi_p generators")

fixed_features = {model.stim_dim: 0}
next_x = super().gen(
num_points=num_points, model=model, fixed_features=fixed_features
num_points=num_points, model=model, fixed_features=fixed_features, **kwargs
)
# to compute intensity, we need the point where f is at the
# threshold as a function of context. self.acqf_kwargs should contain
Expand Down
2 changes: 2 additions & 0 deletions aepsych/generators/sobol_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by quasi-randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Defaults to 1.
moodel (AEPsychMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
**kwargs: Ignored, API compatibility
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim] or [num_points x dim x stimuli_per_trial] if stimuli_per_trial != 1.
"""
Expand Down
11 changes: 9 additions & 2 deletions aepsych/server/message_handlers/handle_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def handle_ask(server, request):
return new_config


def ask(server, num_points=1):
def ask(server, num_points=1, **kwargs):
"""get the next point to query from the model
Returns:
dict -- new config dict (keys are strings, values are floats)
Expand All @@ -49,7 +49,14 @@ def ask(server, num_points=1):
server.strat._make_next_strat()
return None

# The fixed_pars kwargs name is purposefully differend to the fixed_features
# expected by botorch's optimize acqf to avoid doubling up ever while allowing other
# kwargs to pass through
if "fixed_pars" in kwargs:
fixed_pars = kwargs.pop("fixed_pars")
kwargs["fixed_features"] = server._fixed_to_idx(fixed_pars)

# index by [0] is temporary HACK while serverside
# doesn't handle batched ask
next_x = server.strat.gen()[0]
next_x = server.strat.gen(num_points=num_points, **kwargs)[0]
return server._tensor_to_config(next_x)
20 changes: 20 additions & 0 deletions aepsych/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import threading
import traceback
import warnings
from typing import Dict, Union

import aepsych.database.db as db
import aepsych.utils_logging as utils_logging
Expand Down Expand Up @@ -270,6 +271,25 @@ def _config_to_tensor(self, config):

return x

def _fixed_to_idx(self, fixed: Dict[str, Union[float, str]]):
# Given a dictionary of fixed parameters, turn the parameters names into indices
# transforms values as necessary
dummy = np.zeros(len(self.parnames)).astype("O")
for key, value in fixed.items():
idx = self.parnames.index(key)
dummy[idx] = value
dummy = np.expand_dims(dummy, 0)
dummy = self.strat.transforms.str_to_indices(dummy)
dummy = self.strat.transforms.transform(dummy)[0]

# Turn the dummy back into a dict
fixed_features = {}
for key in fixed.keys():
idx = self.parnames.index(key)
fixed_features[idx] = dummy[idx].item()

return fixed_features

def __getstate__(self):
# nuke the socket since it's not pickleble
state = self.__dict__.copy()
Expand Down
14 changes: 11 additions & 3 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from aepsych.acquisition.monotonic_rejection import MonotonicMCAcquisition
from aepsych.config import Config
from aepsych.generators import OptimizeAcqfGenerator
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin
from aepsych.models.utils import get_extremum, get_jnd, get_max, get_min, inv_query
Expand Down Expand Up @@ -296,14 +297,14 @@ def normalize_inputs(

return x, y, n

# TODO: allow user to pass in generator options
@ensure_model_is_fresh
def gen(self, num_points: int = 1) -> torch.Tensor:
def gen(self, num_points: int = 1, **kwargs) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query. Defaults to 1.
Other arguments are forwared to underlying model.
**kwargs: Kwargs to send to pass to the underlying generator.
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
Expand All @@ -313,8 +314,14 @@ def gen(self, num_points: int = 1) -> torch.Tensor:
original_device = self.model.device
self.model.to(self.generator_device) # type: ignore

if "fixed_features" in kwargs and not isinstance(
self.generator, OptimizeAcqfGenerator
):
logger.warning(
f"You cannot generate points with specific values using {self.generator.__class__.__name__}, fixed_features/fixed_pars options ignored."
)
self._count = self._count + num_points
points = self.generator.gen(num_points, self.model)
points = self.generator.gen(num_points, self.model, **kwargs)

if original_device is not None:
self.model.to(original_device) # type: ignore
Expand Down Expand Up @@ -796,6 +803,7 @@ def gen(self, num_points: int = 1, **kwargs) -> torch.Tensor:
if self._strat.finished:
self._make_next_strat()
self._suggest_count = self._suggest_count + num_points

return self._strat.gen(num_points=num_points, **kwargs)

def finish(self) -> None:
Expand Down
5 changes: 4 additions & 1 deletion aepsych/transforms/ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def str_to_indices(self, obj_arr: np.ndarray) -> np.ndarray:

if self.string_map is not None:
for idx, cats in self.string_map.items():
obj_arr[:, idx] = [cats.index(cat) for cat in obj_arr[:, idx]]
obj_arr[:, idx] = [
cats.index(cat) if isinstance(cat, str) else cat
for cat in obj_arr[:, idx]
]

return obj_arr
5 changes: 3 additions & 2 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,20 +385,21 @@ def __init__(
self.max_asks = self._base_obj.max_asks

def gen(
self, num_points: int = 1, model: Optional[AEPsychMixin] = None
self, num_points: int = 1, model: Optional[AEPsychMixin] = None, **kwargs
) -> torch.Tensor:
r"""Query next point(s) to run from the generator and return them untransformed.
Args:
num_points (int): Number of points to query, defaults to 1.
model (AEPsychMixin, optional): The model to use to generate points, can be
None if no model is needed.
**kwargs: Kwargs to pass to the generator's generator.
Returns:
torch.Tensor: Next set of point(s) to evaluate, `[num_points x dim]` or
`[num_points x dim x stimuli_per_trial]` if `self.stimuli_per_trial != 1`,
which will be untransformed.
"""
x = self._base_obj.gen(num_points, model)
x = self._base_obj.gen(num_points, model, **kwargs)
return self.transforms.untransform(x)

def _get_acqf_options(self, acqf: AcquisitionFunction, config: Config):
Expand Down
88 changes: 88 additions & 0 deletions tests/server/message_handlers/test_ask_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python3
# Copyright (c) Meta, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import logging

from ..test_server import BaseServerTestCase


class AskHandlerTestCase(BaseServerTestCase):
def test_fixed_ask(self):
config_str = f"""
[common]
parnames = [par1, par2]
stimuli_per_trial = 1
outcome_types = [binary]
strategy_names = [init_strat, opt_strat]
[par1]
par_type = continuous
lower_bound = 1
upper_bound = 100
[par2]
par_type = continuous
lower_bound = 0
upper_bound = 1
[init_strat]
generator = SobolGenerator
min_total_tells = 1
[opt_strat]
generator = OptimizeAcqfGenerator
acqf = MCLevelSetEstimation
model = GPClassificationModel
min_total_tells = 2
"""
setup_request = {
"type": "setup",
"message": {"config_str": config_str},
}
self.s.handle_request(setup_request)

fixed1 = 75
fixed2 = 0.75

# SobolGenerator
# One fixed
with self.assertLogs(level=logging.WARNING) as logs:
resp = self.s.handle_request(
{"type": "ask", "message": {"fixed_pars": {"par1": fixed1}}}
)
outputs = ";".join(logs.output)
self.assertTrue("cannot generate points with specific values" in outputs)

self.s.handle_request(
{"type": "tell", "message": {"config": resp["config"], "outcome": 1}}
)
self.s.handle_request(
{"type": "tell", "message": {"config": resp["config"], "outcome": 0}}
)

# OptimizeAcqfGenerator
# One fixed
resp = self.s.handle_request(
{"type": "ask", "message": {"fixed_pars": {"par1": fixed1}}}
)
self.assertTrue(resp["config"]["par1"][0] == fixed1)

self.s.handle_request(
{"type": "tell", "message": {"config": resp["config"], "outcome": 1}}
)

# All fixed
resp = self.s.handle_request(
{
"type": "ask",
"message": {"fixed_pars": {"par1": fixed1, "par2": fixed2}},
}
)

self.assertTrue(resp["config"]["par1"][0] == fixed1)
self.assertTrue(resp["config"]["par2"][0] == fixed2)

0 comments on commit 66d4a05

Please sign in to comment.