From f481063c09279a62f7a491130a146f0c1c3ed973 Mon Sep 17 00:00:00 2001 From: Pegerto Fernandez Date: Tue, 26 Mar 2024 18:34:31 +0000 Subject: [PATCH] Revert "remove one-encoding for binary target" This reverts commit 8ac69b613e095608b93a00a5babc99447756f7b1. --- src/graphpro/annotations.py | 4 ++-- test/graphpro/annotations_test.py | 6 +++--- test/graphpro/graph_test.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/graphpro/annotations.py b/src/graphpro/annotations.py index 65522cd..35f73bc 100644 --- a/src/graphpro/annotations.py +++ b/src/graphpro/annotations.py @@ -12,7 +12,7 @@ class NodeTargetBinaryAttribute(NodeTarget): """ def encode(self, G: Graph) -> torch.Tensor: present = [self.attr_name in G.node_attr(n) for n in G.nodes()] - return torch.tensor([present], dtype=torch.int64).to(torch.float).T + return F.one_hot(torch.tensor(present, dtype=torch.int64), num_classes=2).to(torch.float) class NodeAnnotation(): @@ -76,4 +76,4 @@ def generate(self, G: Graph, atom_group: AtomGroup): def encode(self, G: Graph) -> torch.tensor: total_area = [G.node_attr(n)[self.attr_name] if self.attr_name in G.node_attr(n) else 0 for n in G.nodes()] - return F.normalize(torch.tensor([total_area], dtype=torch.float).T, dim=(0,1)) \ No newline at end of file + return F.normalize(torch.tensor([total_area], dtype=torch.float).T, dim=(0,1)) diff --git a/test/graphpro/annotations_test.py b/test/graphpro/annotations_test.py index 8c04d1d..a13c173 100644 --- a/test/graphpro/annotations_test.py +++ b/test/graphpro/annotations_test.py @@ -15,9 +15,9 @@ def test_node_target_binary(): target = NodeTargetBinaryAttribute("test_attr") y = target.encode(G) - assert y.size() == (214, 1) - assert torch.all(y[0].eq(torch.tensor([0.]))) - assert torch.all(y[1].eq(torch.tensor([1.]))) + assert y.size() == (214,2) + assert torch.all(y[0].eq(torch.tensor([1., 0.]))) + assert torch.all(y[1].eq(torch.tensor([0., 1.]))) def test_resname_annotation(): diff --git a/test/graphpro/graph_test.py b/test/graphpro/graph_test.py index 8b7ea6e..1a55347 100644 --- a/test/graphpro/graph_test.py +++ b/test/graphpro/graph_test.py @@ -56,4 +56,4 @@ def test_to_data_x_transformer(): def test_to_data_y_target(): data = SIMPLE_G_ATTR.to_data(node_encoders= [ResidueType()], target = NodeTargetBinaryAttribute("target")) - assert data.y.size() == (2,1) \ No newline at end of file + assert data.y.size() == (2,2) \ No newline at end of file