diff --git a/Cargo.lock b/Cargo.lock index 9e5192c..89acb4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,7 +806,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "phenolrs" -version = "0.4.1" +version = "0.4.2" dependencies = [ "anyhow", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 6262dfd..074042e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "phenolrs" -version = "0.4.1" +version = "0.4.2" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/pyproject.toml b/pyproject.toml index 63d188b..54cb5bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,8 +12,6 @@ classifiers = [ ] dependencies = [ "numpy", - "torch", - "torch-geometric", "python-arango" ] @@ -22,6 +20,10 @@ tests = [ "pytest", "arango-datasets" ] +torch = [ + "torch", + "torch-geometric", +] dynamic = ["version"] [tool.maturin] diff --git a/python/phenolrs/pyg_loader.py b/python/phenolrs/pyg_loader.py index 977176d..5a7bca7 100644 --- a/python/phenolrs/pyg_loader.py +++ b/python/phenolrs/pyg_loader.py @@ -1,12 +1,18 @@ import typing import numpy as np -import torch -from torch_geometric.data import Data, HeteroData from phenolrs import PhenolError from phenolrs.numpy_loader import NumpyLoader +try: + import torch + from torch_geometric.data import Data, HeteroData + + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + class PygLoader: @staticmethod @@ -20,7 +26,11 @@ def load_into_pyg_data( tls_cert: typing.Any | None = None, parallelism: int | None = None, batch_size: int | None = None, - ) -> tuple[Data, dict[str, dict[str, int]], dict[str, dict[int, str]]]: + ) -> tuple["Data", dict[str, dict[str, int]], dict[str, dict[int, str]]]: + if not TORCH_AVAILABLE: + m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" # noqa: E501 + raise ImportError(m) + if "vertexCollections" not in metagraph: raise PhenolError("vertexCollections not found in metagraph") if "edgeCollections" not in metagraph: @@ -99,7 +109,11 @@ def load_into_pyg_heterodata( tls_cert: typing.Any | None = None, parallelism: int | None = None, batch_size: int | None = None, - ) -> tuple[HeteroData, dict[str, dict[str, int]], dict[str, dict[int, str]]]: + ) -> tuple["HeteroData", dict[str, dict[str, int]], dict[str, dict[int, str]]]: + if not TORCH_AVAILABLE: + m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" # noqa: E501 + raise ImportError(m) + if "vertexCollections" not in metagraph: raise PhenolError("vertexCollections not found in metagraph") if "edgeCollections" not in metagraph: