Skip to content

Commit

Permalink
Attempt to fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 8, 2024
1 parent 483a95f commit ac8ef38
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
26 changes: 19 additions & 7 deletions tests/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,29 @@

import pytest

SKIP_FILES: set[str] = {"_version.py"}
"""set[str]: These files are not required to have a copyright notice."""

COPYRIGHT_NOTICE: str = '"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.'
"""str: Every file must start with this notice."""

PYTHON_FILES: list[Path] = []
"""list[Path]: Python files to scan for headers."""

@pytest.mark.parametrize("python_file", Path(__file__).parents[1].rglob("**/*.py"))
def test_presence_of_copyright_header(python_file: Path) -> None:
if python_file.name in SKIP_FILES:
return
_root = Path(__file__).parents[1]
for path in _root.rglob("**/*.py"):
relative_path = path.relative_to(_root)

# Ignore a possible virtual environment.
if str(relative_path.parents[-2]) in {"venv"}:
continue

# Ignore the automatically generated version file.
if relative_path.name in {"_version.py"}:
continue

PYTHON_FILES.append(path)


@pytest.mark.parametrize("python_file", PYTHON_FILES)
def test_presence_of_copyright_header(python_file: Path) -> None:
with open(python_file) as f:
lines = list(f.read().splitlines())

Expand Down
6 changes: 3 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def test_aurora_small() -> None:
# Load test input.
path = hf_hub_download(
repo_id=os.environ["HUGGINGFACE_REPO"],
filename="aurora-0.25-small-pretrained-test-input.pickle",
filename="aurora-0.25-small-pretrained-test-input-sse2.pickle",
)
with open(path, "rb") as f:
test_input: SavedBatch = pickle.load(f)

# Load test output.
path = hf_hub_download(
repo_id=os.environ["HUGGINGFACE_REPO"],
filename="aurora-0.25-small-pretrained-test-output.pickle",
filename="aurora-0.25-small-pretrained-test-output-sse2.pickle",
)
with open(path, "rb") as f:
test_output: SavedBatch = pickle.load(f)
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_aurora_small() -> None:

# Load the checkpoint and run the model.
model.load_checkpoint(os.environ["HUGGINGFACE_REPO"], "aurora-0.25-small-pretrained.ckpt")
torch.manual_seed(0) # Very important to seed! The test data was generated using this.
model.eval()
with torch.inference_mode():
pred = model.forward(batch)

Expand Down

0 comments on commit ac8ef38

Please sign in to comment.