Skip to content

Commit

Permalink
add UpdateByEnqueue.
Browse files Browse the repository at this point in the history
  • Loading branch information
KanaiYuma-aist committed Mar 26, 2024
1 parent 32bf9d3 commit d9c0b6b
Showing 1 changed file with 74 additions and 83 deletions.
157 changes: 74 additions & 83 deletions aiaccel/hpo/samplers/nelder_mead_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class NelderMeadCoefficient:
s: float = 0.5


class UpdateByEnqueue(Exception):
def __init__(self, additional_vertices: list[np.ndarray], additional_values: list[float]) -> None:
self.additional_vertices = additional_vertices
self.additional_values = additional_values


class NelderMeadAlgorism:
vertices: np.ndarray
values: np.ndarray
Expand Down Expand Up @@ -125,62 +131,35 @@ def _initialization(self) -> Generator[np.ndarray, None, None]:

def _recontract_simplex(
self,
vertices: list[np.ndarray],
values: list[float],
additional_vertices: list[np.ndarray],
additional_values: list[float],
) -> None:
dimension = len(self._search_space)
new_vertices = np.array(list(self.vertices) + additional_vertices)
new_values = np.array(list(self.values) + additional_values)

order = np.argsort(new_values)
self.vertices, self.values = new_vertices[order][:dimension + 1], new_values[order][:dimension + 1]

def validate_better_param_in_enqueue(
self,
vertices_before_processing: list[np.ndarray],
values_before_processing: list[float],
enqueue_values: list[float]
) -> bool:
) -> None:
enqueue_vertices = []
while not self.enqueue_vertex_queue.empty():
enqueue_vertices.append(self.enqueue_vertex_queue.get(block=False))

if len([i for i in enqueue_values if i < min(values + [self.values[-1]])]) > 0:
# recontract_simplex
dimension = len(self._search_space)
new_vertices = np.array(list(self.vertices) + vertices + enqueue_vertices)
new_values = np.array(list(self.values) + values + enqueue_values)
worst_value = min(
values_before_processing + [self.values[-1]]
) if len(values_before_processing) > 0 or len(self.values) > 0 else None

order = np.argsort(new_values)
new_vertices, new_values = new_vertices[order][:dimension + 1], new_values[order][:dimension + 1]

return True
else:
return False

def _expand(self, yr: np.ndarray, fr: float, ye: np.ndarray, fe: float) -> tuple[list[np.ndarray], list[float]]:
self.vertices[-1], self.values[-1] = (ye, fe) if fe < fr else (yr, fr)
return ([yr], [fr]) if fe < fr else ([ye], [fe])

def _outside_contract(self,
yr: np.ndarray,
fr: float,
yoc: np.ndarray,
foc: float
) -> tuple[list[np.ndarray], list[float], bool]:
if foc <= fr:
self.vertices[-1], self.values[-1] = yoc, foc
shrink_requied = False
past_vertices, past_values = [yr], [fr]
else:
shrink_requied = True
past_vertices, past_values = [yr, yoc], [fr, foc]

return past_vertices, past_values, shrink_requied

def _inside_contract(self,
yr: np.ndarray,
fr: float,
yic: np.ndarray,
fic: float
) -> tuple[list[np.ndarray], list[float], bool]:
if fic < self.values[-1]:
self.vertices[-1], self.values[-1] = yic, fic
shrink_requied = False
past_vertices, past_values = [yr], [fr]
else:
shrink_requied = True
past_vertices, past_values = [yr, yic], [fr, fic]

return past_vertices, past_values, shrink_requied
if worst_value is not None and len([v for v in enqueue_values if v < worst_value]) > 0:
raise UpdateByEnqueue(
vertices_before_processing + enqueue_vertices,
values_before_processing + enqueue_values
)

def _generator(self) -> Generator[np.ndarray, None, None]:
# initialization
Expand All @@ -189,51 +168,63 @@ def _generator(self) -> Generator[np.ndarray, None, None]:
# main loop
shrink_requied = False
while True:
# sort self.vertices by their self.values
order = np.argsort(self.values)
self.vertices, self.values = self.vertices[order], self.values[order]
try:
# sort self.vertices by their self.values
order = np.argsort(self.values)
self.vertices, self.values = self.vertices[order], self.values[order]

# reflect
yc = self.vertices[:-1].mean(axis=0)
yield (yr := yc + self.coeff.r * (yc - self.vertices[-1]))

fr, enqueue_values = yield from self._waiting_for_float()
self.validate_better_param_in_enqueue([yr], [fr], enqueue_values)

# reflect
yc = self.vertices[:-1].mean(axis=0)
yield (yr := yc + self.coeff.r * (yc - self.vertices[-1]))
if self.values[0] <= fr < self.values[-2]:
self.vertices[-1], self.values[-1] = yr, fr
elif fr < self.values[0]: # expand
yield (ye := yc + self.coeff.e * (yc - self.vertices[-1]))

fr, enqueue_values = yield from self._waiting_for_float()
past_vertices, past_values = [yr], [fr]
fe, enqueue_values = yield from self._waiting_for_float()
self.validate_better_param_in_enqueue([yr, ye], [fr, fe], enqueue_values)

if self._recontract_simplex(past_vertices, past_values, enqueue_values):
continue
self.vertices[-1], self.values[-1] = (ye, fe) if fe < fr else (yr, fr)

if self.values[0] <= fr < self.values[-2]:
self.vertices[-1], self.values[-1] = yr, fr
elif fr < self.values[0]: # expand
yield (ye := yc + self.coeff.e * (yc - self.vertices[-1]))
elif self.values[-2] <= fr < self.values[-1]: # outside contract
yield (yoc := yc + self.coeff.oc * (yc - self.vertices[-1]))

fe, enqueue_values = yield from self._waiting_for_float()
past_vertices, past_values = self._expand(yr, fr, ye, fe)
foc, enqueue_values = yield from self._waiting_for_float()
self.validate_better_param_in_enqueue([yr, yoc], [fr, foc], enqueue_values)

elif self.values[-2] <= fr < self.values[-1]: # outside contract
yield (yoc := yc + self.coeff.oc * (yc - self.vertices[-1]))
if foc <= fr:
self.vertices[-1], self.values[-1] = yoc, foc
shrink_requied = False
else:
shrink_requied = True
elif self.values[-1] <= fr: # inside contract
yield (yic := yc + self.coeff.ic * (yc - self.vertices[-1]))

foc, enqueue_values = yield from self._waiting_for_float()
past_vertices, past_values, shrink_requied = self._outside_contract(yr, fr, yoc, foc)
fic, enqueue_values = yield from self._waiting_for_float()
self.validate_better_param_in_enqueue([yr, yic], [fr, fic], enqueue_values)

elif self.values[-1] <= fr: # inside contract
yield (yic := yc + self.coeff.ic * (yc - self.vertices[-1]))
if fic < self.values[-1]:
self.vertices[-1], self.values[-1] = yic, fic
shrink_requied = False
else:
shrink_requied = True

fic, enqueue_values = yield from self._waiting_for_float()
past_vertices, past_values, shrink_requied = self._inside_contract(yr, fr, yic, fic)
# shrink
if shrink_requied:
self.vertices = self.vertices[0] + self.coeff.s * (self.vertices - self.vertices[0])
yield from self.vertices[1:]

if self._recontract_simplex(past_vertices, past_values, enqueue_values):
continue
self.values[1:], enqueue_values = yield from self._waiting_for_list(len(self.vertices[1:]))
shrink_requied = False

# shrink
if shrink_requied:
self.vertices = self.vertices[0] + self.coeff.s * (self.vertices - self.vertices[0])
yield from self.vertices[1:]
self.values[1:], enqueue_values = yield from self._waiting_for_list(len(self.vertices[1:]))
self._recontract_simplex([], [], enqueue_values)
self.validate_better_param_in_enqueue([], [], enqueue_values)

shrink_requied = False
except UpdateByEnqueue as e:
self._recontract_simplex(e.additional_vertices, e.additional_values)


class NelderMeadSampler(optuna.samplers.BaseSampler):
Expand Down

0 comments on commit d9c0b6b

Please sign in to comment.