Skip to content

Commit

Permalink
add test for result module
Browse files Browse the repository at this point in the history
  • Loading branch information
aramoto99 committed Jan 23, 2025
1 parent 4f83ef1 commit 9c5662f
Show file tree
Hide file tree
Showing 15 changed files with 246 additions and 7 deletions.
1 change: 0 additions & 1 deletion aiaccel/hpo/results/base_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


class BaseResult:

def __init__(self, filename_template: str) -> None:
self.filename_template = filename_template

Expand Down
1 change: 0 additions & 1 deletion aiaccel/hpo/results/json_result.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import json

from aiaccel.hpo.job_executors import BaseJobExecutor
Expand Down
1 change: 0 additions & 1 deletion aiaccel/hpo/results/pickle_result.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import pickle as pkl

from aiaccel.hpo.job_executors import BaseJobExecutor
Expand Down
25 changes: 25 additions & 0 deletions tests/hpo/result/config_for_json_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
study:
_target_: optuna.create_study
direction: minimize
study_name: my_study
load_if_exists: false
sampler:
_target_: optuna.samplers.TPESampler
seed: 0

result:
_target_: aiaccel.hpo.results.JsonResult
filename_template: "{job.cwd}/{job.job_name}_result.json"

params:
_convert_: partial
_target_: aiaccel.hpo.apps.optimize.HparamsManager
x1: [0, 1]
x2:
_target_: aiaccel.hpo.optuna.suggest_wrapper.SuggestFloat
name: x2
low: 0.0
high: 1.0
log: false

n_trials: 30
25 changes: 25 additions & 0 deletions tests/hpo/result/config_for_pkl_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
study:
_target_: optuna.create_study
direction: minimize
study_name: my_study
load_if_exists: false
sampler:
_target_: optuna.samplers.TPESampler
seed: 0

result:
_target_: aiaccel.hpo.results.PickleResult
filename_template: "{job.cwd}/{job.job_name}_result.pkl"

params:
_convert_: partial
_target_: aiaccel.hpo.apps.optimize.HparamsManager
x1: [0, 1]
x2:
_target_: aiaccel.hpo.optuna.suggest_wrapper.SuggestFloat
name: x2
low: 0.0
high: 1.0
log: false

n_trials: 30
25 changes: 25 additions & 0 deletions tests/hpo/result/config_for_stdo_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
study:
_target_: optuna.create_study
direction: minimize
study_name: my_study
load_if_exists: false
sampler:
_target_: optuna.samplers.TPESampler
seed: 0

result:
_target_: aiaccel.hpo.results.StdoutResult
filename_template: "{job.cwd}/{job.job_name}_result.txt"

params:
_convert_: partial
_target_: aiaccel.hpo.apps.optimize.HparamsManager
x1: [0, 1]
x2:
_target_: aiaccel.hpo.optuna.suggest_wrapper.SuggestFloat
name: x2
low: 0.0
high: 1.0
log: false

n_trials: 30
22 changes: 22 additions & 0 deletions tests/hpo/result/objective_for_json_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from argparse import ArgumentParser
import json
from pathlib import Path


def main() -> None:
parser = ArgumentParser()
parser.add_argument("dst_filename", type=Path)
parser.add_argument("--x1", type=float)
parser.add_argument("--x2", type=float)
args = parser.parse_args()

x1, x2 = args.x1, args.x2

y = (x1**2) - (4.0 * x1) + (x2**2) - x2 - (x1 * x2)

with open(args.dst_filename, "w") as f:
json.dump({"objective": y}, f)


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions tests/hpo/result/objective_for_json_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

#$-l rt_C.small=1
#$-cwd

source /etc/profile.d/modules.sh
module load gcc/13.2.0
module load python/3.10/3.10.14

python objective_for_json_test.py $@
22 changes: 22 additions & 0 deletions tests/hpo/result/objective_for_pkl_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from argparse import ArgumentParser
from pathlib import Path
import pickle as pkl


def main() -> None:
parser = ArgumentParser()
parser.add_argument("dst_filename", type=Path)
parser.add_argument("--x1", type=float)
parser.add_argument("--x2", type=float)
args = parser.parse_args()

x1, x2 = args.x1, args.x2

y = (x1**2) - (4.0 * x1) + (x2**2) - x2 - (x1 * x2)

with open(args.dst_filename, "wb") as f:
pkl.dump(y, f)


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions tests/hpo/result/objective_for_pkl_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

#$-l rt_C.small=1
#$-cwd

source /etc/profile.d/modules.sh
module load gcc/13.2.0
module load python/3.10/3.10.14

python objective_for_pkl_test.py $@
18 changes: 18 additions & 0 deletions tests/hpo/result/objective_for_stdo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from argparse import ArgumentParser


def main() -> None:
parser = ArgumentParser()
parser.add_argument("--x1", type=float)
parser.add_argument("--x2", type=float)
args = parser.parse_args()

x1, x2 = args.x1, args.x2

y = (x1**2) - (4.0 * x1) + (x2**2) - x2 - (x1 * x2)

print(y)


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions tests/hpo/result/objective_for_stdo_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

#$-l rt_C.small=1
#$-cwd

source /etc/profile.d/modules.sh
module load gcc/13.2.0
module load python/3.10/3.10.14

python objective_for_stdo_test.py "${@:2}" > "$1"
26 changes: 26 additions & 0 deletions tests/hpo/result/test_json_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import json
import os
from pathlib import Path
import shutil
import tempfile
from unittest.mock import patch

import pytest

Expand All @@ -17,6 +19,21 @@ def temp_dir() -> Generator[Path]:
temp_path = Path(tmp_dir)
original_dir = os.getcwd()
os.chdir(tmp_dir)
source_dir = Path(__file__).parent
test_files = ["config_for_json_test.yaml", "objective_for_json_test.sh", "objective_for_json_test.py"]

for file_name in test_files:
source_file = source_dir / file_name
target_file = temp_path / file_name
if source_file.exists():
shutil.copy2(source_file, target_file)
os.chmod(target_file, 0o755)
print(f"\n=== Content of {file_name} ===")
print((target_file).read_text())
print("=" * 40)
else:
pytest.skip(f"Required test file {file_name} not found in {source_dir}")

yield temp_path
os.chdir(original_dir)

Expand Down Expand Up @@ -46,3 +63,12 @@ def test_load_str(temp_dir: Path) -> None:
job = LocalJobExecutor(Path(""), work_dir=temp_dir)
json_result = JsonResult("{job.cwd}/result.json")
assert json_result.load(job) == "result"


def test_result(temp_dir: Path) -> None:
with patch("sys.argv", ["optimize.py", "objective_for_json_test.sh", "--config", "config_for_json_test.yaml"]):
from aiaccel.hpo.apps.optimize import main

main()
assert (temp_dir / "objective_for_json_test.sh_29_result.json").exists()
assert (temp_dir / "objective_for_json_test.sh_30_result.json").exists() is False
29 changes: 27 additions & 2 deletions tests/hpo/result/test_pickle_result.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections.abc import Generator
import pickle as pkl

import os
from pathlib import Path
import pickle as pkl
import shutil
import tempfile
from unittest.mock import patch

import pytest

Expand All @@ -18,6 +19,21 @@ def temp_dir() -> Generator[Path]:
temp_path = Path(tmp_dir)
original_dir = os.getcwd()
os.chdir(tmp_dir)
source_dir = Path(__file__).parent
test_files = ["config_for_pkl_test.yaml", "objective_for_pkl_test.sh", "objective_for_pkl_test.py"]

for file_name in test_files:
source_file = source_dir / file_name
target_file = temp_path / file_name
if source_file.exists():
shutil.copy2(source_file, target_file)
os.chmod(target_file, 0o755)
print(f"\n=== Content of {file_name} ===")
print((target_file).read_text())
print("=" * 40)
else:
pytest.skip(f"Required test file {file_name} not found in {source_dir}")

yield temp_path
os.chdir(original_dir)

Expand Down Expand Up @@ -50,3 +66,12 @@ def test_load_str(temp_dir: Path) -> None:
job = LocalJobExecutor(Path(""), work_dir=temp_dir)
pkl_result = PickleResult("{job.cwd}/result.pkl")
assert pkl_result.load(job) == "result"


def test_result(temp_dir: Path) -> None:
with patch("sys.argv", ["optimize.py", "objective_for_pkl_test.sh", "--config", "config_for_pkl_test.yaml"]):
from aiaccel.hpo.apps.optimize import main

main()
assert (temp_dir / "objective_for_pkl_test.sh_29_result.pkl").exists()
assert (temp_dir / "objective_for_pkl_test.sh_30_result.pkl").exists() is False
28 changes: 26 additions & 2 deletions tests/hpo/result/test_stdout_result.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections.abc import Generator
import pickle as pkl

import os
from pathlib import Path
import shutil
import tempfile
from unittest.mock import patch

import pytest

Expand All @@ -18,6 +18,21 @@ def temp_dir() -> Generator[Path]:
temp_path = Path(tmp_dir)
original_dir = os.getcwd()
os.chdir(tmp_dir)
source_dir = Path(__file__).parent
test_files = ["config_for_stdo_test.yaml", "objective_for_stdo_test.sh", "objective_for_stdo_test.py"]

for file_name in test_files:
source_file = source_dir / file_name
target_file = temp_path / file_name
if source_file.exists():
shutil.copy2(source_file, target_file)
os.chmod(target_file, 0o755)
print(f"\n=== Content of {file_name} ===")
print((target_file).read_text())
print("=" * 40)
else:
pytest.skip(f"Required test file {file_name} not found in {source_dir}")

yield temp_path
os.chdir(original_dir)

Expand Down Expand Up @@ -47,3 +62,12 @@ def test_load_str(temp_dir: Path) -> None:
job = LocalJobExecutor(Path(""), work_dir=temp_dir)
stdout_result = StdoutResult("{job.cwd}/result.txt")
assert stdout_result.load(job) == "result"


def test_result(temp_dir: Path) -> None:
with patch("sys.argv", ["optimize.py", "objective_for_stdo_test.sh", "--config", "config_for_stdo_test.yaml"]):
from aiaccel.hpo.apps.optimize import main

main()
assert (temp_dir / "objective_for_stdo_test.sh_29_result.txt").exists()
assert (temp_dir / "objective_for_stdo_test.sh_30_result.txt").exists() is False

0 comments on commit 9c5662f

Please sign in to comment.