From 705b4ccbf6d3f089e58cff7b1ac3d8b2dcb8faf3 Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Wed, 15 Jan 2025 17:37:50 +0100 Subject: [PATCH] add tests for action --- src/gfn/actions.py | 29 ++++++++--- testing/test_actions.py | 109 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 8 deletions(-) create mode 100644 testing/test_actions.py diff --git a/src/gfn/actions.py b/src/gfn/actions.py index e76388d9..88b4e635 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -176,6 +176,7 @@ class GraphActionType(enum.IntEnum): ADD_NODE = 0 ADD_EDGE = 1 EXIT = 2 + DUMMY = 3 class GraphActions(Actions): @@ -209,7 +210,7 @@ def __init__(self, tensor: TensorDict): self.batch_shape = tensor["action_type"].shape features = tensor.get("features", None) if features is None: - assert torch.all(tensor["action_type"] == GraphActionType.EXIT) + assert torch.all(torch.logical_or(tensor["action_type"] == GraphActionType.EXIT, tensor["action_type"] == GraphActionType.DUMMY)) features = torch.zeros((*self.batch_shape, self.features_dim)) edge_index = tensor.get("edge_index", None) if edge_index is None: @@ -269,6 +270,11 @@ def compare(self, other: GraphActions) -> torch.Tensor: def is_exit(self) -> torch.Tensor: """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions.""" return self.action_type == GraphActionType.EXIT + + @property + def is_dummy(self) -> torch.Tensor: + """Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions.""" + return self.action_type == GraphActionType.DUMMY @property def action_type(self) -> torch.Tensor: @@ -287,12 +293,12 @@ def edge_index(self) -> torch.Tensor: @classmethod def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: - """Creates an Actions object of dummy actions with the given batch shape.""" + """Creates a GraphActions object of dummy actions with the given batch shape.""" return cls( TensorDict( { "action_type": torch.full( - batch_shape, fill_value=GraphActionType.EXIT + batch_shape, fill_value=GraphActionType.DUMMY ), }, batch_size=batch_shape, @@ -300,9 +306,16 @@ def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions: ) @classmethod - def stack(cls, actions_list: list[GraphActions]) -> GraphActions: - """Stacks a list of GraphActions objects into a single GraphActions object.""" - actions_tensor = torch.stack( - [actions.tensor for actions in actions_list], dim=0 + def make_exit_actions(cls, batch_shape: tuple[int]) -> Actions: + """Creates an GraphActions object of exit actions with the given batch shape.""" + return cls( + TensorDict( + { + "action_type": torch.full( + batch_shape, fill_value=GraphActionType.EXIT + ), + }, + batch_size=batch_shape, + ) ) - return cls(actions_tensor) + diff --git a/testing/test_actions.py b/testing/test_actions.py new file mode 100644 index 00000000..4dcf05f4 --- /dev/null +++ b/testing/test_actions.py @@ -0,0 +1,109 @@ +from copy import deepcopy +from gfn.actions import Actions, GraphActions +import pytest +import torch +from tensordict import TensorDict + + +class ContinuousActions(Actions): + action_shape = (10,) + dummy_action = torch.zeros(10) + exit_action = torch.ones(10) + +class GraphActions(GraphActions): + features_dim = 10 + + +@pytest.fixture +def continuous_action(): + return ContinuousActions( + tensor=torch.arange(0, 10) + ) + +@pytest.fixture +def graph_action(): + return GraphActions( + tensor=TensorDict( + { + "action_type": torch.zeros((1,), dtype=torch.float32), + "features": torch.zeros((1, 10), dtype=torch.float32), + }, + device="cpu", + ) + ) + + +def test_continuous_action(continuous_action): + BATCH = 5 + + exit_actions = continuous_action.make_exit_actions((BATCH,)) + assert torch.all(exit_actions.tensor == continuous_action.exit_action.repeat(BATCH, 1)) + assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool)) + + dummy_actions = continuous_action.make_dummy_actions((BATCH,)) + assert torch.all(dummy_actions.tensor == continuous_action.dummy_action.repeat(BATCH, 1)) + assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool)) + + # Test stack + stacked_actions = continuous_action.stack([exit_actions, dummy_actions]) + assert stacked_actions.batch_shape == (2, BATCH) + assert torch.all(stacked_actions.tensor == torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0)) + is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + assert stacked_actions[0, 1].is_exit + stacked_actions[0, 1] = stacked_actions[1, 1] + is_exit_stacked[0, 1] = False + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + + # Test extend + extended_actions = deepcopy(exit_actions) + extended_actions.extend(dummy_actions) + assert extended_actions.batch_shape == (BATCH * 2,) + assert torch.all(extended_actions.tensor == torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0)) + is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(extended_actions.is_exit == is_exit_extended) + assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy + extended_actions[0] = extended_actions[BATCH] + is_exit_extended[0] = False + assert torch.all(extended_actions.is_exit == is_exit_extended) + +def test_graph_action(graph_action): + BATCH = 5 + + exit_actions = graph_action.make_exit_actions((BATCH,)) + assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool)) + dummy_actions = graph_action.make_dummy_actions((BATCH,)) + assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool)) + assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool)) + + # Test stack + stacked_actions = graph_action.stack([exit_actions, dummy_actions]) + assert stacked_actions.batch_shape == (2, BATCH) + manually_stacked_tensor = torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0) + assert torch.all(stacked_actions.tensor["action_type"] == manually_stacked_tensor["action_type"]) + assert torch.all(stacked_actions.tensor["features"] == manually_stacked_tensor["features"]) + assert torch.all(stacked_actions.tensor["edge_index"] == manually_stacked_tensor["edge_index"]) + is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + assert stacked_actions[0, 1].is_exit + stacked_actions[0, 1] = stacked_actions[1, 1] + is_exit_stacked[0, 1] = False + assert torch.all(stacked_actions.is_exit == is_exit_stacked) + + # Test extend + extended_actions = deepcopy(exit_actions) + extended_actions.extend(dummy_actions) + assert extended_actions.batch_shape == (BATCH * 2,) + manually_extended_tensor = torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0) + assert torch.all(extended_actions.tensor["action_type"] == manually_extended_tensor["action_type"]) + assert torch.all(extended_actions.tensor["features"] == manually_extended_tensor["features"]) + assert torch.all(extended_actions.tensor["edge_index"] == manually_extended_tensor["edge_index"]) + is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0) + assert torch.all(extended_actions.is_exit == is_exit_extended) + assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy + extended_actions[0] = extended_actions[BATCH] + is_exit_extended[0] = False + assert torch.all(extended_actions.is_exit == is_exit_extended) \ No newline at end of file