Skip to content

Commit

Permalink
Merge pull request #8 from eukaryo/add-timeline-plot
Browse files Browse the repository at this point in the history
Add timeline plot
  • Loading branch information
toshihikoyanase authored Mar 31, 2023
2 parents 2f9d850 + 2dd2ed1 commit 060dde1
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
49 changes: 48 additions & 1 deletion studies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
from typing import List
from typing import Optional
from typing import Sequence
Expand Down Expand Up @@ -257,7 +258,6 @@ def get_mnist() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoade
return train_loader, valid_loader

def objective(trial: optuna.Trial) -> float:

# Generate the model.
model = define_model(trial).to(DEVICE)

Expand Down Expand Up @@ -316,3 +316,50 @@ def objective(trial: optuna.Trial) -> float:
)
study.optimize(objective, n_trials=50, timeout=600)
return study


def create_single_objective_studies_for_timeline() -> List[Tuple[str, StudiesType]]:
studies: List[Tuple[str, StudiesType]] = []
storage = optuna.storages.InMemoryStorage()

def objective_timeline(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", 0, 1)
time.sleep(x * 0.1)
if x > 0.8:
raise ValueError()
if x > 0.4:
raise optuna.TrialPruned()
return x**2

# Single-objective study
study = optuna.create_study(
study_name="A single objective study consuming time",
storage=storage,
)

study.enqueue_trial({"x": 0.3}) # Add a COMPLETE trial.
study.enqueue_trial({"x": 0.9}) # Add a FAIL trial.
study.enqueue_trial({"x": 0.5}) # Add a PRUNED trial.
study.optimize(objective_timeline, n_trials=50, n_jobs=2, catch=(ValueError,))
studies.append((study.study_name, study))

# Single-objective study
study = optuna.create_study(
study_name=(
"A single objective study consuming time and "
"the order of legends is different from the order of trials"
),
storage=storage,
)
study.enqueue_trial({"x": 0.9}) # Add a FAIL trial.
study.enqueue_trial({"x": 0.5}) # Add a PRUNED trial.
study.enqueue_trial({"x": 0.3}) # Add a COMPLETE trial.
study.optimize(objective_timeline, n_trials=50, n_jobs=2, catch=(ValueError,))
studies.append((study.study_name, study))

# No trials single-objective study
study = optuna.create_study(
study_name="A single objective study that has no trials", storage=storage
)
studies.append((study.study_name, study))
return studies
18 changes: 18 additions & 0 deletions visual_regression_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from studies import create_multiple_single_objective_studies
from studies import create_pytorch_study
from studies import create_single_objective_studies
from studies import create_single_objective_studies_for_timeline
from studies import StudiesType


Expand Down Expand Up @@ -241,6 +242,21 @@ def generate_pareto_front_plots(
)


def generate_timeline_plots(
studies: List[Tuple[str, StudiesType]], base_dir: str, plot_kwargs: Dict[str, Any]
) -> List[Tuple[str, str, str]]:
filename_prefix = "timeline"
if len(plot_kwargs) > 0:
filename_prefix = f"{filename_prefix}-{stringify_plot_kwargs(plot_kwargs)}"
return generate_plot_files(
studies,
base_dir,
wrap_plot_func(lambda s: plotly_visualization.plot_timeline(s, **plot_kwargs)),
wrap_plot_func(lambda s: matplotlib_visualization.plot_timeline(s, **plot_kwargs)),
filename_prefix=filename_prefix,
)


def main() -> None:
if not os.path.exists(abs_output_dir):
os.mkdir(abs_output_dir)
Expand All @@ -256,6 +272,7 @@ def main() -> None:
multi_objective_studies = create_multi_objective_studies()
print("Creating studies that have intermediate values")
intermediate_value_studies = create_intermediate_value_studies()
single_objective_studies_for_timeline = create_single_objective_studies_for_timeline()

if args.heavy:
print("Creating pytorch study")
Expand Down Expand Up @@ -305,6 +322,7 @@ def main() -> None:
generate_edf_plots,
{},
),
("plot_timeline", single_objective_studies_for_timeline, generate_timeline_plots, {}),
]:
assert isinstance(plot_kwargs, Dict)
plot_files = generate(studies, abs_output_dir, plot_kwargs)
Expand Down

0 comments on commit 060dde1

Please sign in to comment.