Skip to content

Commit

Permalink
add tests for action
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Jan 15, 2025
1 parent 5e4fc4e commit 705b4cc
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 8 deletions.
29 changes: 21 additions & 8 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class GraphActionType(enum.IntEnum):
ADD_NODE = 0
ADD_EDGE = 1
EXIT = 2
DUMMY = 3


class GraphActions(Actions):
Expand Down Expand Up @@ -209,7 +210,7 @@ def __init__(self, tensor: TensorDict):
self.batch_shape = tensor["action_type"].shape
features = tensor.get("features", None)
if features is None:
assert torch.all(tensor["action_type"] == GraphActionType.EXIT)
assert torch.all(torch.logical_or(tensor["action_type"] == GraphActionType.EXIT, tensor["action_type"] == GraphActionType.DUMMY))
features = torch.zeros((*self.batch_shape, self.features_dim))
edge_index = tensor.get("edge_index", None)
if edge_index is None:
Expand Down Expand Up @@ -269,6 +270,11 @@ def compare(self, other: GraphActions) -> torch.Tensor:
def is_exit(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are exit actions."""
return self.action_type == GraphActionType.EXIT

@property
def is_dummy(self) -> torch.Tensor:
"""Returns a boolean tensor of shape `batch_shape` indicating whether the actions are dummy actions."""
return self.action_type == GraphActionType.DUMMY

@property
def action_type(self) -> torch.Tensor:
Expand All @@ -287,22 +293,29 @@ def edge_index(self) -> torch.Tensor:

@classmethod
def make_dummy_actions(cls, batch_shape: tuple[int]) -> GraphActions:
"""Creates an Actions object of dummy actions with the given batch shape."""
"""Creates a GraphActions object of dummy actions with the given batch shape."""
return cls(
TensorDict(
{
"action_type": torch.full(
batch_shape, fill_value=GraphActionType.EXIT
batch_shape, fill_value=GraphActionType.DUMMY
),
},
batch_size=batch_shape,
)
)

@classmethod
def stack(cls, actions_list: list[GraphActions]) -> GraphActions:
"""Stacks a list of GraphActions objects into a single GraphActions object."""
actions_tensor = torch.stack(
[actions.tensor for actions in actions_list], dim=0
def make_exit_actions(cls, batch_shape: tuple[int]) -> Actions:
"""Creates an GraphActions object of exit actions with the given batch shape."""
return cls(
TensorDict(
{
"action_type": torch.full(
batch_shape, fill_value=GraphActionType.EXIT
),
},
batch_size=batch_shape,
)
)
return cls(actions_tensor)

109 changes: 109 additions & 0 deletions testing/test_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from copy import deepcopy
from gfn.actions import Actions, GraphActions
import pytest
import torch
from tensordict import TensorDict


class ContinuousActions(Actions):
action_shape = (10,)
dummy_action = torch.zeros(10)
exit_action = torch.ones(10)

class GraphActions(GraphActions):
features_dim = 10


@pytest.fixture
def continuous_action():
return ContinuousActions(
tensor=torch.arange(0, 10)
)

@pytest.fixture
def graph_action():
return GraphActions(
tensor=TensorDict(
{
"action_type": torch.zeros((1,), dtype=torch.float32),
"features": torch.zeros((1, 10), dtype=torch.float32),
},
device="cpu",
)
)


def test_continuous_action(continuous_action):
BATCH = 5

exit_actions = continuous_action.make_exit_actions((BATCH,))
assert torch.all(exit_actions.tensor == continuous_action.exit_action.repeat(BATCH, 1))
assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool))
assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool))

dummy_actions = continuous_action.make_dummy_actions((BATCH,))
assert torch.all(dummy_actions.tensor == continuous_action.dummy_action.repeat(BATCH, 1))
assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool))
assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool))

# Test stack
stacked_actions = continuous_action.stack([exit_actions, dummy_actions])
assert stacked_actions.batch_shape == (2, BATCH)
assert torch.all(stacked_actions.tensor == torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0))
is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0)
assert torch.all(stacked_actions.is_exit == is_exit_stacked)
assert stacked_actions[0, 1].is_exit
stacked_actions[0, 1] = stacked_actions[1, 1]
is_exit_stacked[0, 1] = False
assert torch.all(stacked_actions.is_exit == is_exit_stacked)

# Test extend
extended_actions = deepcopy(exit_actions)
extended_actions.extend(dummy_actions)
assert extended_actions.batch_shape == (BATCH * 2,)
assert torch.all(extended_actions.tensor == torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0))
is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0)
assert torch.all(extended_actions.is_exit == is_exit_extended)
assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy
extended_actions[0] = extended_actions[BATCH]
is_exit_extended[0] = False
assert torch.all(extended_actions.is_exit == is_exit_extended)

def test_graph_action(graph_action):
BATCH = 5

exit_actions = graph_action.make_exit_actions((BATCH,))
assert torch.all(exit_actions.is_exit == torch.ones(BATCH, dtype=torch.bool))
assert torch.all(exit_actions.is_dummy == torch.zeros(BATCH, dtype=torch.bool))
dummy_actions = graph_action.make_dummy_actions((BATCH,))
assert torch.all(dummy_actions.is_dummy == torch.ones(BATCH, dtype=torch.bool))
assert torch.all(dummy_actions.is_exit == torch.zeros(BATCH, dtype=torch.bool))

# Test stack
stacked_actions = graph_action.stack([exit_actions, dummy_actions])
assert stacked_actions.batch_shape == (2, BATCH)
manually_stacked_tensor = torch.stack([exit_actions.tensor, dummy_actions.tensor], dim=0)
assert torch.all(stacked_actions.tensor["action_type"] == manually_stacked_tensor["action_type"])
assert torch.all(stacked_actions.tensor["features"] == manually_stacked_tensor["features"])
assert torch.all(stacked_actions.tensor["edge_index"] == manually_stacked_tensor["edge_index"])
is_exit_stacked = torch.stack([exit_actions.is_exit, dummy_actions.is_exit], dim=0)
assert torch.all(stacked_actions.is_exit == is_exit_stacked)
assert stacked_actions[0, 1].is_exit
stacked_actions[0, 1] = stacked_actions[1, 1]
is_exit_stacked[0, 1] = False
assert torch.all(stacked_actions.is_exit == is_exit_stacked)

# Test extend
extended_actions = deepcopy(exit_actions)
extended_actions.extend(dummy_actions)
assert extended_actions.batch_shape == (BATCH * 2,)
manually_extended_tensor = torch.cat([exit_actions.tensor, dummy_actions.tensor], dim=0)
assert torch.all(extended_actions.tensor["action_type"] == manually_extended_tensor["action_type"])
assert torch.all(extended_actions.tensor["features"] == manually_extended_tensor["features"])
assert torch.all(extended_actions.tensor["edge_index"] == manually_extended_tensor["edge_index"])
is_exit_extended = torch.cat([exit_actions.is_exit, dummy_actions.is_exit], dim=0)
assert torch.all(extended_actions.is_exit == is_exit_extended)
assert extended_actions[0].is_exit and extended_actions[BATCH].is_dummy
extended_actions[0] = extended_actions[BATCH]
is_exit_extended[0] = False
assert torch.all(extended_actions.is_exit == is_exit_extended)

0 comments on commit 705b4cc

Please sign in to comment.