Skip to content

Commit

Permalink
Add ion capabilities (#146)
Browse files Browse the repository at this point in the history
* add dgl ion capabilities

* move files around and update tests

* add test model and tidy files

* add ion tests

* remove commented out lines

* add pointer to apache 2 license

* update changelog [skip ci]
  • Loading branch information
lilyminium authored Oct 17, 2024
1 parent dffc164 commit 6cb4387
Show file tree
Hide file tree
Showing 102 changed files with 3,833 additions and 83 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The rules for this file:

### Fixed
- Check lookup tables for allowed molecules before ChemicalDomain for forbidden ones (PR #145, Issue #144)
- Add support for single atoms (PR #146, Issue #138)


## v0.4.0 -- 2024-07-18
Expand Down
80 changes: 78 additions & 2 deletions openff/nagl/molecule/_dgl/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, List, TYPE_CHECKING, Optional

import torch
import numpy as np
from openff.utilities import requires_package

from openff.nagl.features.atoms import AtomFeature
Expand Down Expand Up @@ -39,7 +40,8 @@ def openff_molecule_to_base_dgl_graph(
{
("atom", forward, "atom"): (indices_a, indices_b),
("atom", reverse, "atom"): (indices_b, indices_a),
}
},
num_nodes_dict={"atom": molecule.n_atoms},
)
return molecule_graph

Expand Down Expand Up @@ -99,7 +101,7 @@ def openff_molecule_to_dgl_graph(

for direction in (forward, reverse):
n_bonds = len(molecule.bonds)
if bond_feature_tensor is not None:
if bond_feature_tensor is not None and n_bonds:
bond_feature_tensor = bond_feature_tensor.reshape(n_bonds, -1)
else:
bond_feature_tensor = torch.zeros((n_bonds, 0))
Expand All @@ -108,13 +110,87 @@ def openff_molecule_to_dgl_graph(

return molecule_graph

@requires_package("dgl")
def heterograph_to_homograph_no_edges(G: "dgl.DGLHeteroGraph", ndata=None, edata=None) -> "dgl.DGLGraph":
"""
Copied and modified from dgl.python.dgl.convert.to_homogeneous,
but with the edges removed.
This part of the code is licensed under the Apache 2.0 license according
to the terms of DGL (https://github.com/dmlc/dgl?tab=Apache-2.0-1-ov-file).
Please see our third-party license file for more information
(https://github.com/openforcefield/openff-nagl/blob/main/LICENSE-3RD-PARTY)
"""
import dgl
from dgl import backend as F
from dgl.base import EID, NID, ETYPE, NTYPE
from dgl.heterograph import combine_frames

# TODO: revisit in case DGL accounts for this in the future
num_nodes_per_ntype = [G.num_nodes(ntype) for ntype in G.ntypes]
offset_per_ntype = np.insert(np.cumsum(num_nodes_per_ntype), 0, 0)
srcs = []
dsts = []
nids = []
eids = []
ntype_ids = []
etype_ids = []
total_num_nodes = 0

for ntype_id, ntype in enumerate(G.ntypes):
num_nodes = G.num_nodes(ntype)
total_num_nodes += num_nodes
# Type ID is always in int64
ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, G.device))
nids.append(F.arange(0, num_nodes, G.idtype, G.device))

for etype_id, etype in enumerate(G.canonical_etypes):
srctype, _, dsttype = etype
src, dst = G.all_edges(etype=etype, order="eid")
num_edges = len(src)
srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)]))
dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)]))
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, G.device))
eids.append(F.arange(0, num_edges, G.idtype, G.device))

retg = dgl.graph(
(F.cat(srcs, 0), F.cat(dsts, 0)),
num_nodes=total_num_nodes,
idtype=G.idtype,
device=G.device,
)

# copy features
if ndata is None:
ndata = []
if edata is None:
edata = []
comb_nf = combine_frames(
G._node_frames, range(len(G.ntypes)), col_names=ndata
)
if comb_nf is not None:
retg.ndata.update(comb_nf)

retg.ndata[NID] = F.cat(nids, 0)
retg.edata[EID] = F.cat(eids, 0)
retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0)

return retg




@requires_package("dgl")
def dgl_heterograph_to_homograph(graph: "dgl.DGLHeteroGraph") -> "dgl.DGLGraph":
import dgl

try:
homo_graph = dgl.to_homogeneous(graph, ndata=[FEATURE], edata=[FEATURE])
except TypeError as e:
if graph.num_edges() == 0:
homo_graph = heterograph_to_homograph_no_edges(graph)
except KeyError:
# A nasty workaround to check when we don't have any atom / bond features as
# DGL doesn't allow easy querying of features dicts for hetereographs with
Expand Down
8 changes: 7 additions & 1 deletion openff/nagl/molecule/_graph/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,13 @@ def in_edges(self, nodes, form="uv"):
raise ValueError("Unknown form: {}".format(form))

def _bond_indices(self):
u, v = map(list, zip(*self.graph.edges()))
try:
u, v = map(list, zip(*self.graph.edges()))
except ValueError as e:
# this may be due to there not being bonds
if not self.graph.edges():
return torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)
raise e
U = torch.tensor(u, dtype=torch.long)
V = torch.tensor(v, dtype=torch.long)
return U, V
Expand Down
52 changes: 52 additions & 0 deletions openff/nagl/tests/data/example_charges/generate-new-sdfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pathlib

import click
import tqdm

from openff.toolkit import Molecule
from openff.units import unit
from openff.nagl import GNNModel


@click.command()
@click.option(
"--input", "-i",
"input_directory",
required=True,
type=click.Path(exists=True, file_okay=False, dir_okay=True),
)
@click.option(
"--output", "-o",
"output_directory",
required=True,
type=click.Path(file_okay=False, dir_okay=True),
)
@click.option(
"--model", "-m",
"model_path",
required=True,
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
def main(
input_directory: str,
output_directory: str,
model_path: str,
):
input_files = sorted(pathlib.Path(input_directory).glob("*.sdf"))
output_directory = pathlib.Path(output_directory)
output_directory.mkdir(exist_ok=True, parents=True)

model = GNNModel.load(model_path, eval_mode=True)

for input_file in tqdm.tqdm(input_files):
mol = Molecule.from_file(input_file, "SDF", allow_undefined_stereo=True)
mol._partial_charges = (
model.compute_property(mol, as_numpy=True)
* unit.elementary_charge
)
output_file = output_directory / input_file.name
mol.to_file(output_file, "SDF")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

-OEChem-10172415303D

3 2 0 0 0 0 0 0 0999 V2000
0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
1.0000 0.0000 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0
-0.7500 0.0000 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1 2 3 0 0 0 0
1 3 1 0 0 0 0
M END
> <atom.dprop.PartialCharge>
0.190730 -0.372090 0.181360

$$$$
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

-OEChem-10172415303D

9 9 0 0 0 0 0 0 0999 V2000
0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
1.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
0.5000 0.8682 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
-0.1300 -0.7387 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.7050 0.2560 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.7050 0.2560 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.1300 -0.7387 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.0747 1.3501 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.0747 1.3501 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1 3 1 0 0 0 0
1 2 1 0 0 0 0
2 3 1 0 0 0 0
1 4 1 0 0 0 0
1 5 1 0 0 0 0
2 6 1 0 0 0 0
2 7 1 0 0 0 0
3 8 1 0 0 0 0
3 9 1 0 0 0 0
M END
> <atom.dprop.PartialCharge>
-0.136720 -0.136690 -0.136690 0.068350 0.068350 0.068350 0.068350 0.068350 0.068350

$$$$
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

-OEChem-10172415303D

12 12 0 0 0 0 0 0 0999 V2000
0.0000 1.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
1.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
1.0000 1.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
-0.7500 1.0000 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.0000 1.7500 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.0000 -0.7500 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.7500 0.0000 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.7500 0.0000 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.0000 -0.7500 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.0000 1.7500 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.7500 1.0000 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1 4 1 0 0 0 0
1 2 1 0 0 0 0
2 3 1 0 0 0 0
3 4 1 0 0 0 0
1 5 1 0 0 0 0
1 6 1 0 0 0 0
2 7 1 0 0 0 0
2 8 1 0 0 0 0
3 9 1 0 0 0 0
3 10 1 0 0 0 0
4 11 1 0 0 0 0
4 12 1 0 0 0 0
M END
> <atom.dprop.PartialCharge>
-0.096850 -0.096850 -0.096850 -0.096850 0.048425 0.048425 0.048425 0.048425 0.048425 0.048425 0.048425 0.048425

$$$$
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

-OEChem-10172415303D

16 16 0 0 0 0 0 0 0999 V2000
0.8674 0.4976 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
0.8674 1.5027 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
-0.0000 2.0102 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0
-0.8674 1.5027 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
-1.7349 2.0002 0.0000 O 0 0 0 0 0 0 0 0 0 0 0 0
-0.8674 0.4976 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
1.6061 0.6272 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.1226 -0.2076 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.1266 2.2065 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.6058 1.3710 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.0000 2.7602 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-1.1226 -0.2076 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-1.6061 0.6272 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
0.4816 -0.5750 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.4816 -0.5750 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1 7 1 0 0 0 0
1 2 1 0 0 0 0
2 3 1 0 0 0 0
3 4 1 0 0 0 0
4 5 2 0 0 0 0
4 6 1 0 0 0 0
6 7 1 0 0 0 0
1 8 1 0 0 0 0
1 9 1 0 0 0 0
2 10 1 0 0 0 0
2 11 1 0 0 0 0
3 12 1 0 0 0 0
6 13 1 0 0 0 0
6 14 1 0 0 0 0
7 15 1 0 0 0 0
7 16 1 0 0 0 0
M END
> <atom.dprop.PartialCharge>
-0.105587 0.100671 -0.583052 0.702970 -0.645974 -0.149781 -0.084091 0.052848 0.052848 0.041457 0.041457 0.321251 0.075829 0.075829 0.051663 0.051663

$$$$
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

-OEChem-10172415303D

14 14 0 0 0 0 0 0 0999 V2000
0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
-0.3065 0.9519 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
0.5007 1.5426 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0
1.3131 0.9519 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
1.0014 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
0.0777 -0.7460 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.7339 -0.1546 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.9919 0.6473 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.6824 1.6009 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
0.4996 2.2926 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.6888 1.6010 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.9978 0.6457 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1.7348 -0.1571 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
0.9224 -0.7458 0.0000 H 0 0 0 0 0 0 0 0 0 0 0 0
1 5 1 0 0 0 0
1 2 1 0 0 0 0
2 3 1 0 0 0 0
3 4 1 0 0 0 0
4 5 1 0 0 0 0
1 6 1 0 0 0 0
1 7 1 0 0 0 0
2 8 1 0 0 0 0
2 9 1 0 0 0 0
3 10 1 0 0 0 0
4 11 1 0 0 0 0
4 12 1 0 0 0 0
5 13 1 0 0 0 0
5 14 1 0 0 0 0
M END
> <atom.dprop.PartialCharge>
-0.102176 0.146808 -0.808465 0.146808 -0.102176 0.048320 0.048320 0.043794 0.043794 0.350745 0.043794 0.043794 0.048320 0.048320

$$$$
Loading

0 comments on commit 6cb4387

Please sign in to comment.