Skip to content

Commit

Permalink
ensure resencode to be int64 (#26)
Browse files Browse the repository at this point in the history
Co-authored-by: Pegerto Fernandez <[email protected]>
  • Loading branch information
pegerto and Pegerto Fernandez authored Dec 3, 2023
1 parent 3c11104 commit ba8535a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "graphpro"
version = "0.9.1"
version = "0.9.2"
authors = [
{ name="Pegerto Fernandez", email="[email protected]" },
]
Expand Down
2 changes: 1 addition & 1 deletion src/graphpro/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ def generate(self, G: Graph, atom_group: AtomGroup):
def encode(self, G: Graph) -> torch.tensor:
res_names = [G.node_attr(n)['resname'] for n in G.nodes()]
res_ids = [self.res_letters.index(name) for name in res_names]
return F.one_hot(torch.tensor(res_ids), num_classes=len(self.res_letters))
return F.one_hot(torch.tensor(res_ids, dtype=torch.int64), num_classes=len(self.res_letters))
4 changes: 3 additions & 1 deletion test/graphpro/annotations_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import MDAnalysis as mda

from graphpro import md_analisys
Expand All @@ -19,4 +20,5 @@ def test_resname_encoded():
G = md_analisys(u1).generate(ContactMap(cutoff=6), [ResidueType()])
data = G.to_data(node_encoders= [ResidueType()])

assert data.x.size() == (214, 22)
assert data.x.size() == (214, 22)
assert data.x.dtype == torch.int64

0 comments on commit ba8535a

Please sign in to comment.