Skip to content

Commit

Permalink
Merge pull request #11 from arangoml/support-features-none-syntax
Browse files Browse the repository at this point in the history
address `None` value in metagraph
  • Loading branch information
Alex Geenen authored May 10, 2024
2 parents 984a906 + f8d8920 commit e80c789
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 46 deletions.
19 changes: 19 additions & 0 deletions python/phenolrs/numpy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def load_graph_to_numpy(
if "vertexCollections" not in metagraph:
raise PhenolError("vertexCollections not found in metagraph")

# Address the possibility of having something like this:
# "USER": {"x": {"features": None}}
# Should be converted to:
# "USER": {"x": "features"}
entries: dict[str, typing.Any]
for v_col_name, entries in metagraph["vertexCollections"].items():
for source_name, value in entries.items():
if isinstance(value, dict):
if len(value) != 1:
m = f"Only one feature field should be specified per attribute. Found {value}" # noqa: E501
raise PhenolError(m)

value_key = list(value.keys())[0]
if value[value_key] is not None:
m = f"Invalid value for feature {source_name}: {value_key}. Found {value[value_key]}" # noqa: E501
raise PhenolError(m)

metagraph["vertexCollections"][v_col_name][source_name] = value_key

vertex_collections = [
{"name": v_col_name, "fields": list(entries.values())}
for v_col_name, entries in metagraph["vertexCollections"].items()
Expand Down
54 changes: 30 additions & 24 deletions python/phenolrs/pyg_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,21 @@ def load_into_pyg_data(
v_col_spec_name = list(metagraph["vertexCollections"].keys())[0]
v_col_spec = list(metagraph["vertexCollections"].values())[0]

features_by_col, coo_map, col_to_key_inds, vertex_cols_source_to_output = (
NumpyLoader.load_graph_to_numpy(
database,
metagraph,
hosts,
user_jwt,
username,
password,
tls_cert,
parallelism,
batch_size,
)
(
features_by_col,
coo_map,
col_to_key_inds,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
database,
metagraph,
hosts,
user_jwt,
username,
password,
tls_cert,
parallelism,
batch_size,
)

data = Data()
Expand Down Expand Up @@ -106,18 +109,21 @@ def load_into_pyg_heterodata(
if len(metagraph["edgeCollections"]) == 0:
raise PhenolError("edgeCollections must map to non-empty dictionary")

features_by_col, coo_map, col_to_key_inds, vertex_cols_source_to_output = (
NumpyLoader.load_graph_to_numpy(
database,
metagraph,
hosts,
user_jwt,
username,
password,
tls_cert,
parallelism,
batch_size,
)
(
features_by_col,
coo_map,
col_to_key_inds,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
database,
metagraph,
hosts,
user_jwt,
username,
password,
tls_cert,
parallelism,
batch_size,
)
data = HeteroData()
for col in features_by_col.keys():
Expand Down
65 changes: 43 additions & 22 deletions python/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,39 @@ def test_phenol_abide_hetero(
assert isinstance(result, HeteroData)
assert result["Subjects"]["x"].shape == (871, 2000)

result = PygLoader.load_into_pyg_heterodata(
connection_information["dbName"],
{
"vertexCollections": {
"Subjects": {"x": {"brain_fmri_features": None}, "y": "label"}
},
"edgeCollections": {"medical_affinity_graph": {}},
},
[connection_information["url"]],
username=connection_information["username"],
password=connection_information["password"],
)
assert isinstance(result, HeteroData)
assert result["Subjects"]["x"].shape == (871, 2000)


def test_phenol_abide_numpy(
load_abide: None, connection_information: dict[str, str]
) -> None:
features_by_col, coo_map, col_to_key_inds, vertex_cols_source_to_output = (
NumpyLoader.load_graph_to_numpy(
connection_information["dbName"],
{
"vertexCollections": {"Subjects": {"x": "brain_fmri_features"}},
"edgeCollections": {"medical_affinity_graph": {}},
},
[connection_information["url"]],
username=connection_information["username"],
password=connection_information["password"],
)
(
features_by_col,
coo_map,
col_to_key_inds,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
connection_information["dbName"],
{
"vertexCollections": {"Subjects": {"x": "brain_fmri_features"}},
"edgeCollections": {"medical_affinity_graph": {}},
},
[connection_information["url"]],
username=connection_information["username"],
password=connection_information["password"],
)

assert features_by_col["Subjects"]["brain_fmri_features"].shape == (871, 2000)
Expand All @@ -47,17 +65,20 @@ def test_phenol_abide_numpy(
assert len(col_to_key_inds["Subjects"]) == 871
assert vertex_cols_source_to_output == {"Subjects": {"brain_fmri_features": "x"}}

features_by_col, coo_map, col_to_key_inds, vertex_cols_source_to_output = (
NumpyLoader.load_graph_to_numpy(
connection_information["dbName"],
{
"vertexCollections": {"Subjects": {"x": "brain_fmri_features"}},
# "edgeCollections": {"medical_affinity_graph": {}},
},
[connection_information["url"]],
username=connection_information["username"],
password=connection_information["password"],
)
(
features_by_col,
coo_map,
col_to_key_inds,
vertex_cols_source_to_output,
) = NumpyLoader.load_graph_to_numpy(
connection_information["dbName"],
{
"vertexCollections": {"Subjects": {"x": "brain_fmri_features"}},
# "edgeCollections": {"medical_affinity_graph": {}},
},
[connection_information["url"]],
username=connection_information["username"],
password=connection_information["password"],
)

assert features_by_col["Subjects"]["brain_fmri_features"].shape == (871, 2000)
Expand Down

0 comments on commit e80c789

Please sign in to comment.