Skip to content

Commit

Permalink
improvement evo_search_loop and add predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyang07 committed Sep 20, 2022
1 parent 4e80037 commit 07b516d
Show file tree
Hide file tree
Showing 18 changed files with 740 additions and 24 deletions.
147 changes: 131 additions & 16 deletions mmrazor/engine/runner/evolution_search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np

import torch
from mmengine import fileio
Expand All @@ -13,11 +14,11 @@
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader

from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS
from mmrazor.models.task_modules import ResourceEstimator, HighTradeoffPoints
from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import Candidates, export_fix_subnet
from mmrazor.utils import SupportRandomSubnet
from .utils import check_subnet_flops, crossover
from .utils import check_subnet_flops, get_subnet_resources, crossover


@LOOPS.register_module()
Expand Down Expand Up @@ -50,6 +51,14 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
init_candidates (str, optional): The candidates file path, which is
used to init `self.candidates`. Its format is usually in .yaml
format. Defaults to None.
trade_off (dict, optional): whether set sec_obj and use multi-object
optimization. Defaults to None.
sec_obj (str): Your second optimization object.
max_score_key (int): Whether your score key is ascend or descend,
if your score key is descend, set max_score_key = 0, otherwise
set max_score_key = the highest score of it.
Default to 100(apply to accuracy or map...).
ratio (float): weight between score_key and sec_obj
"""

def __init__(self,
Expand All @@ -66,7 +75,9 @@ def __init__(self,
mutate_prob: float = 0.1,
flops_range: Optional[Tuple[float, float]] = (0., 330.),
resource_estimator_cfg: Optional[dict] = None,
predictor_cfg: Optional[dict] = None,
score_key: str = 'accuracy/top1',
resource_key: str = 'flops',
init_candidates: Optional[str] = None) -> None:
super().__init__(runner, dataloader, max_epochs)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
Expand All @@ -85,12 +96,17 @@ def __init__(self,
self.top_k = top_k
self.flops_range = flops_range
self.score_key = score_key
self.resource_key = resource_key
self.num_mutation = num_mutation
self.num_crossover = num_crossover
self.mutate_prob = mutate_prob
self.max_keep_ckpts = max_keep_ckpts
self.resume_from = resume_from

self.num_candidates = 4
self.num_mutation = 2
self.num_crossover = 2

if init_candidates is None:
self.candidates = Candidates()
else:
Expand All @@ -109,6 +125,15 @@ def __init__(self,
else:
self.model = runner.model

self.use_predictor = False
if predictor_cfg is not None:
predictor_cfg['score_key'] = self.score_key
predictor_cfg['search_groups'] = self.model.mutator.search_groups

self.predictor = TASK_UTILS.build(predictor_cfg)
self._init_predictor()
self.use_predictor = True

def run(self) -> None:
"""Launch searching."""
self.runner.call_hook('before_train')
Expand Down Expand Up @@ -144,7 +169,7 @@ def run_epoch(self) -> None:
f'{scores_before}')

self.candidates.extend(self.top_k_candidates)
self.candidates.sort(key=lambda x: x[1], reverse=True)
self.sort_candidates(trade_off=None, reverse=True)
self.top_k_candidates = Candidates(self.candidates[:self.top_k])

scores_after = self.top_k_candidates.scores
Expand Down Expand Up @@ -172,19 +197,55 @@ def sample_candidates(self) -> None:
# broadcast candidates to val with multi-GPUs.
broadcast_object_list(self.candidates.data)

def update_candidates_scores(self) -> None:
def update_candidates_scores(self, finetune: bool = False) -> None:
"""Validate candicate one by one from the candicate pool, and update
top-k candicates."""
for i, candidate in enumerate(self.candidates.subnets):
self.model.set_subnet(candidate)
metrics = self._val_candidate()
score = metrics[self.score_key] \
if len(metrics) != 0 else 0.
if finetune:
self._finetune_model()
metrics = self._val_candidate(use_predictor=self.use_predictor,
valid_resources=True)
score = round(metrics[self.score_key] if len(metrics) != 0 else 0., 3)
resource = metrics[self.resource_key]
self.candidates.set_score(i, score)
self.candidates.set_resource(i, resource)

self.runner.logger.info(
f'Epoch:[{self._epoch}/{self._max_epochs}] '
f'Candidate:[{i + 1}/{self.num_candidates}] '
f'Score:{score}')
f'Score:{score} '
f'Resource:{resource}')

def sort_candidates(self,
trade_off: dict = None,
reverse: bool = False) -> None:
"""Support sort candidates in single and multiple obj optimization.
Args:
trade_off (dict, optional): Dict for trade-off in multiple object
optimization. Defaults to None.
reverse (bool, optional): Whether to reverse the list.
Defaults to True.
"""
trade_off = dict(sec_obj='flops', max_score_key=100)
if trade_off is not None:
ratio = trade_off.get('ratio', 1)
multiple_obj_score = [(cand[1], cand[2])
for cand in self.candidates]
multiple_obj_score = np.array(multiple_obj_score)
max_score_key = trade_off.get('max_score_key', 100)
if max_score_key != 0:
multiple_obj_score[:, 0] = \
max_score_key - multiple_obj_score[:, 0]
sort_idx = np.argsort(multiple_obj_score[:, 0])
F = multiple_obj_score[sort_idx]
dm = HighTradeoffPoints(ratio, n_survive=len(multiple_obj_score))
candidate_index = dm.do(F)
candidate_index = sort_idx[candidate_index]
self.candidates = [self.candidates[idx] for idx in candidate_index]
else:
self.candidates.sort(key=lambda x: x[1], reverse=reverse)

def gen_mutation_candidates(self) -> List:
"""Generate specified number of mutation candicates."""
Expand Down Expand Up @@ -257,14 +318,68 @@ def _save_best_fix_subnet(self):
'Search finished and '
f'{save_name} saved in {self.runner.work_dir}.')

def _init_predictor(self):
"""Initialize predictor. Training is required."""
if self.predictor.pretrained:
self.predictor.load_checkpoint()
self.runner.logger.info(
f'Loaded Checkpoints from {self.predictor.pretrained}')
else:
self.runner.logger.info('No checkpoints found. Start training.')
if isinstance(self.predictor.train_samples, str):
self.runner.logger.info('Find specified samples in '
f'{self.predictor.train_samples}')
train_samples = fileio.load(self.predictor.train_samples)
self.candidates = train_samples['subnets']
else:
self.runner.logger.info("Without specified samples. Start random sampling.")
temp_num_candidates = self.num_candidates
self.num_candidates = self.predictor.train_samples

self.sample_candidates()
self.update_candidates_scores(finetune=True)
self.num_candidates = temp_num_candidates

inputs = []
for i, candidate in enumerate(self.candidates.subnets):
self.model.set_subnet(candidate)
inputs.append(self.predictor.spec2feats(self.model))
inputs = np.array(inputs)
labels = np.array(self.candidates.scores)
self.predictor.fit(inputs, labels)
if self.runner.rank == 0:
predictor_dir = self.predictor.save(
osp.join(self.runner.work_dir, 'predictor'))
self.runner.logger.info(
f'Predictor pre-trained, saved in {self.runner.work_dir}.')
self.candidates = Candidates()

def _finetune_model(self):
"""Finetune model."""
pass

@torch.no_grad()
def _val_candidate(self) -> Dict:
"""Run validation."""
self.runner.model.eval()
for data_batch in self.dataloader:
outputs = self.runner.model.val_step(data_batch)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
def _val_candidate(self,
use_predictor: bool = False,
valid_resources: bool = False) -> Dict:
"""Run validation.
Args:
valid_resources (bool): Whether to valid resources.
Defaults to False.
"""
if use_predictor:
assert self.predictor is not None
metrics = self.predictor.predict(self.model)
else:
self.runner.model.eval()
for data_batch in self.dataloader:
outputs = self.runner.model.val_step(data_batch)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

if valid_resources:
metrics.update(get_subnet_resources(self.model, self.estimator))
return metrics

def _save_searcher_ckpt(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions mmrazor/engine/runner/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .check import check_subnet_flops
from .check import check_subnet_flops, get_subnet_resources
from .genetic import crossover

__all__ = ['crossover', 'check_subnet_flops']
__all__ = ['crossover', 'check_subnet_flops', 'get_subnet_resources']
20 changes: 20 additions & 0 deletions mmrazor/engine/runner/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,23 @@ def check_subnet_flops(
return True
else:
return False

def get_subnet_resources(
model: nn.Module,
# subnet: SupportRandomSubnet,
estimator: ResourceEstimator) -> bool:
"""Get subnet FLOPs."""
assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture')

# model.set_subnet(subnet)
fix_mutable = export_fix_subnet(model)
copied_model = copy.deepcopy(model)
load_fix_subnet(copied_model, fix_mutable)

model_to_check = model.architecture
if isinstance(model_to_check, BaseDetector):
results = estimator.estimate(model=model_to_check.backbone)
else:
results = estimator.estimate(model=model_to_check)

return results
2 changes: 2 additions & 0 deletions mmrazor/models/task_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .delivery import * # noqa: F401,F403
from .estimators import ResourceEstimator
from .multi_object_optimizer import * # noqa: F401,F403
from .predictor import * # noqa: F401,F403
from .recorder import * # noqa: F401,F403
from .tracer import * # noqa: F401,F403

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .packages import HighTradeoffPoints

__all__ = ['HighTradeoffPoints']
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .high_tradeoff_points import HighTradeoffPoints

__all__ = ['HighTradeoffPoints']
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np

from pymoo.core.decision_making import DecisionMaking, NeighborFinder, find_outliers_upper_tail
from pymoo.util.normalization import normalize
from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting

from pymoo.config import Config
Config.warnings['not_compiled'] = False


class HighTradeoffPoints(DecisionMaking):
"""Method for multi-object optimization.
Args:
ratio(float): weight between score_key and sec_obj, details in
demo/nas/demo.ipynb.
epsilon(float): specific a radius for each neighbour.
n_survive(int): how many high-tradeoff points will return finally.
"""

def __init__(self,
ratio=1,
epsilon=0.125,
n_survive=None,
**kwargs) -> None:
super().__init__(**kwargs)
self.epsilon = epsilon
self.n_survive = n_survive
self.ratio = ratio

def _do(self, data, **kwargs):
front = NonDominatedSorting().do(data, only_non_dominated_front=True)
F = data[front, :]

n, m = F.shape
F = normalize(F, self.ideal, self.nadir)
F[:, 1] = F[:, 1] * self.ratio

neighbors_finder = NeighborFinder(
F, epsilon=0.125, n_min_neigbors='auto', consider_2d=False)

mu = np.full(n, -np.inf)

for i in range(n):

# for each neighbour in a specific radius of that solution
neighbors = neighbors_finder.find(i)

# calculate the trade-off to all neighbours
diff = F[neighbors] - F[i]

# calculate sacrifice and gain
sacrifice = np.maximum(0, diff).sum(axis=1)
gain = np.maximum(0, -diff).sum(axis=1)

np.warnings.filterwarnings('ignore')
tradeoff = sacrifice / gain

# otherwise find the one with the smalled one
mu[i] = np.nanmin(tradeoff)

# if given topk
if self.n_survive is not None:
n_survive = min(self.n_survive, len(mu))
index = np.argsort(mu)[-n_survive:][::-1]
front_survive = front[index]

self.n_survive -= n_survive
if self.n_survive == 0:
return front_survive
# in case the survived in front is not enough for topk
index = np.array(list(set(np.arange(len(data))) - set(front)))
unused_data = data[index]
no_front_survive = index[self._do(unused_data)]

return np.concatenate([front_survive, no_front_survive])
else:
# return points with trade-off > 2*sigma
mu = find_outliers_upper_tail(mu)
return mu if len(mu) else []
5 changes: 5 additions & 0 deletions mmrazor/models/task_modules/predictor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .metric_predictor import MetricPredictor
from .zero_shot_predictor import ZeroShotPredictor

__all__ = ['MetricPredictor', 'ZeroShotPredictor']
15 changes: 15 additions & 0 deletions mmrazor/models/task_modules/predictor/base_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod


class BasePredictor():
"""Base predictor."""

def __init__(self):
"""init."""
pass

@abstractmethod
def predict(self, model, predict_args):
"""predict result."""
pass
7 changes: 7 additions & 0 deletions mmrazor/models/task_modules/predictor/handler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .carts_handler import CartsHandler
from .gp_handler import GaussProcessHandler
from .mlp_handler import MLPHandler
from .rbf_handler import RBFHandler

__all__ = ['CartsHandler', 'GaussProcessHandler', 'MLPHandler', 'RBFHandler']
Loading

0 comments on commit 07b516d

Please sign in to comment.