Skip to content

Commit

Permalink
Move get_device to a separate location
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Apr 29, 2024
1 parent d8eff4c commit 2235440
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 18 deletions.
3 changes: 2 additions & 1 deletion examples/transit_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset
from qusi.hadryss_model import Hadryss
from qusi.infer_session import get_device, infer_session
from qusi.infer_session import infer_session
from qusi.device import get_device
from qusi.light_curve_collection import LightCurveCollection
from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve

Expand Down
9 changes: 1 addition & 8 deletions examples/transit_infinite_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchmetrics.classification import BinaryAccuracy

from qusi.hadryss_model import Hadryss
from qusi.device import get_device
from qusi.light_curve_collection import LabeledLightCurveCollection
from qusi.light_curve_dataset import LightCurveDataset
from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve
Expand Down Expand Up @@ -71,14 +72,6 @@ def infinite_datasets_test_session(test_datasets: list[LightCurveDataset], model
return results


def get_device():
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
return device


def infinite_dataset_test_phase(dataloader, model: Module, metric_functions: list[Module], device: Device, steps: int):
batch_count = 0
metric_totals = torch.zeros(size=[len(metric_functions)])
Expand Down
10 changes: 10 additions & 0 deletions src/qusi/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from torch.types import Device


def get_device() -> Device:
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
return device
8 changes: 0 additions & 8 deletions src/qusi/infer_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ def infer_session(
return results


def get_device() -> Device:
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
return device


def infer_phase(dataloader, model: Module, device: Device):
batch_count = 0
batches_of_predicted_targets = []
Expand Down
3 changes: 2 additions & 1 deletion tests/end_to_end_tests/test_toy_infer_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import numpy as np

from qusi.infer_session import get_device, infer_session
from qusi.infer_session import infer_session
from qusi.device import get_device
from qusi.light_curve_dataset import (
default_light_curve_post_injection_transform,
)
Expand Down

0 comments on commit 2235440

Please sign in to comment.