From 0b622b133e34ba9f1981a41730586c84d268ed46 Mon Sep 17 00:00:00 2001 From: Idriss-Malek Date: Wed, 29 Jan 2025 14:55:19 +0400 Subject: [PATCH 1/2] changed stack_states from a fn to a classmethod --- src/gfn/samplers.py | 4 ++-- src/gfn/states.py | 48 +++++++++++++++++++++++---------------------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 88cad461..c51d5e62 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -7,7 +7,7 @@ from gfn.containers import Trajectories from gfn.env import Env from gfn.modules import GFNModule -from gfn.states import States, stack_states +from gfn.states import States from gfn.utils.handlers import ( has_conditioning_exception_handler, no_conditioning_exception_handler, @@ -245,7 +245,7 @@ def sample_trajectories( trajectories_states.append(deepcopy(states)) # TODO: do not ignore the next three ignores - trajectories_states = stack_states(trajectories_states) # pyright: ignore + trajectories_states = states.stack_states(trajectories_states) # pyright: ignore trajectories_actions = env.Actions.stack(trajectories_actions)[ 1: # Drop dummy action ] # pyright: ignore diff --git a/src/gfn/states.py b/src/gfn/states.py index 1e818e7e..159ffd3b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -276,6 +276,26 @@ def log_rewards(self, log_rewards: torch.Tensor) -> None: def sample(self, n_samples: int) -> States: """Samples a subset of the States object.""" return self[torch.randperm(len(self))[:n_samples]] + + @classmethod + def stack_states(cls, states: List[States]): + """Given a list of states, stacks them along a new dimension (0).""" + state_example = states[0] # We assume all elems of `states` are the same. + + stacked_states = state_example.from_batch_shape((0, 0)) # Empty. + stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) + # TODO: do not ignore the next ignore + if state_example._log_rewards: + stacked_states._log_rewards = torch.stack( + [s._log_rewards for s in states], dim=0 # pyright: ignore + ) + + # Adds the trajectory dimension. + stacked_states.batch_shape = ( + stacked_states.tensor.shape[0], + ) + state_example.batch_shape + + return stacked_states class DiscreteStates(States, ABC): @@ -462,33 +482,15 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.ones(shape).bool() else: self.forward_masks = torch.zeros(shape).bool() - - -def stack_states(states: List[States]): - """Given a list of states, stacks them along a new dimension (0).""" - state_example = states[0] # We assume all elems of `states` are the same. - - stacked_states = state_example.from_batch_shape((0, 0)) # Empty. - stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) - # TODO: do not ignore the next ignore - if state_example._log_rewards: - stacked_states._log_rewards = torch.stack( - [s._log_rewards for s in states], dim=0 # pyright: ignore - ) - - # We are dealing with a list of DiscretrStates instances. - if isinstance(stacked_states, DiscreteStates): - # TODO: do not ignore the next two ignores + + @classmethod + def stack_states(cls, states:List[DiscreteStates]): + stacked_states: DiscreteStates = super().stack_states(states) # pyright: ignore stacked_states.forward_masks = torch.stack( [s.forward_masks for s in states], dim=0 # pyright: ignore ) stacked_states.backward_masks = torch.stack( [s.backward_masks for s in states], dim=0 # pyright: ignore ) + return stacked_states - # Adds the trajectory dimension. - stacked_states.batch_shape = ( - stacked_states.tensor.shape[0], - ) + state_example.batch_shape - - return stacked_states From 0e1c58008430c37a0732c84d780338f3f6c6a73e Mon Sep 17 00:00:00 2001 From: Idriss-Malek Date: Wed, 29 Jan 2025 15:13:48 +0400 Subject: [PATCH 2/2] reformatted --- src/gfn/samplers.py | 4 +++- src/gfn/states.py | 9 ++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index c51d5e62..78d0c42d 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -245,7 +245,9 @@ def sample_trajectories( trajectories_states.append(deepcopy(states)) # TODO: do not ignore the next three ignores - trajectories_states = states.stack_states(trajectories_states) # pyright: ignore + trajectories_states = states.stack_states( + trajectories_states + ) # pyright: ignore trajectories_actions = env.Actions.stack(trajectories_actions)[ 1: # Drop dummy action ] # pyright: ignore diff --git a/src/gfn/states.py b/src/gfn/states.py index 159ffd3b..20476868 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -276,7 +276,7 @@ def log_rewards(self, log_rewards: torch.Tensor) -> None: def sample(self, n_samples: int) -> States: """Samples a subset of the States object.""" return self[torch.randperm(len(self))[:n_samples]] - + @classmethod def stack_states(cls, states: List[States]): """Given a list of states, stacks them along a new dimension (0).""" @@ -482,10 +482,10 @@ def init_forward_masks(self, set_ones: bool = True): self.forward_masks = torch.ones(shape).bool() else: self.forward_masks = torch.zeros(shape).bool() - + @classmethod - def stack_states(cls, states:List[DiscreteStates]): - stacked_states: DiscreteStates = super().stack_states(states) # pyright: ignore + def stack_states(cls, states: List[DiscreteStates]): + stacked_states: DiscreteStates = super().stack_states(states) # pyright: ignore stacked_states.forward_masks = torch.stack( [s.forward_masks for s in states], dim=0 # pyright: ignore ) @@ -493,4 +493,3 @@ def stack_states(cls, states:List[DiscreteStates]): [s.backward_masks for s in states], dim=0 # pyright: ignore ) return stacked_states -