forked from facebookresearch/aepsych
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement Fixed Parameter Asks (facebookresearch#512)
Summary: Allow asks to provide some some parameter values to fix when generating points. When set only the unset parameters will be automatically generated. This more or less only works for OptimizeAcqfGenerator. Other generators will simply throw a warning and ignore the directive. To conform API, every generator's gen method now includes the **kwargs but most of them will simply ignore it. Differential Revision: D67956510
- Loading branch information
1 parent
f046f93
commit 66d4a05
Showing
14 changed files
with
155 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta, Inc. and its affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import logging | ||
|
||
from ..test_server import BaseServerTestCase | ||
|
||
|
||
class AskHandlerTestCase(BaseServerTestCase): | ||
def test_fixed_ask(self): | ||
config_str = f""" | ||
[common] | ||
parnames = [par1, par2] | ||
stimuli_per_trial = 1 | ||
outcome_types = [binary] | ||
strategy_names = [init_strat, opt_strat] | ||
[par1] | ||
par_type = continuous | ||
lower_bound = 1 | ||
upper_bound = 100 | ||
[par2] | ||
par_type = continuous | ||
lower_bound = 0 | ||
upper_bound = 1 | ||
[init_strat] | ||
generator = SobolGenerator | ||
min_total_tells = 1 | ||
[opt_strat] | ||
generator = OptimizeAcqfGenerator | ||
acqf = MCLevelSetEstimation | ||
model = GPClassificationModel | ||
min_total_tells = 2 | ||
""" | ||
setup_request = { | ||
"type": "setup", | ||
"message": {"config_str": config_str}, | ||
} | ||
self.s.handle_request(setup_request) | ||
|
||
fixed1 = 75 | ||
fixed2 = 0.75 | ||
|
||
# SobolGenerator | ||
# One fixed | ||
with self.assertLogs(level=logging.WARNING) as logs: | ||
resp = self.s.handle_request( | ||
{"type": "ask", "message": {"fixed_pars": {"par1": fixed1}}} | ||
) | ||
outputs = ";".join(logs.output) | ||
self.assertTrue("cannot generate points with specific values" in outputs) | ||
|
||
self.s.handle_request( | ||
{"type": "tell", "message": {"config": resp["config"], "outcome": 1}} | ||
) | ||
self.s.handle_request( | ||
{"type": "tell", "message": {"config": resp["config"], "outcome": 0}} | ||
) | ||
|
||
# OptimizeAcqfGenerator | ||
# One fixed | ||
resp = self.s.handle_request( | ||
{"type": "ask", "message": {"fixed_pars": {"par1": fixed1}}} | ||
) | ||
self.assertTrue(resp["config"]["par1"][0] == fixed1) | ||
|
||
self.s.handle_request( | ||
{"type": "tell", "message": {"config": resp["config"], "outcome": 1}} | ||
) | ||
|
||
# All fixed | ||
resp = self.s.handle_request( | ||
{ | ||
"type": "ask", | ||
"message": {"fixed_pars": {"par1": fixed1, "par2": fixed2}}, | ||
} | ||
) | ||
|
||
self.assertTrue(resp["config"]["par1"][0] == fixed1) | ||
self.assertTrue(resp["config"]["par2"][0] == fixed2) |