From 9d8c96dafe7310e4d9cb5e5a3f358fc71e9bbd91 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 29 May 2024 13:26:24 -0400 Subject: [PATCH 1/4] new: make `torch` optional --- pyproject.toml | 6 ++++-- python/phenolrs/pyg_loader.py | 23 +++++++++++++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 63d188b..54cb5bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,8 +12,6 @@ classifiers = [ ] dependencies = [ "numpy", - "torch", - "torch-geometric", "python-arango" ] @@ -22,6 +20,10 @@ tests = [ "pytest", "arango-datasets" ] +torch = [ + "torch", + "torch-geometric", +] dynamic = ["version"] [tool.maturin] diff --git a/python/phenolrs/pyg_loader.py b/python/phenolrs/pyg_loader.py index 977176d..b04324a 100644 --- a/python/phenolrs/pyg_loader.py +++ b/python/phenolrs/pyg_loader.py @@ -1,12 +1,19 @@ import typing import numpy as np -import torch -from torch_geometric.data import Data, HeteroData from phenolrs import PhenolError from phenolrs.numpy_loader import NumpyLoader +try: + import torch + from torch_geometric.data import Data, HeteroData + + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + + class PygLoader: @staticmethod @@ -20,7 +27,11 @@ def load_into_pyg_data( tls_cert: typing.Any | None = None, parallelism: int | None = None, batch_size: int | None = None, - ) -> tuple[Data, dict[str, dict[str, int]], dict[str, dict[int, str]]]: + ) -> tuple["Data", dict[str, dict[str, int]], dict[str, dict[int, str]]]: + if not TORCH_AVAILABLE: + m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" + raise ImportError(m) + if "vertexCollections" not in metagraph: raise PhenolError("vertexCollections not found in metagraph") if "edgeCollections" not in metagraph: @@ -99,7 +110,11 @@ def load_into_pyg_heterodata( tls_cert: typing.Any | None = None, parallelism: int | None = None, batch_size: int | None = None, - ) -> tuple[HeteroData, dict[str, dict[str, int]], dict[str, dict[int, str]]]: + ) -> tuple["HeteroData", dict[str, dict[str, int]], dict[str, dict[int, str]]]: + if not TORCH_AVAILABLE: + m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" + raise ImportError(m) + if "vertexCollections" not in metagraph: raise PhenolError("vertexCollections not found in metagraph") if "edgeCollections" not in metagraph: From 9fb79170ba2a9e6a357326347b93c78143a0bf8a Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 29 May 2024 13:27:39 -0400 Subject: [PATCH 2/4] fix: lint --- python/phenolrs/pyg_loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/phenolrs/pyg_loader.py b/python/phenolrs/pyg_loader.py index b04324a..d4f6ec5 100644 --- a/python/phenolrs/pyg_loader.py +++ b/python/phenolrs/pyg_loader.py @@ -14,7 +14,6 @@ TORCH_AVAILABLE = False - class PygLoader: @staticmethod def load_into_pyg_data( From 600c02bfece36c23b5e2fb51cf0c4fd29a5cb1f3 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 29 May 2024 13:28:58 -0400 Subject: [PATCH 3/4] bump version --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9e5192c..89acb4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,7 +806,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "phenolrs" -version = "0.4.1" +version = "0.4.2" dependencies = [ "anyhow", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 6262dfd..074042e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "phenolrs" -version = "0.4.1" +version = "0.4.2" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html From 3f0e2f5b66eaea260e0648d063461ec949f6d9c1 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 29 May 2024 13:30:43 -0400 Subject: [PATCH 4/4] fix: noqa --- python/phenolrs/pyg_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/phenolrs/pyg_loader.py b/python/phenolrs/pyg_loader.py index d4f6ec5..5a7bca7 100644 --- a/python/phenolrs/pyg_loader.py +++ b/python/phenolrs/pyg_loader.py @@ -28,7 +28,7 @@ def load_into_pyg_data( batch_size: int | None = None, ) -> tuple["Data", dict[str, dict[str, int]], dict[str, dict[int, str]]]: if not TORCH_AVAILABLE: - m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" + m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" # noqa: E501 raise ImportError(m) if "vertexCollections" not in metagraph: @@ -111,7 +111,7 @@ def load_into_pyg_heterodata( batch_size: int | None = None, ) -> tuple["HeteroData", dict[str, dict[str, int]], dict[str, dict[int, str]]]: if not TORCH_AVAILABLE: - m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" + m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" # noqa: E501 raise ImportError(m) if "vertexCollections" not in metagraph: