Skip to content

Commit

Permalink
Tweak example to be resumable
Browse files Browse the repository at this point in the history
  • Loading branch information
Pierre Delaunay committed Nov 9, 2022
1 parent be581e8 commit 048ba19
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 6 deletions.
9 changes: 7 additions & 2 deletions example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ defaults:
- override hydra/sweeper: orion

hydra:
job:
env_set:
PREVIOUS_CHECKPOINT: ${hydra.sweeper.orion.previous_checkpoint}
CURRENT_CHECKPOINT: ${hydra.runtime.output_dir}

# makes sure each multirun ends up in a unique folder
# the defaults can make overlapping folders
sweep:
Expand All @@ -16,14 +21,14 @@ hydra:
lr: "uniform(0, 1)"
dropout: "uniform(0, 1)"
batch_size: "uniform(4, 16, discrete=True)"
epoch: "fidelity(10, 100)"
epoch: "fidelity(3, 100)"

orion:
name: 'experiment'
version: '1'

algorithm:
type: random
type: hyperband
config:
seed: 1

Expand Down
55 changes: 53 additions & 2 deletions example/my_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,77 @@
log = logging.getLogger(__name__)



def _load_checkpoint(path, model):
checkpoint = os.path.join(path, 'chk.pt')

if os.path.exists(checkpoint):
# load checkpoint
# ...
return True

return False


def load_checkpoint(model):
current_checkpoint_path = os.getenv("CURRENT_CHECKPOINT")
assert current_checkpoint_path is not None

# if checkpoint file exist then always load it as it is the most recent
if _load_checkpoint(current_checkpoint_path, model):
return True

# Previous checkpoint points to a job that finished and that we want to resume from
# this is useful for genetic algo or algo that gradually improve on previous solutions
prev_checkpoint_path = os.getenv("PREVIOUS_CHECKPOINT")

if prev_checkpoint_path and _load_checkpoint(prev_checkpoint_path, model):
return True

return False



def save_checkpoint(model):
current_checkpoint_path = os.getenv("CURRENT_CHECKPOINT")
checkpoint = os.path.join(current_checkpoint_path, 'chk.pt')

with open(checkpoint, 'w') as fp:
# save checkpoint
# ...
pass


@hydra.main(config_path=".", config_name="config", version_base="1.1")
def dummy_training(cfg: DictConfig) -> float:
"""A dummy function to minimize
Minimum is 0.0 at:
lr = 0.12, dropout=0.33, opt=Adam, batch_size=4
"""

# print(cfg.hydra )

# makes sure folders are unique
os.makedirs('newdir', exist_ok=False)

model = None

if load_checkpoint(model):
print('Resuming from checkpoint')
else:
print('No checkpoint found')

do = cfg.dropout
bs = cfg.batch_size
out = float(
abs(do - 0.33) + int(cfg.optimizer.name == "Adam") + abs(cfg.optimizer.lr - 0.12) + abs(bs - 4)
)
# ..../hydra_orion_sweeper/example/multirun/2022-11-08/11-56-45/39
# print(os.getcwd())
log.info(
f"dummy_training(dropout={do:.3f}, lr={cfg.optimizer.lr:.3f}, opt={cfg.optimizer.name}, batch_size={bs}) = {out:.3f}",
)

save_checkpoint(model)

if cfg.error:
raise RuntimeError("cfg.error is True")

Expand Down
2 changes: 2 additions & 0 deletions hydra_plugins/hydra_orion_sweeper/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class OrionClientConf:
trial: Optional[str] = None
uuid: Optional[str] = None

previous_checkpoint: Optional[str] = None


@dataclass
class WorkerConf:
Expand Down
17 changes: 15 additions & 2 deletions hydra_plugins/hydra_orion_sweeper/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, List, Optional, Sequence, Union

from hydra.core import utils
from hydra.core.global_hydra import GlobalHydra
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.override_parser.types import Override, QuotedString
from hydra.core.plugins import Plugins
Expand Down Expand Up @@ -159,16 +160,19 @@ def override_parser():
return parser


def as_overrides(trial, additional, uuid):
def as_overrides(trial, additional, uuid, prev_checkpoint):
"""Returns the trial arguments as hydra overrides"""
kwargs = deepcopy(additional)
kwargs.update(flatten(trial.params))

args = [f"{k}={v}" for k, v in kwargs.items()]

args += [
f"hydra.sweeper.orion.id={trial.experiment}",
f"hydra.sweeper.orion.trial={trial.id}",
f"hydra.sweeper.orion.uuid={uuid}",
f"hydra.sweeper.orion.previous_checkpoint={prev_checkpoint}",
# "hydra.sweeper.orion.current_checkpoint=$hydra.runtime.output_dir",
]
return tuple(args)

Expand Down Expand Up @@ -350,6 +354,7 @@ def __init__(
self.client = None
self.storage = None
self.uuid = uuid.uuid1().hex
self.resume_paths = dict()

self.orion_config = orion
self.worker_config = worker
Expand Down Expand Up @@ -532,10 +537,15 @@ def sample_trials(self) -> List[Trial]:
self.pending_trials.update(set(trials))
return trials

def trial_as_override(self, trial: Trial):
"""Create overrides for a specific trial"""
checkpoint = self.resume_paths.get(trial.hash_params)
return as_overrides(trial, self.arguments, self.uuid, checkpoint)

def execute_trials(self, trials: List[Trial]) -> Sequence[JobReturn]:
"""Execture the given batch of trials"""

overrides = list(as_overrides(t, self.arguments, self.uuid) for t in trials)
overrides = list(self.trial_as_override(t) for t in trials)
self.validate_batch_is_legal(overrides)

returns = self.launcher.launch(overrides, initial_job_idx=self.job_idx)
Expand All @@ -548,6 +558,9 @@ def observe_one(
"""Observe a single trial"""
value = result.return_value

trialdir = result.hydra_cfg["hydra"]["runtime"]["output_dir"]
self.resume_paths[trial.hash_params] = trialdir

try:
objective = to_objective(value)
self.client.observe(trial, objective)
Expand Down

0 comments on commit 048ba19

Please sign in to comment.