Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local search sampler #208

Merged
merged 19 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
if TYPE_CHECKING:
from gfn.actions import Actions
from gfn.env import Env
from gfn.states import States, DiscreteStates
from gfn.states import States

import numpy as np
import torch

from gfn.containers.base import Container
Expand Down Expand Up @@ -101,7 +100,7 @@ def __init__(
and self._log_rewards.dtype == torch.float
)

if log_probs is not None:
if log_probs is not None and log_probs.shape != (0, 0):
assert (
log_probs.shape == (self.max_length, self.n_trajectories)
and log_probs.dtype == torch.float
Expand All @@ -122,15 +121,15 @@ def __repr__(self) -> str:
for traj in states[:10]:
one_traj_repr = []
for step in traj:
one_traj_repr.append(str(step.numpy()))
one_traj_repr.append(str(step.cpu().numpy()))
if step.equal(self.env.s0 if self.is_backward else self.env.sf):
break
trajectories_representation += "-> ".join(one_traj_repr) + "\n"
return (
f"Trajectories(n_trajectories={self.n_trajectories}, max_length={self.max_length}, First 10 trajectories:"
+ f"states=\n{trajectories_representation}"
# + f"actions=\n{self.actions.tensor.squeeze().transpose(0, 1)[:10].numpy()}, "
+ f"when_is_done={self.when_is_done[:10].numpy()})"
+ f"when_is_done={self.when_is_done[:10].cpu().numpy()})"
)

@property
Expand Down Expand Up @@ -428,6 +427,68 @@ def to_non_initial_intermediary_and_terminating_states(
conditioning,
)

@staticmethod
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
"""Reverses a backward trajectory"""
# FIXME: This method is not compatible with continuous GFN.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the major blocker here? Anyone know?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure either... This was from here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One guess is this:

In line 436-443:

        new_actions = torch.full(
            (
                trajectories.max_length + 1,
                len(trajectories),
                *trajectories.actions.action_shape,
            ),
            -1,
        )

Also, in line 461-463:

            new_actions[trajectories.when_is_done[i], i] = (
                trajectories.env.n_actions - 1
            )

These assume that the action is an integer, which is not true for continuous case, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blocker: see my response to line 462 of this file in the PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function will not work on non-discrete environments!


assert trajectories.is_backward, "Trajectories must be backward."
new_actions = torch.full(
(
trajectories.max_length + 1,
len(trajectories),
*trajectories.actions.action_shape,
),
-1,
)

# env.sf should never be None unless something went wrong during class
# instantiation.
if trajectories.env.sf is None:
raise AttributeError(
"Something went wrong during the instantiation of environment {}".format(
trajectories.env
)
)

new_when_is_done = trajectories.when_is_done + 1
new_states = trajectories.env.sf.repeat(
new_when_is_done.max() + 1, len(trajectories), 1
)

# FIXME: Can we vectorize this?
# FIXME: Also, loop over batch or sequence?
for i in range(len(trajectories)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we flip the full trajectory tensor in one call, and then use indexing to resolve the padding instead? it should be much faster.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, yes. I'll check and let you know if I need your help.

new_actions[trajectories.when_is_done[i], i] = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is here. Actions are not always integers.

trajectories.env.n_actions - 1
)
new_actions[
: trajectories.when_is_done[i], i
] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(0)

new_states[
: trajectories.when_is_done[i] + 1, i
] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip(
0
)

trajectories_states = trajectories.env.states_from_tensor(new_states)
trajectories_actions = trajectories.env.actions_from_tensor(new_actions)

return Trajectories(
env=trajectories.env,
states=trajectories_states,
conditioning=trajectories.conditioning,
actions=trajectories_actions,
when_is_done=new_when_is_done,
is_backward=False,
log_rewards=trajectories.log_rewards,
log_probs=None, # We can't simply pass the trajectories.log_probs
# Since `log_probs` is assumed to be the forward log probabilities.
# FIXME: To resolve this, we can save log_pfs and log_pbs in the trajectories object.
estimator_outputs=None, # Same as `log_probs`.
)


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
"""Pads tensor a to match the dimention of b."""
Expand Down
16 changes: 10 additions & 6 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(
self.dummy_action = dummy_action
self.exit_action = exit_action

# Warning: don't use self.States or self.Actions to initialize an instance of the class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who is this warning intended for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe us?? Regarding this, what about making them into private variables (e.g., self.__States and self.__Actions)??

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we should not initialize a States object using self.States, but rather self.states_from_tensor, as in line 251 of src/gfn/gym/discrete_ebm.py.
I agree with the general sentiment here. Should we actually raise a warning when self.States is used? Or is there a way to prevent it?
I agree with @hyeok9855's comment as well.

# Use self.states_from_tensor or self.actions_from_tensor instead.
self.States = self.make_states_class()
self.Actions = self.make_actions_class()

Expand All @@ -85,7 +87,9 @@ def states_from_tensor(self, tensor: torch.Tensor):
"""
return self.States(tensor)

def states_from_batch_shape(self, batch_shape: Tuple):
def states_from_batch_shape(
self, batch_shape: Tuple, random: bool = False, sink: bool = False
):
"""Returns a batch of s0 states with a given batch_shape.

Args:
Expand All @@ -94,7 +98,7 @@ def states_from_batch_shape(self, batch_shape: Tuple):
Returns:
States: A batch of initial states.
"""
return self.States.from_batch_shape(batch_shape)
return self.States.from_batch_shape(batch_shape, random=random, sink=sink)

def actions_from_tensor(self, tensor: torch.Tensor):
"""Wraps the supplied Tensor an an Actions instance.
Expand Down Expand Up @@ -218,7 +222,7 @@ def reset(
batch_shape = (1,)
if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
return self.States.from_batch_shape(
return self.states_from_batch_shape(
batch_shape=batch_shape, random=random, sink=sink
)

Expand Down Expand Up @@ -441,21 +445,21 @@ def reset(
batch_shape = (1,)
if isinstance(batch_shape, int):
batch_shape = (batch_shape,)
states = self.States.from_batch_shape(
states = self.states_from_batch_shape(
batch_shape=batch_shape, random=random, sink=sink
)
self.update_masks(states)

return states

@abstractmethod
def update_masks(self, states: type[States]) -> None:
def update_masks(self, states: States) -> None:
"""Updates the masks in States.

Called automatically after each step for discrete environments.
"""

def make_states_class(self) -> type[States]:
def make_states_class(self) -> type[DiscreteStates]:
env = self

class DiscreteEnvStates(DiscreteStates):
Expand Down
6 changes: 3 additions & 3 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
preprocessor=preprocessor,
)

def update_masks(self, states: type[States]) -> None:
def update_masks(self, states: DiscreteStates) -> None:
states.forward_masks[..., : self.ndim] = states.tensor == -1
states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1
states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1)
Expand Down Expand Up @@ -248,13 +248,13 @@ def all_states(self) -> DiscreteStates:
digits = torch.arange(3, device=self.device)
all_states = torch.cartesian_prod(*[digits] * self.ndim)
all_states = all_states - 1
return self.States(all_states)
return self.states_from_tensor(all_states)

@property
def terminating_states(self) -> DiscreteStates:
digits = torch.arange(2, device=self.device)
all_states = torch.cartesian_prod(*[digits] * self.ndim)
return self.States(all_states)
return self.states_from_tensor(all_states)

@property
def true_dist_pmf(self) -> torch.Tensor:
Expand Down
6 changes: 3 additions & 3 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
preprocessor=preprocessor,
)

def update_masks(self, states: type[DiscreteStates]) -> None:
def update_masks(self, states: DiscreteStates) -> None:
"""Update the masks based on the current states."""
# Not allowed to take any action beyond the environment height, but
# allow early termination.
Expand Down Expand Up @@ -223,13 +223,13 @@ def build_grid(self) -> DiscreteStates:
rearrange_string += " ".join([f"n{i}" for i in range(ndim, 0, -1)])
rearrange_string += " ndim"
grid = rearrange(grid, rearrange_string).long()
return self.States(grid)
return self.states_from_tensor(grid)

@property
def all_states(self) -> DiscreteStates:
grid = self.build_grid()
flat_grid = rearrange(grid.tensor, "... ndim -> (...) ndim")
return self.States(flat_grid)
return self.states_from_tensor(flat_grid)

@property
def terminating_states(self) -> DiscreteStates:
Expand Down
Loading
Loading