-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
improvement evo_search_loop and add predictors
- Loading branch information
gaoyang07
committed
Sep 20, 2022
1 parent
4e80037
commit 07b516d
Showing
18 changed files
with
740 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .check import check_subnet_flops | ||
from .check import check_subnet_flops, get_subnet_resources | ||
from .genetic import crossover | ||
|
||
__all__ = ['crossover', 'check_subnet_flops'] | ||
__all__ = ['crossover', 'check_subnet_flops', 'get_subnet_resources'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 4 additions & 0 deletions
4
mmrazor/models/task_modules/multi_object_optimizer/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .packages import HighTradeoffPoints | ||
|
||
__all__ = ['HighTradeoffPoints'] |
4 changes: 4 additions & 0 deletions
4
mmrazor/models/task_modules/multi_object_optimizer/packages/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .high_tradeoff_points import HighTradeoffPoints | ||
|
||
__all__ = ['HighTradeoffPoints'] |
81 changes: 81 additions & 0 deletions
81
mmrazor/models/task_modules/multi_object_optimizer/packages/high_tradeoff_points.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import numpy as np | ||
|
||
from pymoo.core.decision_making import DecisionMaking, NeighborFinder, find_outliers_upper_tail | ||
from pymoo.util.normalization import normalize | ||
from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting | ||
|
||
from pymoo.config import Config | ||
Config.warnings['not_compiled'] = False | ||
|
||
|
||
class HighTradeoffPoints(DecisionMaking): | ||
"""Method for multi-object optimization. | ||
Args: | ||
ratio(float): weight between score_key and sec_obj, details in | ||
demo/nas/demo.ipynb. | ||
epsilon(float): specific a radius for each neighbour. | ||
n_survive(int): how many high-tradeoff points will return finally. | ||
""" | ||
|
||
def __init__(self, | ||
ratio=1, | ||
epsilon=0.125, | ||
n_survive=None, | ||
**kwargs) -> None: | ||
super().__init__(**kwargs) | ||
self.epsilon = epsilon | ||
self.n_survive = n_survive | ||
self.ratio = ratio | ||
|
||
def _do(self, data, **kwargs): | ||
front = NonDominatedSorting().do(data, only_non_dominated_front=True) | ||
F = data[front, :] | ||
|
||
n, m = F.shape | ||
F = normalize(F, self.ideal, self.nadir) | ||
F[:, 1] = F[:, 1] * self.ratio | ||
|
||
neighbors_finder = NeighborFinder( | ||
F, epsilon=0.125, n_min_neigbors='auto', consider_2d=False) | ||
|
||
mu = np.full(n, -np.inf) | ||
|
||
for i in range(n): | ||
|
||
# for each neighbour in a specific radius of that solution | ||
neighbors = neighbors_finder.find(i) | ||
|
||
# calculate the trade-off to all neighbours | ||
diff = F[neighbors] - F[i] | ||
|
||
# calculate sacrifice and gain | ||
sacrifice = np.maximum(0, diff).sum(axis=1) | ||
gain = np.maximum(0, -diff).sum(axis=1) | ||
|
||
np.warnings.filterwarnings('ignore') | ||
tradeoff = sacrifice / gain | ||
|
||
# otherwise find the one with the smalled one | ||
mu[i] = np.nanmin(tradeoff) | ||
|
||
# if given topk | ||
if self.n_survive is not None: | ||
n_survive = min(self.n_survive, len(mu)) | ||
index = np.argsort(mu)[-n_survive:][::-1] | ||
front_survive = front[index] | ||
|
||
self.n_survive -= n_survive | ||
if self.n_survive == 0: | ||
return front_survive | ||
# in case the survived in front is not enough for topk | ||
index = np.array(list(set(np.arange(len(data))) - set(front))) | ||
unused_data = data[index] | ||
no_front_survive = index[self._do(unused_data)] | ||
|
||
return np.concatenate([front_survive, no_front_survive]) | ||
else: | ||
# return points with trade-off > 2*sigma | ||
mu = find_outliers_upper_tail(mu) | ||
return mu if len(mu) else [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .metric_predictor import MetricPredictor | ||
from .zero_shot_predictor import ZeroShotPredictor | ||
|
||
__all__ = ['MetricPredictor', 'ZeroShotPredictor'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import abstractmethod | ||
|
||
|
||
class BasePredictor(): | ||
"""Base predictor.""" | ||
|
||
def __init__(self): | ||
"""init.""" | ||
pass | ||
|
||
@abstractmethod | ||
def predict(self, model, predict_args): | ||
"""predict result.""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .carts_handler import CartsHandler | ||
from .gp_handler import GaussProcessHandler | ||
from .mlp_handler import MLPHandler | ||
from .rbf_handler import RBFHandler | ||
|
||
__all__ = ['CartsHandler', 'GaussProcessHandler', 'MLPHandler', 'RBFHandler'] |
Oops, something went wrong.