Skip to content

Commit

Permalink
Fixed CUDA Execution Provider
Browse files Browse the repository at this point in the history
  • Loading branch information
clementpoiret committed Sep 25, 2021
1 parent 612d6d8 commit 1f8d61c
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Now, let's dive into the details.
They specify the path to the dataset (or MRI) and the pattern to find the files.

All parameters starting with ``roiloc.`` are directly linked to our home-made ROI location algorithm.
You can find more information about it in the `related GitHub repository <https://github.com/clementpoiret/HSF>`_.
You can find more information about it in the `related GitHub repository <https://github.com/clementpoiret/ROILoc>`_.

To date, we propose 4 different segmentation algorithms (from the fastest to the most accurate):

Expand Down
2 changes: 1 addition & 1 deletion hsf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.0'
__version__ = '0.1.1'
7 changes: 6 additions & 1 deletion hsf/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def get_inference_sessions(models_path: PosixPath) -> list:
p = Path(models_path).expanduser()
models = list(p.glob("*.onnx"))

return [ort.InferenceSession(str(model_path)) for model_path in models]
ep_list = ['CUDAExecutionProvider', 'CPUExecutionProvider']

return [
ort.InferenceSession(str(model_path), providers=ep_list)
for model_path in models
]


def to_ca_mode(logits: torch.Tensor, ca_mode: str = "1/2/3") -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "HSF"
version = "0.1.0"
version = "0.1.1"
description = "Simple yet exhaustive segmentation tool of the Hippocampal Subfields in T1w and T2w MRIs."
authors = ["Clément POIRET <[email protected]>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


def test_version():
assert __version__ == '0.1.0'
assert __version__ == '0.1.1'

0 comments on commit 1f8d61c

Please sign in to comment.