Skip to content

Commit

Permalink
RNSGA-II - BootstrapNAS sub-network search (openvinotoolkit#1235)
Browse files Browse the repository at this point in the history
### Changes

Enable RNSGA-II algorithm to search for efficient sub-networks.

### Reason for changes

Improvements in BootstrapNAS search. 

### Related tickets

N/A

### Tests

Improved search tests to include RNSGA-II as search algorithm.

---------

Co-authored-by: Yuan, Jinjie <[email protected]>
  • Loading branch information
jpablomch and Yuan0320 authored Dec 18, 2023
1 parent 1e26f1f commit 4660729
Show file tree
Hide file tree
Showing 9 changed files with 385 additions and 68 deletions.
6 changes: 3 additions & 3 deletions examples/experimental/torch/classification/bootstrap_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
from examples.torch.common.utils import is_pretrained_model_requested
from examples.torch.common.utils import print_args
from nncf.config.structures import BNAdaptationInitArgs
from nncf.experimental.torch.nas.bootstrapNAS import BaseSearchAlgorithm
from nncf.experimental.torch.nas.bootstrapNAS import EpochBasedTrainingAlgorithm
from nncf.experimental.torch.nas.bootstrapNAS import SearchAlgorithm
from nncf.torch.initialization import default_criterion_fn
from nncf.torch.initialization import wrap_dataloader_for_init
from nncf.torch.model_creation import create_nncf_network
Expand Down Expand Up @@ -196,9 +196,9 @@ def validate_model_fn_top1(model_, loader_):
)

if resuming_checkpoint_path is None:
search_algo = SearchAlgorithm.from_config(nncf_network, elasticity_ctrl, nncf_config)
search_algo = BaseSearchAlgorithm.from_config(nncf_network, elasticity_ctrl, nncf_config)
else:
search_algo = SearchAlgorithm.from_checkpoint(
search_algo = BaseSearchAlgorithm.from_checkpoint(
nncf_network, elasticity_ctrl, bn_adapt_args, resuming_checkpoint_path
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from examples.torch.common.utils import is_pretrained_model_requested
from examples.torch.common.utils import print_args
from nncf.config.structures import BNAdaptationInitArgs
from nncf.experimental.torch.nas.bootstrapNAS import SearchAlgorithm
from nncf.experimental.torch.nas.bootstrapNAS import BaseSearchAlgorithm
from nncf.experimental.torch.nas.bootstrapNAS.training.model_creator_helpers import resume_compression_from_state
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.initialization import wrap_dataloader_for_init
Expand Down Expand Up @@ -151,7 +151,7 @@ def validate_model_fn_top1(model_, loader_):
top1_acc = validate_model_fn_top1(model, val_loader)
logger.info("SuperNetwork Top 1: {top1_acc}".format(top1_acc=top1_acc))

search_algo = SearchAlgorithm.from_config(model, elasticity_ctrl, nncf_config)
search_algo = BaseSearchAlgorithm.from_config(model, elasticity_ctrl, nncf_config)

elasticity_ctrl, best_config, performance_metrics = search_algo.run(
validate_model_fn_top1, val_loader, config.checkpoint_save_dir, tensorboard_writer=config.tb
Expand Down
15 changes: 14 additions & 1 deletion nncf/config/schemata/experimental_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@

SEARCH_ALGORITHMS_SCHEMA = {
"type": "string",
"enum": ["NSGA2"],
"enum": ["NSGA2", "RNSGA2"],
}

BOOTSTRAP_NAS_SEARCH_SCHEMA = {
Expand All @@ -303,10 +303,15 @@
NUMBER,
description="Defines the number of evaluations that will be used by the search algorithm.",
),
"num_constraints": with_attributes(NUMBER, description="Number of constraints in search problem."),
"population": with_attributes(
NUMBER,
description="Defines the population size when using an evolutionary search algorithm.",
),
"crossover_prob": with_attributes(NUMBER, description="Crossover probability used by a genetic algorithm."),
"crossover_eta": with_attributes(NUMBER, description="Crossover eta."),
"mutation_eta": with_attributes(NUMBER, description="Mutation eta for genetic algorithm."),
"mutation_prob": with_attributes(NUMBER, description="Mutation probability for genetic algorithm."),
"acc_delta": with_attributes(
NUMBER,
description="Defines the absolute difference in accuracy that is tolerated "
Expand All @@ -317,6 +322,14 @@
description="Defines the reference accuracy from the pre-trained model used "
"to generate the super-network.",
),
"aspiration_points": with_attributes(
ARRAY_OF_NUMBERS, description="Information to indicate the preferred parts of the Pareto front"
),
"epsilon": with_attributes(NUMBER, description="epsilon distance of surviving solutions for RNSGA-II."),
"weights": with_attributes(NUMBER, description="weights used by RNSGA-II."),
"extreme_points_as_ref_points": with_attributes(
BOOLEAN, description="Find extreme points and use them as aspiration points."
),
"compression": make_object_or_array_of_objects_schema(
{"oneOf": [{"$ref": f"#/$defs/{KNOWLEDGE_DISTILLATION_ALGO_NAME_IN_CONFIG}"}]}
),
Expand Down
2 changes: 1 addition & 1 deletion nncf/experimental/torch/nas/bootstrapNAS/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@

# noqa
from nncf.experimental.torch.nas.bootstrapNAS.elasticity import elasticity_builder as elasticity_algo
from nncf.experimental.torch.nas.bootstrapNAS.search.search import SearchAlgorithm
from nncf.experimental.torch.nas.bootstrapNAS.search.search import BaseSearchAlgorithm
from nncf.experimental.torch.nas.bootstrapNAS.training import progressive_shrinking_builder as ps_algo
from nncf.experimental.torch.nas.bootstrapNAS.training.training_algorithm import EpochBasedTrainingAlgorithm
Loading

0 comments on commit 4660729

Please sign in to comment.