diff --git a/example/config.yaml b/example/config.yaml index 5f2f31c..8c204f4 100644 --- a/example/config.yaml +++ b/example/config.yaml @@ -2,11 +2,16 @@ defaults: - override hydra/sweeper: orion hydra: + job: + env_set: + PREVIOUS_CHECKPOINT: ${hydra.sweeper.experiment.previous_checkpoint} + CURRENT_CHECKPOINT: ${hydra.runtime.output_dir} + # makes sure each multirun ends up in a unique folder # the defaults can make overlapping folders sweep: dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} - subdir: ${hydra.sweeper.orion.name}/${hydra.sweeper.orion.uuid}/${hydra.job.id} + subdir: ${hydra.sweeper.experiment.name}/${hydra.sweeper.experiment.uuid}/${hydra.job.id} sweeper: # default parametrization of the search space @@ -16,14 +21,14 @@ hydra: lr: "uniform(0, 1)" dropout: "uniform(0, 1)" batch_size: "uniform(4, 16, discrete=True)" - epoch: "fidelity(10, 100)" + epoch: "fidelity(3, 100)" - orion: + experiment: name: 'experiment' version: '1' algorithm: - type: random + type: hyperband config: seed: 1 diff --git a/example/my_app.py b/example/my_app.py index 11ab5a3..ff68b07 100644 --- a/example/my_app.py +++ b/example/my_app.py @@ -8,6 +8,47 @@ log = logging.getLogger(__name__) + +def _load_checkpoint(path, model): + checkpoint = os.path.join(path, 'chk.pt') + + if os.path.exists(checkpoint): + # load checkpoint + # ... + return True + + return False + + +def load_checkpoint(model): + current_checkpoint_path = os.getenv("CURRENT_CHECKPOINT") + assert current_checkpoint_path is not None + + # if checkpoint file exist then always load it as it is the most recent + if _load_checkpoint(current_checkpoint_path, model): + return True + + # Previous checkpoint points to a job that finished and that we want to resume from + # this is useful for genetic algo or algo that gradually improve on previous solutions + prev_checkpoint_path = os.getenv("PREVIOUS_CHECKPOINT") + + if prev_checkpoint_path and _load_checkpoint(prev_checkpoint_path, model): + return True + + return False + + + +def save_checkpoint(model): + current_checkpoint_path = os.getenv("CURRENT_CHECKPOINT") + checkpoint = os.path.join(current_checkpoint_path, 'chk.pt') + + with open(checkpoint, 'w') as fp: + # save checkpoint + # ... + pass + + @hydra.main(config_path=".", config_name="config", version_base="1.1") def dummy_training(cfg: DictConfig) -> float: """A dummy function to minimize @@ -18,16 +59,24 @@ def dummy_training(cfg: DictConfig) -> float: # makes sure folders are unique os.makedirs('newdir', exist_ok=False) + model = None + + if load_checkpoint(model): + print('Resuming from checkpoint') + else: + print('No checkpoint found') + do = cfg.dropout bs = cfg.batch_size out = float( abs(do - 0.33) + int(cfg.optimizer.name == "Adam") + abs(cfg.optimizer.lr - 0.12) + abs(bs - 4) ) - # ..../hydra_orion_sweeper/example/multirun/2022-11-08/11-56-45/39 - # print(os.getcwd()) log.info( f"dummy_training(dropout={do:.3f}, lr={cfg.optimizer.lr:.3f}, opt={cfg.optimizer.name}, batch_size={bs}) = {out:.3f}", ) + + save_checkpoint(model) + if cfg.error: raise RuntimeError("cfg.error is True") diff --git a/hydra_plugins/hydra_orion_sweeper/config.py b/hydra_plugins/hydra_orion_sweeper/config.py index 4995a50..4ade603 100644 --- a/hydra_plugins/hydra_orion_sweeper/config.py +++ b/hydra_plugins/hydra_orion_sweeper/config.py @@ -23,6 +23,8 @@ class OrionClientConf: trial: Optional[str] = None uuid: Optional[str] = None + previous_checkpoint: Optional[str] = None + @dataclass class WorkerConf: @@ -82,7 +84,7 @@ class OrionSweeperConf: _target_: str = "hydra_plugins.hydra_orion_sweeper.orion_sweeper.OrionSweeper" - orion: OrionClientConf = OrionClientConf() + experiment: OrionClientConf = OrionClientConf() worker: WorkerConf = WorkerConf() diff --git a/hydra_plugins/hydra_orion_sweeper/implementation.py b/hydra_plugins/hydra_orion_sweeper/implementation.py index 1959afb..20b2ade 100644 --- a/hydra_plugins/hydra_orion_sweeper/implementation.py +++ b/hydra_plugins/hydra_orion_sweeper/implementation.py @@ -159,16 +159,19 @@ def override_parser(): return parser -def as_overrides(trial, additional, uuid): +def as_overrides(trial, additional, uuid, prev_checkpoint): """Returns the trial arguments as hydra overrides""" kwargs = deepcopy(additional) kwargs.update(flatten(trial.params)) args = [f"{k}={v}" for k, v in kwargs.items()] + args += [ - f"hydra.sweeper.orion.id={trial.experiment}", - f"hydra.sweeper.orion.trial={trial.id}", - f"hydra.sweeper.orion.uuid={uuid}", + f"hydra.sweeper.experiment.id={trial.experiment}", + f"hydra.sweeper.experiment.trial={trial.id}", + f"hydra.sweeper.experiment.uuid={uuid}", + f"hydra.sweeper.experiment.previous_checkpoint={prev_checkpoint}", + # "hydra.sweeper.experiment.current_checkpoint=$hydra.runtime.output_dir", ] return tuple(args) @@ -338,7 +341,7 @@ class OrionSweeperImpl(Sweeper): def __init__( self, - orion: OrionClientConf, + experiment: OrionClientConf, worker: WorkerConf, algorithm: AlgorithmConf, storage: StorageConf, @@ -350,8 +353,9 @@ def __init__( self.client = None self.storage = None self.uuid = uuid.uuid1().hex + self.resume_paths = dict() - self.orion_config = orion + self.orion_config = experiment self.worker_config = worker self.algo_config = algorithm self.storage_config = storage @@ -532,10 +536,15 @@ def sample_trials(self) -> List[Trial]: self.pending_trials.update(set(trials)) return trials + def trial_as_override(self, trial: Trial): + """Create overrides for a specific trial""" + checkpoint = self.resume_paths.get(trial.hash_params) + return as_overrides(trial, self.arguments, self.uuid, checkpoint) + def execute_trials(self, trials: List[Trial]) -> Sequence[JobReturn]: """Execture the given batch of trials""" - overrides = list(as_overrides(t, self.arguments, self.uuid) for t in trials) + overrides = list(self.trial_as_override(t) for t in trials) self.validate_batch_is_legal(overrides) returns = self.launcher.launch(overrides, initial_job_idx=self.job_idx) @@ -548,6 +557,10 @@ def observe_one( """Observe a single trial""" value = result.return_value + if result.hydra_cfg: + trialdir = result.hydra_cfg["hydra"]["runtime"]["output_dir"] + self.resume_paths[trial.hash_params] = trialdir + try: objective = to_objective(value) self.client.observe(trial, objective) diff --git a/hydra_plugins/hydra_orion_sweeper/orion_sweeper.py b/hydra_plugins/hydra_orion_sweeper/orion_sweeper.py index 8fdd1d7..8e5fb2d 100644 --- a/hydra_plugins/hydra_orion_sweeper/orion_sweeper.py +++ b/hydra_plugins/hydra_orion_sweeper/orion_sweeper.py @@ -18,7 +18,7 @@ class OrionSweeper(Sweeper): def __init__( self, - orion: OrionClientConf, + experiment: OrionClientConf, worker: WorkerConf, algorithm: AlgorithmConf, storage: StorageConf, @@ -30,7 +30,7 @@ def __init__( # >>> Remove with Issue #8 if parametrization is not None and params is None: warn( - "`hydra.sweeper.orion.parametrization` is deprecated;" + "`hydra.sweeper.experiment.parametrization` is deprecated;" "use `hydra.sweeper.params` instead", DeprecationWarning, ) @@ -38,7 +38,7 @@ def __init__( elif parametrization is not None and params is not None: warn( - "Both `hydra.sweeper.orion.parametrization` and `hydra.sweeper.params` are defined;" + "Both `hydra.sweeper.experiment.parametrization` and `hydra.sweeper.params` are defined;" "using `hydra.sweeper.params`", DeprecationWarning, ) @@ -47,7 +47,7 @@ def __init__( if params is None: params = dict() - self.sweeper = OrionSweeperImpl(orion, worker, algorithm, storage, params) + self.sweeper = OrionSweeperImpl(experiment, worker, algorithm, storage, params) def setup( self, diff --git a/tests/hydra_config.py b/tests/hydra_config.py index 19bdb17..b0b620c 100644 --- a/tests/hydra_config.py +++ b/tests/hydra_config.py @@ -7,7 +7,7 @@ }, "sweeper": { "_target_": "hydra_plugins.hydra_orion_sweeper.orion_sweeper.OrionSweeper", - "orion": { + "experiment": { "name": None, "version": None, "branching": None, diff --git a/tests/test_orion.py b/tests/test_orion.py index 2b60a4e..294f192 100644 --- a/tests/test_orion.py +++ b/tests/test_orion.py @@ -42,7 +42,7 @@ def load_hydra_testing_config(): def orion_configuration(): return dict( - orion=OmegaConf.structured(OrionClientConf()), + experiment=OmegaConf.structured(OrionClientConf()), worker=OmegaConf.structured(WorkerConf()), algorithm=OmegaConf.structured(AlgorithmConf()), storage=OmegaConf.structured(StorageConf()), diff --git a/tests/test_warnings.py b/tests/test_warnings.py index df7407a..dfc39f6 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -25,7 +25,7 @@ def test_parametrization_is_deprecated(): assert ( warnings[0] .message.args[0] - .startswith("`hydra.sweeper.orion.parametrization` is deprecated;") + .startswith("`hydra.sweeper.experiment.parametrization` is deprecated;") ) @@ -44,7 +44,7 @@ def test_parametrization_and_params(): assert ( warnings[0] .message.args[0] - .startswith("Both `hydra.sweeper.orion.parametrization` and") + .startswith("Both `hydra.sweeper.experiment.parametrization` and") )