diff --git a/examples/transit_infer.py b/examples/transit_infer.py index 066a0470..80673fdb 100644 --- a/examples/transit_infer.py +++ b/examples/transit_infer.py @@ -5,7 +5,8 @@ from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset from qusi.hadryss_model import Hadryss -from qusi.infer_session import get_device, infer_session +from qusi.infer_session import infer_session +from qusi.device import get_device from qusi.light_curve_collection import LightCurveCollection from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve diff --git a/examples/transit_infinite_dataset_test.py b/examples/transit_infinite_dataset_test.py index 8776206c..7c3cb124 100644 --- a/examples/transit_infinite_dataset_test.py +++ b/examples/transit_infinite_dataset_test.py @@ -8,6 +8,7 @@ from torchmetrics.classification import BinaryAccuracy from qusi.hadryss_model import Hadryss +from qusi.device import get_device from qusi.light_curve_collection import LabeledLightCurveCollection from qusi.light_curve_dataset import LightCurveDataset from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve @@ -71,14 +72,6 @@ def infinite_datasets_test_session(test_datasets: list[LightCurveDataset], model return results -def get_device(): - if torch.cuda.is_available(): - device = torch.device('cuda') - else: - device = torch.device('cpu') - return device - - def infinite_dataset_test_phase(dataloader, model: Module, metric_functions: list[Module], device: Device, steps: int): batch_count = 0 metric_totals = torch.zeros(size=[len(metric_functions)]) diff --git a/src/qusi/device.py b/src/qusi/device.py new file mode 100644 index 00000000..37455a22 --- /dev/null +++ b/src/qusi/device.py @@ -0,0 +1,10 @@ +import torch +from torch.types import Device + + +def get_device() -> Device: + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + return device diff --git a/src/qusi/infer_session.py b/src/qusi/infer_session.py index 60f580fc..bc2a2f7b 100644 --- a/src/qusi/infer_session.py +++ b/src/qusi/infer_session.py @@ -22,14 +22,6 @@ def infer_session( return results -def get_device() -> Device: - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - return device - - def infer_phase(dataloader, model: Module, device: Device): batch_count = 0 batches_of_predicted_targets = [] diff --git a/tests/end_to_end_tests/test_toy_infer_session.py b/tests/end_to_end_tests/test_toy_infer_session.py index dd2ec2ab..9c3f046f 100644 --- a/tests/end_to_end_tests/test_toy_infer_session.py +++ b/tests/end_to_end_tests/test_toy_infer_session.py @@ -3,7 +3,8 @@ import numpy as np -from qusi.infer_session import get_device, infer_session +from qusi.infer_session import infer_session +from qusi.device import get_device from qusi.light_curve_dataset import ( default_light_curve_post_injection_transform, )