Skip to content

Commit

Permalink
Implement Fixed Parameter Asks (facebookresearch#512)
Browse files Browse the repository at this point in the history
Summary:

Allow asks to provide some some parameter values to fix when generating points. When set only the unset parameters will be automatically generated.

This more or less only works for OptimizeAcqfGenerator. Other generators will simply throw a warning and ignore the directive.

To conform API, every generator's gen method now includes the **kwargs but most of them will simply ignore it.

Differential Revision: D67956510
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Jan 10, 2025
1 parent f81261f commit 4ee84a0
Show file tree
Hide file tree
Showing 19 changed files with 353 additions and 35 deletions.
35 changes: 31 additions & 4 deletions aepsych/generators/acqf_thompson_sampler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,36 +86,56 @@ def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFuncti
else:
return self.acqf(model=model, **self.acqf_kwargs)

def gen(self, num_points: int, model: ModelProtocol, **gen_options) -> torch.Tensor:
def gen(
self,
num_points: int,
model: ModelProtocol,
fixed_features: Optional[dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Fitted model of the data.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""

if self.stimuli_per_trial == 2:
qbatch_points = self._gen(
num_points=num_points * 2, model=model, **gen_options
num_points=num_points * 2,
model=model,
fixed_features=fixed_features,
**gen_options,
)

# output of super() is (q, dim) but the contract is (num_points, dim, 2)
# so we need to split q into q and pairs and then move the pair dim to the end
return qbatch_points.reshape(num_points, 2, -1).swapaxes(-1, -2)

else:
return self._gen(num_points=num_points, model=model, **gen_options)
return self._gen(
num_points=num_points,
model=model,
fixed_features=fixed_features,
**gen_options,
)

def _gen(
self, num_points: int, model: ModelProtocol, **gen_options
self,
num_points: int,
model: ModelProtocol,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
"""
Generates the next query points by optimizing the acquisition function.
Args:
num_points (int): The number of points to query.
model (ModelProtocol): The fitted model used to evaluate the acquisition function.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
gen_options (dict): Additional options for generating points, including:
- "seed": Random seed for reproducibility.
Expand Down Expand Up @@ -145,6 +165,13 @@ def _gen(
)
X_rnd = bounds_cpu[0] + (bounds_cpu[1] - bounds_cpu[0]) * X_rnd_nlzd

if fixed_features is not None:
logger.warning(
"Fixing parameters for generation with the AcqfThompsonSamplerGenerator changes initial random points to be evaluated, thus not guaranteed to have the same pseudorandom properties as without fixed parameters."
)
for key, value in fixed_features.items():
X_rnd[:, key] = value

acqf_vals = acqf(X_rnd).to(torch.float64)
acqf_vals -= acqf_vals.min()
probability_dist = acqf_vals / acqf_vals.sum()
Expand Down
8 changes: 7 additions & 1 deletion aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def __init__(
pass

@abc.abstractmethod
def gen(self, num_points: int, model: AEPsychModelType) -> torch.Tensor:
def gen(
self,
num_points: int,
model: AEPsychModelType,
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
pass

@classmethod
Expand Down
26 changes: 22 additions & 4 deletions aepsych/generators/epsilon_greedy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Optional

import numpy as np
import torch
from aepsych.config import Config
Expand Down Expand Up @@ -56,17 +58,33 @@ def from_config(cls, config: Config) -> "EpsilonGreedyGenerator":
epsilon = config.getfloat(classname, "epsilon", fallback=0.1)
return cls(lb=lb, ub=ub, subgenerator=subgen, epsilon=epsilon)

def gen(self, num_points: int, model: ModelProtocol) -> torch.Tensor:
def gen(
self,
num_points: int,
model: ModelProtocol,
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by sampling from the subgenerator with probability 1-epsilon, and randomly otherwise.
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Model to use for generating points.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
**kwargs: Passed to subgenerator if not exploring
"""
if num_points > 1:
raise NotImplementedError("Epsilon-greedy batched gen is not implemented!")
if np.random.uniform() < self.epsilon:
sample = np.random.uniform(low=self.lb, high=self.ub)
return torch.tensor(sample).reshape(1, -1)
sample_ = np.random.uniform(low=self.lb, high=self.ub)
sample = torch.tensor(sample_).reshape(1, -1)

if fixed_features is not None:
for key, value in fixed_features.items():
sample[:, key] = value

return sample
else:
return self.subgenerator.gen(num_points, model)
return self.subgenerator.gen(
num_points, model, fixed_features=fixed_features, **kwargs
)
9 changes: 9 additions & 0 deletions aepsych/generators/manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,16 @@ def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility
fixed_features: Optional[Dict[int, float]] = None,
**kwargs, # Ignored
) -> torch.Tensor:
"""Query next point(s) to run by quasi-randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Defaults to 1.
model (AEPsychMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
fixed_features (Dict[int, float], optional): Ignored, kept for consistent
API.
**kwargs: Ignored, API compatibility
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""
Expand All @@ -67,6 +72,10 @@ def gen(
"Asked for more points than are left in the generator! Giving everthing it has!",
RuntimeWarning,
)

if fixed_features is not None:
warnings.warn("Cannot fix features when generating from ManualGenerator")

points = self.points[self._idx : self._idx + num_points]
self._idx += num_points
return points
Expand Down
17 changes: 12 additions & 5 deletions aepsych/generators/monotonic_rejection_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,22 @@ def gen(
self,
num_points: int, # Current implementation only generates 1 point at a time
model: MonotonicRejectionGP,
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query. Currently only supports 1.
model (AEPsychMixin): Fitted model of the data.
fixed_features (Dict[int, float], optional): Not implemented for this generator.
**kwargs: Ignored, API compatibility
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""

if fixed_features is not None:
logger.warning(
"Cannot fix features when generating from MonotonicRejectionGenerator"
)
options = self.model_gen_options or {}
num_restarts = options.get("num_restarts", 10)
raw_samples = options.get("raw_samples", 1000)
Expand All @@ -119,7 +126,7 @@ def gen(
# Augment bounds with deriv indicator
bounds = torch.cat((self.bounds, torch.zeros(2, 1)), dim=1)
# Fix deriv indicator to 0 during optimization
fixed_features = {(bounds.shape[1] - 1): 0.0}
fixed_features_ = {(bounds.shape[1] - 1): 0.0}
# Fix explore features to random values
if self.explore_features is not None:
for idx in self.explore_features:
Expand All @@ -128,7 +135,7 @@ def gen(
+ torch.rand(1, dtype=bounds.dtype)
* (bounds[1, idx] - bounds[0, idx])
).item()
fixed_features[idx] = val
fixed_features_[idx] = val
bounds[0, idx] = val
bounds[1, idx] = val

Expand All @@ -145,7 +152,7 @@ def gen(
clamped_candidates = columnwise_clamp(
X=batch_initial_conditions, lower=bounds[0], upper=bounds[1]
).requires_grad_(True)
candidates = fix_features(clamped_candidates, fixed_features)
candidates = fix_features(clamped_candidates, fixed_features_)
optimizer = torch.optim.SGD(
params=[clamped_candidates], lr=lr, momentum=momentum, nesterov=nesterov
)
Expand Down Expand Up @@ -174,7 +181,7 @@ def closure():
clamped_candidates.data = columnwise_clamp(
X=clamped_candidates, lower=bounds[0], upper=bounds[1]
)
candidates = fix_features(clamped_candidates, fixed_features)
candidates = fix_features(clamped_candidates, fixed_features_)
lr_scheduler.step()

# Extract best point
Expand Down
12 changes: 10 additions & 2 deletions aepsych/generators/monotonic_thompson_sampler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Type
import warnings
from typing import Dict, List, Optional, Type

import torch
from aepsych.acquisition.objective import ProbitObjective
Expand Down Expand Up @@ -58,15 +59,22 @@ def gen(
self,
num_points: int, # Current implementation only generates 1 point at a time
model: MonotonicRejectionGP,
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query. current implementation only generates 1 point at a time.
model (MonotonicRejectionGP): Fitted model of the data.
fixed_features (Dict[int, float], optional): Not implemented for this generator.
**kwargs: Ignored, API compatibility
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""

if fixed_features is not None:
warnings.warn(
"Cannot fix features when generating from MonotonicRejectionGenerator"
)
# Generate the points at which to sample
X = draw_sobol_samples(bounds=model.bounds_, n=self.num_ts_points, q=1).squeeze(
1
Expand Down
37 changes: 33 additions & 4 deletions aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,36 +104,64 @@ def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFuncti
else:
return self.acqf(model=model, **self.acqf_kwargs)

def gen(self, num_points: int, model: ModelProtocol, **gen_options) -> torch.Tensor:
def gen(
self,
num_points: int,
model: ModelProtocol,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Fitted model of the data.
fixed_features (Dict[int, float], optional): The values where the specified
parameters should be at when generating. Should be a dictionary where
the keys are the indices of the parameters to fix and the values are the
values to fix them at.
**gen_options: Additional options for generating points, such as custom configurations.
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""

if self.stimuli_per_trial == 2:
qbatch_points = self._gen(
num_points=num_points * 2, model=model, **gen_options
num_points=num_points * 2,
model=model,
fixed_features=fixed_features,
**gen_options,
)

# output of super() is (q, dim) but the contract is (num_points, dim, 2)
# so we need to split q into q and pairs and then move the pair dim to the end
return qbatch_points.reshape(num_points, 2, -1).swapaxes(-1, -2)

else:
return self._gen(num_points=num_points, model=model, **gen_options)
return self._gen(
num_points=num_points,
model=model,
fixed_features=fixed_features,
**gen_options,
)

def _gen(
self, num_points: int, model: ModelProtocol, **gen_options: Dict[str, Any]
self,
num_points: int,
model: ModelProtocol,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options: Dict[str, Any],
) -> torch.Tensor:
"""
Generates the next query points by optimizing the acquisition function.
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Fitted model of the data.
fixed_features (Dict[int, float], optional): The values where the specified
parameters should be at when generating. Should be a dictionary where
the keys are the indices of the parameters to fix and the values are the
values to fix them at.
gen_options (Dict[str, Any]): Additional options for generating points, such as custom configurations.
Returns:
Expand All @@ -157,6 +185,7 @@ def _gen(
num_restarts=self.restarts,
raw_samples=self.samps,
timeout_sec=self.max_gen_time,
fixed_features=fixed_features,
**gen_options,
)

Expand Down
12 changes: 11 additions & 1 deletion aepsych/generators/random_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
from typing import Dict, Optional

import torch
from aepsych.config import Config
Expand Down Expand Up @@ -39,17 +39,27 @@ def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility.
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Currently, only 1 point can be queried at a time.
model (AEPsychMixin, optional): Model to use for generating points. Not used in this generator.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
**kwargs: Ignored, API compatibility
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""
X = self.bounds_[0] + torch.rand((num_points, self.bounds_.shape[1])) * (
self.bounds_[1] - self.bounds_[0]
)

if fixed_features is not None:
for key, value in fixed_features.items():
X[:, key] = value

return X

@classmethod
Expand Down
Loading

0 comments on commit 4ee84a0

Please sign in to comment.