Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak example to be resumable #21

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@ defaults:
- override hydra/sweeper: orion

hydra:
job:
env_set:
PREVIOUS_CHECKPOINT: ${hydra.sweeper.experiment.previous_checkpoint}
CURRENT_CHECKPOINT: ${hydra.runtime.output_dir}

# makes sure each multirun ends up in a unique folder
# the defaults can make overlapping folders
sweep:
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
subdir: ${hydra.sweeper.orion.name}/${hydra.sweeper.orion.uuid}/${hydra.job.id}
subdir: ${hydra.sweeper.experiment.name}/${hydra.sweeper.experiment.uuid}/${hydra.job.id}

sweeper:
# default parametrization of the search space
Expand All @@ -16,14 +21,14 @@ hydra:
lr: "uniform(0, 1)"
dropout: "uniform(0, 1)"
batch_size: "uniform(4, 16, discrete=True)"
epoch: "fidelity(10, 100)"
epoch: "fidelity(3, 100)"

orion:
experiment:
name: 'experiment'
version: '1'

algorithm:
type: random
type: hyperband
config:
seed: 1

Expand Down
53 changes: 51 additions & 2 deletions example/my_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,47 @@
log = logging.getLogger(__name__)



def _load_checkpoint(path, model):
checkpoint = os.path.join(path, 'chk.pt')

if os.path.exists(checkpoint):
# load checkpoint
# ...
return True

return False


def load_checkpoint(model):
current_checkpoint_path = os.getenv("CURRENT_CHECKPOINT")
assert current_checkpoint_path is not None

# if checkpoint file exist then always load it as it is the most recent
if _load_checkpoint(current_checkpoint_path, model):
return True

# Previous checkpoint points to a job that finished and that we want to resume from
# this is useful for genetic algo or algo that gradually improve on previous solutions
prev_checkpoint_path = os.getenv("PREVIOUS_CHECKPOINT")

if prev_checkpoint_path and _load_checkpoint(prev_checkpoint_path, model):
return True

return False



def save_checkpoint(model):
current_checkpoint_path = os.getenv("CURRENT_CHECKPOINT")
checkpoint = os.path.join(current_checkpoint_path, 'chk.pt')

with open(checkpoint, 'w') as fp:
# save checkpoint
# ...
pass


@hydra.main(config_path=".", config_name="config", version_base="1.1")
def dummy_training(cfg: DictConfig) -> float:
"""A dummy function to minimize
Expand All @@ -18,16 +59,24 @@ def dummy_training(cfg: DictConfig) -> float:
# makes sure folders are unique
os.makedirs('newdir', exist_ok=False)

model = None

if load_checkpoint(model):
print('Resuming from checkpoint')
else:
print('No checkpoint found')

do = cfg.dropout
bs = cfg.batch_size
out = float(
abs(do - 0.33) + int(cfg.optimizer.name == "Adam") + abs(cfg.optimizer.lr - 0.12) + abs(bs - 4)
)
# ..../hydra_orion_sweeper/example/multirun/2022-11-08/11-56-45/39
# print(os.getcwd())
log.info(
f"dummy_training(dropout={do:.3f}, lr={cfg.optimizer.lr:.3f}, opt={cfg.optimizer.name}, batch_size={bs}) = {out:.3f}",
)

save_checkpoint(model)

if cfg.error:
raise RuntimeError("cfg.error is True")

Expand Down
4 changes: 3 additions & 1 deletion hydra_plugins/hydra_orion_sweeper/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class OrionClientConf:
trial: Optional[str] = None
uuid: Optional[str] = None

previous_checkpoint: Optional[str] = None


@dataclass
class WorkerConf:
Expand Down Expand Up @@ -82,7 +84,7 @@ class OrionSweeperConf:

_target_: str = "hydra_plugins.hydra_orion_sweeper.orion_sweeper.OrionSweeper"

orion: OrionClientConf = OrionClientConf()
experiment: OrionClientConf = OrionClientConf()

worker: WorkerConf = WorkerConf()

Expand Down
27 changes: 20 additions & 7 deletions hydra_plugins/hydra_orion_sweeper/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,19 @@ def override_parser():
return parser


def as_overrides(trial, additional, uuid):
def as_overrides(trial, additional, uuid, prev_checkpoint):
"""Returns the trial arguments as hydra overrides"""
kwargs = deepcopy(additional)
kwargs.update(flatten(trial.params))

args = [f"{k}={v}" for k, v in kwargs.items()]

args += [
f"hydra.sweeper.orion.id={trial.experiment}",
f"hydra.sweeper.orion.trial={trial.id}",
f"hydra.sweeper.orion.uuid={uuid}",
f"hydra.sweeper.experiment.id={trial.experiment}",
f"hydra.sweeper.experiment.trial={trial.id}",
f"hydra.sweeper.experiment.uuid={uuid}",
f"hydra.sweeper.experiment.previous_checkpoint={prev_checkpoint}",
# "hydra.sweeper.experiment.current_checkpoint=$hydra.runtime.output_dir",
]
return tuple(args)

Expand Down Expand Up @@ -338,7 +341,7 @@ class OrionSweeperImpl(Sweeper):

def __init__(
self,
orion: OrionClientConf,
experiment: OrionClientConf,
worker: WorkerConf,
algorithm: AlgorithmConf,
storage: StorageConf,
Expand All @@ -350,8 +353,9 @@ def __init__(
self.client = None
self.storage = None
self.uuid = uuid.uuid1().hex
self.resume_paths = dict()

self.orion_config = orion
self.orion_config = experiment
self.worker_config = worker
self.algo_config = algorithm
self.storage_config = storage
Expand Down Expand Up @@ -532,10 +536,15 @@ def sample_trials(self) -> List[Trial]:
self.pending_trials.update(set(trials))
return trials

def trial_as_override(self, trial: Trial):
"""Create overrides for a specific trial"""
checkpoint = self.resume_paths.get(trial.hash_params)
return as_overrides(trial, self.arguments, self.uuid, checkpoint)

def execute_trials(self, trials: List[Trial]) -> Sequence[JobReturn]:
"""Execture the given batch of trials"""

overrides = list(as_overrides(t, self.arguments, self.uuid) for t in trials)
overrides = list(self.trial_as_override(t) for t in trials)
self.validate_batch_is_legal(overrides)

returns = self.launcher.launch(overrides, initial_job_idx=self.job_idx)
Expand All @@ -548,6 +557,10 @@ def observe_one(
"""Observe a single trial"""
value = result.return_value

if result.hydra_cfg:
trialdir = result.hydra_cfg["hydra"]["runtime"]["output_dir"]
self.resume_paths[trial.hash_params] = trialdir

try:
objective = to_objective(value)
self.client.observe(trial, objective)
Expand Down
8 changes: 4 additions & 4 deletions hydra_plugins/hydra_orion_sweeper/orion_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class OrionSweeper(Sweeper):

def __init__(
self,
orion: OrionClientConf,
experiment: OrionClientConf,
worker: WorkerConf,
algorithm: AlgorithmConf,
storage: StorageConf,
Expand All @@ -30,15 +30,15 @@ def __init__(
# >>> Remove with Issue #8
if parametrization is not None and params is None:
warn(
"`hydra.sweeper.orion.parametrization` is deprecated;"
"`hydra.sweeper.experiment.parametrization` is deprecated;"
"use `hydra.sweeper.params` instead",
DeprecationWarning,
)
params = parametrization

elif parametrization is not None and params is not None:
warn(
"Both `hydra.sweeper.orion.parametrization` and `hydra.sweeper.params` are defined;"
"Both `hydra.sweeper.experiment.parametrization` and `hydra.sweeper.params` are defined;"
"using `hydra.sweeper.params`",
DeprecationWarning,
)
Expand All @@ -47,7 +47,7 @@ def __init__(
if params is None:
params = dict()

self.sweeper = OrionSweeperImpl(orion, worker, algorithm, storage, params)
self.sweeper = OrionSweeperImpl(experiment, worker, algorithm, storage, params)

def setup(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
},
"sweeper": {
"_target_": "hydra_plugins.hydra_orion_sweeper.orion_sweeper.OrionSweeper",
"orion": {
"experiment": {
"name": None,
"version": None,
"branching": None,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_orion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def load_hydra_testing_config():

def orion_configuration():
return dict(
orion=OmegaConf.structured(OrionClientConf()),
experiment=OmegaConf.structured(OrionClientConf()),
worker=OmegaConf.structured(WorkerConf()),
algorithm=OmegaConf.structured(AlgorithmConf()),
storage=OmegaConf.structured(StorageConf()),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_parametrization_is_deprecated():
assert (
warnings[0]
.message.args[0]
.startswith("`hydra.sweeper.orion.parametrization` is deprecated;")
.startswith("`hydra.sweeper.experiment.parametrization` is deprecated;")
)


Expand All @@ -44,7 +44,7 @@ def test_parametrization_and_params():
assert (
warnings[0]
.message.args[0]
.startswith("Both `hydra.sweeper.orion.parametrization` and")
.startswith("Both `hydra.sweeper.experiment.parametrization` and")
)


Expand Down