-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add local search sampler #208
Changes from 9 commits
1ccf16c
a3af467
d7d95ca
b075d9e
8481673
0cc32f7
4e11c27
e7fa8b6
21ce0c2
5ce1fdc
6f13cff
958139f
2672d9b
8529e4a
a2837bd
354ee04
9c60997
26126df
731e081
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,9 +5,8 @@ | |
if TYPE_CHECKING: | ||
from gfn.actions import Actions | ||
from gfn.env import Env | ||
from gfn.states import States, DiscreteStates | ||
from gfn.states import States | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from gfn.containers.base import Container | ||
|
@@ -101,7 +100,7 @@ def __init__( | |
and self._log_rewards.dtype == torch.float | ||
) | ||
|
||
if log_probs is not None: | ||
if log_probs is not None and log_probs.shape != (0, 0): | ||
assert ( | ||
log_probs.shape == (self.max_length, self.n_trajectories) | ||
and log_probs.dtype == torch.float | ||
|
@@ -122,15 +121,15 @@ def __repr__(self) -> str: | |
for traj in states[:10]: | ||
one_traj_repr = [] | ||
for step in traj: | ||
one_traj_repr.append(str(step.numpy())) | ||
one_traj_repr.append(str(step.cpu().numpy())) | ||
if step.equal(self.env.s0 if self.is_backward else self.env.sf): | ||
break | ||
trajectories_representation += "-> ".join(one_traj_repr) + "\n" | ||
return ( | ||
f"Trajectories(n_trajectories={self.n_trajectories}, max_length={self.max_length}, First 10 trajectories:" | ||
+ f"states=\n{trajectories_representation}" | ||
# + f"actions=\n{self.actions.tensor.squeeze().transpose(0, 1)[:10].numpy()}, " | ||
+ f"when_is_done={self.when_is_done[:10].numpy()})" | ||
+ f"when_is_done={self.when_is_done[:10].cpu().numpy()})" | ||
) | ||
|
||
@property | ||
|
@@ -428,6 +427,68 @@ def to_non_initial_intermediary_and_terminating_states( | |
conditioning, | ||
) | ||
|
||
@staticmethod | ||
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories: | ||
"""Reverses a backward trajectory""" | ||
# FIXME: This method is not compatible with continuous GFN. | ||
|
||
assert trajectories.is_backward, "Trajectories must be backward." | ||
new_actions = torch.full( | ||
( | ||
trajectories.max_length + 1, | ||
len(trajectories), | ||
*trajectories.actions.action_shape, | ||
), | ||
-1, | ||
) | ||
|
||
# env.sf should never be None unless something went wrong during class | ||
# instantiation. | ||
if trajectories.env.sf is None: | ||
raise AttributeError( | ||
"Something went wrong during the instantiation of environment {}".format( | ||
trajectories.env | ||
) | ||
) | ||
|
||
new_when_is_done = trajectories.when_is_done + 1 | ||
new_states = trajectories.env.sf.repeat( | ||
new_when_is_done.max() + 1, len(trajectories), 1 | ||
) | ||
|
||
# FIXME: Can we vectorize this? | ||
# FIXME: Also, loop over batch or sequence? | ||
for i in range(len(trajectories)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we flip the full trajectory tensor in one call, and then use indexing to resolve the padding instead? it should be much faster. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course, yes. I'll check and let you know if I need your help. |
||
new_actions[trajectories.when_is_done[i], i] = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is here. Actions are not always integers. |
||
trajectories.env.n_actions - 1 | ||
) | ||
new_actions[ | ||
: trajectories.when_is_done[i], i | ||
] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(0) | ||
|
||
new_states[ | ||
: trajectories.when_is_done[i] + 1, i | ||
] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip( | ||
0 | ||
) | ||
|
||
trajectories_states = trajectories.env.states_from_tensor(new_states) | ||
trajectories_actions = trajectories.env.actions_from_tensor(new_actions) | ||
|
||
return Trajectories( | ||
env=trajectories.env, | ||
states=trajectories_states, | ||
conditioning=trajectories.conditioning, | ||
actions=trajectories_actions, | ||
when_is_done=new_when_is_done, | ||
is_backward=False, | ||
log_rewards=trajectories.log_rewards, | ||
log_probs=None, # We can't simply pass the trajectories.log_probs | ||
# Since `log_probs` is assumed to be the forward log probabilities. | ||
# FIXME: To resolve this, we can save log_pfs and log_pbs in the trajectories object. | ||
estimator_outputs=None, # Same as `log_probs`. | ||
) | ||
|
||
|
||
def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor: | ||
"""Pads tensor a to match the dimention of b.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,8 @@ def __init__( | |
self.dummy_action = dummy_action | ||
self.exit_action = exit_action | ||
|
||
# Warning: don't use self.States or self.Actions to initialize an instance of the class. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Who is this warning intended for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe us?? Regarding this, what about making them into private variables (e.g., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like we should not initialize a States object using |
||
# Use self.states_from_tensor or self.actions_from_tensor instead. | ||
self.States = self.make_states_class() | ||
self.Actions = self.make_actions_class() | ||
|
||
|
@@ -85,7 +87,9 @@ def states_from_tensor(self, tensor: torch.Tensor): | |
""" | ||
return self.States(tensor) | ||
|
||
def states_from_batch_shape(self, batch_shape: Tuple): | ||
def states_from_batch_shape( | ||
self, batch_shape: Tuple, random: bool = False, sink: bool = False | ||
): | ||
hyeok9855 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Returns a batch of s0 states with a given batch_shape. | ||
|
||
Args: | ||
|
@@ -94,7 +98,7 @@ def states_from_batch_shape(self, batch_shape: Tuple): | |
Returns: | ||
States: A batch of initial states. | ||
""" | ||
return self.States.from_batch_shape(batch_shape) | ||
return self.States.from_batch_shape(batch_shape, random=random, sink=sink) | ||
|
||
def actions_from_tensor(self, tensor: torch.Tensor): | ||
"""Wraps the supplied Tensor an an Actions instance. | ||
|
@@ -218,7 +222,7 @@ def reset( | |
batch_shape = (1,) | ||
if isinstance(batch_shape, int): | ||
batch_shape = (batch_shape,) | ||
return self.States.from_batch_shape( | ||
return self.states_from_batch_shape( | ||
batch_shape=batch_shape, random=random, sink=sink | ||
) | ||
|
||
|
@@ -441,21 +445,21 @@ def reset( | |
batch_shape = (1,) | ||
if isinstance(batch_shape, int): | ||
batch_shape = (batch_shape,) | ||
states = self.States.from_batch_shape( | ||
states = self.states_from_batch_shape( | ||
batch_shape=batch_shape, random=random, sink=sink | ||
) | ||
self.update_masks(states) | ||
|
||
return states | ||
|
||
@abstractmethod | ||
def update_masks(self, states: type[States]) -> None: | ||
def update_masks(self, states: States) -> None: | ||
"""Updates the masks in States. | ||
|
||
Called automatically after each step for discrete environments. | ||
""" | ||
|
||
def make_states_class(self) -> type[States]: | ||
def make_states_class(self) -> type[DiscreteStates]: | ||
env = self | ||
|
||
class DiscreteEnvStates(DiscreteStates): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the major blocker here? Anyone know?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure either... This was from here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One guess is this:
In line 436-443:
Also, in line 461-463:
These assume that the action is an integer, which is not true for continuous case, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The blocker: see my response to line 462 of this file in the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function will not work on non-discrete environments!