Skip to content

Commit

Permalink
Merge pull request #236 from Idriss-Malek/idriss_malek
Browse files Browse the repository at this point in the history
changed stack_states from a fn to a classmethod
  • Loading branch information
saleml authored Jan 29, 2025
2 parents 5db4162 + 0e1c580 commit 38782ed
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
6 changes: 4 additions & 2 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from gfn.containers import Trajectories
from gfn.env import Env
from gfn.modules import GFNModule
from gfn.states import States, stack_states
from gfn.states import States
from gfn.utils.handlers import (
has_conditioning_exception_handler,
no_conditioning_exception_handler,
Expand Down Expand Up @@ -245,7 +245,9 @@ def sample_trajectories(

trajectories_states.append(deepcopy(states))
# TODO: do not ignore the next three ignores
trajectories_states = stack_states(trajectories_states) # pyright: ignore
trajectories_states = states.stack_states(
trajectories_states
) # pyright: ignore
trajectories_actions = env.Actions.stack(trajectories_actions)[
1: # Drop dummy action
] # pyright: ignore
Expand Down
47 changes: 24 additions & 23 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,26 @@ def sample(self, n_samples: int) -> States:
"""Samples a subset of the States object."""
return self[torch.randperm(len(self))[:n_samples]]

@classmethod
def stack_states(cls, states: List[States]):
"""Given a list of states, stacks them along a new dimension (0)."""
state_example = states[0] # We assume all elems of `states` are the same.

stacked_states = state_example.from_batch_shape((0, 0)) # Empty.
stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0)
# TODO: do not ignore the next ignore
if state_example._log_rewards:
stacked_states._log_rewards = torch.stack(
[s._log_rewards for s in states], dim=0 # pyright: ignore
)

# Adds the trajectory dimension.
stacked_states.batch_shape = (
stacked_states.tensor.shape[0],
) + state_example.batch_shape

return stacked_states


class DiscreteStates(States, ABC):
"""Base class for states of discrete environments.
Expand Down Expand Up @@ -463,32 +483,13 @@ def init_forward_masks(self, set_ones: bool = True):
else:
self.forward_masks = torch.zeros(shape).bool()


def stack_states(states: List[States]):
"""Given a list of states, stacks them along a new dimension (0)."""
state_example = states[0] # We assume all elems of `states` are the same.

stacked_states = state_example.from_batch_shape((0, 0)) # Empty.
stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0)
# TODO: do not ignore the next ignore
if state_example._log_rewards:
stacked_states._log_rewards = torch.stack(
[s._log_rewards for s in states], dim=0 # pyright: ignore
)

# We are dealing with a list of DiscretrStates instances.
if isinstance(stacked_states, DiscreteStates):
# TODO: do not ignore the next two ignores
@classmethod
def stack_states(cls, states: List[DiscreteStates]):
stacked_states: DiscreteStates = super().stack_states(states) # pyright: ignore
stacked_states.forward_masks = torch.stack(
[s.forward_masks for s in states], dim=0 # pyright: ignore
)
stacked_states.backward_masks = torch.stack(
[s.backward_masks for s in states], dim=0 # pyright: ignore
)

# Adds the trajectory dimension.
stacked_states.batch_shape = (
stacked_states.tensor.shape[0],
) + state_example.batch_shape

return stacked_states
return stacked_states

0 comments on commit 38782ed

Please sign in to comment.