Skip to content

Commit

Permalink
Add sample unit tests (foundation-model-stack#61)
Browse files Browse the repository at this point in the history
* add sample unit test

Signed-off-by: ted chang <[email protected]>

* Update tests/utils/test_data_type_utils.py

Co-authored-by: Sukriti Sharma <[email protected]>
Signed-off-by: ted chang <[email protected]>

* Update tests/utils/test_data_type_utils.py

Co-authored-by: Sukriti Sharma <[email protected]>
Signed-off-by: ted chang <[email protected]>

* re-raise valueError instead of exit

Signed-off-by: ted chang <[email protected]>

---------

Signed-off-by: ted chang <[email protected]>
Co-authored-by: Sukriti Sharma <[email protected]>
  • Loading branch information
2 people authored and jbusche committed Mar 25, 2024
1 parent 82307c2 commit 71a3ed6
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 2 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Test

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install -r setup_requirements.txt
- name: Run unit tests
run: tox -e py
35 changes: 35 additions & 0 deletions tests/utils/test_data_type_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Third Party
import pytest
import torch

# Local
from tuning.utils import data_type_utils

dtype_dict = {
"bool": torch.bool,
"double": torch.double,
"float32": torch.float32,
"int64": torch.int64,
"long": torch.long,
}


def test_str_to_torch_dtype():
for t in dtype_dict.keys():
assert data_type_utils.str_to_torch_dtype(t) == dtype_dict.get(t)


def test_str_to_torch_dtype_exit():
with pytest.raises(ValueError):
data_type_utils.str_to_torch_dtype("foo")


def test_get_torch_dtype():
for t in dtype_dict.keys():
# When passed a string, it gets converted to torch.dtype
assert data_type_utils.get_torch_dtype(t) == dtype_dict.get(t)
# When passed a torch.dtype, we get the same torch.dtype returned
assert data_type_utils.get_torch_dtype(dtype_dict.get(t)) == dtype_dict.get(t)
10 changes: 9 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
[tox]
envlist = lint, fmt
envlist = py, lint, fmt

[testenv]
description = run unit tests
deps =
pytest>=7
-r requirements.txt
commands =
pytest {posargs:tests}

[testenv:fmt]
description = format with pre-commit
Expand Down
2 changes: 1 addition & 1 deletion tuning/utils/data_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def str_to_torch_dtype(dtype_str: str) -> torch.dtype:
dt = getattr(torch, dtype_str, None)
if not isinstance(dt, torch.dtype):
logger.error(" ValueError: Unrecognized data type of a torch.Tensor")
exit(-1)
raise ValueError("Unrecognized data type of a torch.Tensor")
return dt


Expand Down

0 comments on commit 71a3ed6

Please sign in to comment.