-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add local search sampler #208
Conversation
c2d59b3
to
c6b9f64
Compare
c6b9f64
to
1ccf16c
Compare
@hyeok9855 can you fix the merge conflicts? |
I noticed you are using |
I fixed an issue in the backward mask! Is there anything necessary to do next? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @hyeok9855 - please see my comments. This is a really awesome PR, I like the changes you made to trajectories
and with some tweaks the local search sampler looks very clean.
Let me know if you want to schedule a pair programming session.
@staticmethod | ||
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories: | ||
"""Reverses a backward trajectory""" | ||
# FIXME: This method is not compatible with continuous GFN. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the major blocker here? Anyone know?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure either... This was from here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One guess is this:
In line 436-443:
new_actions = torch.full(
(
trajectories.max_length + 1,
len(trajectories),
*trajectories.actions.action_shape,
),
-1,
)
Also, in line 461-463:
new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.n_actions - 1
)
These assume that the action is an integer, which is not true for continuous case, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The blocker: see my response to line 462 of this file in the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function will not work on non-discrete environments!
src/gfn/containers/trajectories.py
Outdated
|
||
# FIXME: Can we vectorize this? | ||
# FIXME: Also, loop over batch or sequence? | ||
for i in range(len(trajectories)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we flip the full trajectory tensor in one call, and then use indexing to resolve the padding instead? it should be much faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course, yes. I'll check and let you know if I need your help.
@@ -61,6 +61,8 @@ def __init__( | |||
self.dummy_action = dummy_action | |||
self.exit_action = exit_action | |||
|
|||
# Warning: don't use self.States or self.Actions to initialize an instance of the class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Who is this warning intended for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe us?? Regarding this, what about making them into private variables (e.g., self.__States
and self.__Actions
)??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like we should not initialize a States object using self.States
, but rather self.states_from_tensor
, as in line 251 of src/gfn/gym/discrete_ebm.py
.
I agree with the general sentiment here. Should we actually raise a warning when self.States
is used? Or is there a way to prevent it?
I agree with @hyeok9855's comment as well.
) | ||
|
||
# Calculate the forward probability if needed (Metropolis-Hastings). | ||
prev_trajectories = Trajectories.reverse_backward_trajectories( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe for clarity prev_forward_trajectories
or reversed_bakward_trajectories
, since you also operate on trajectories
below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to make this correspond to the new_trajectories.
(Please check my comment below.)
I think adding a short explanation of why this is called prev_trajectories
would be fine, e.g.,
# By reversing the backward trajectories, obtain the forward trajectories.
# This is called `prev_trajectories` since they are the trajectories before
# the local search. The `new_trajectories` will be obtained by performing local
# search on them.
prev_trajectories = Trajectories.reverse_backward_trajectories(
backward_trajectories
)
What do you think about this??
src/gfn/samplers.py
Outdated
prev_trajectories = Trajectories.reverse_backward_trajectories( | ||
backward_trajectories | ||
) | ||
prev_trajectories_log_rewards = trajectories.log_rewards |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be prev_trajectories
? I actually think it does not matter since you're just looking at the reward at the end of the trajectory, but a comment here would be clarifying b/c above prev_trajectories
refers to the reverse of a trajectory sampled from pf and here you're grabbing log_rewards directly from the forward trajectories.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this is prev_trajectories
-> check like 450-456. I wanted that part to be new...
+/- prev...
.
To alleviate the confusion, we can simply change this line to
prev_trajectories_log_rewards = prev_trajectories.log_rewards
src/gfn/samplers.py
Outdated
n_back = backward_trajectories.when_is_done[i] - K[i] | ||
|
||
# Sanity check | ||
assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good I think to move this into a test, to prevent the assertion being run so often in production.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to test with the local variables of a function?
src/gfn/samplers.py
Outdated
device=device, dtype=torch.float | ||
) | ||
|
||
for i in range(bs): # FIXME: Can we vectorize this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be vectorized -- if you need help, no problem, let's schedule a pair programming session (maybe in the evening EST instead of my early morning ;) ).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me check first! I will let you know :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job, the code is very well written :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great work
@@ -61,6 +61,8 @@ def __init__( | |||
self.dummy_action = dummy_action | |||
self.exit_action = exit_action | |||
|
|||
# Warning: don't use self.States or self.Actions to initialize an instance of the class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like we should not initialize a States object using self.States
, but rather self.states_from_tensor
, as in line 251 of src/gfn/gym/discrete_ebm.py
.
I agree with the general sentiment here. Should we actually raise a warning when self.States
is used? Or is there a way to prevent it?
I agree with @hyeok9855's comment as well.
src/gfn/containers/trajectories.py
Outdated
# FIXME: Can we vectorize this? | ||
# FIXME: Also, loop over batch or sequence? | ||
for i in range(len(trajectories)): | ||
new_actions[trajectories.when_is_done[i], i] = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is here. Actions are not always integers.
@staticmethod | ||
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories: | ||
"""Reverses a backward trajectory""" | ||
# FIXME: This method is not compatible with continuous GFN. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The blocker: see my response to line 462 of this file in the PR.
@staticmethod | ||
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories: | ||
"""Reverses a backward trajectory""" | ||
# FIXME: This method is not compatible with continuous GFN. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function will not work on non-discrete environments!
@@ -103,23 +104,34 @@ def get_trajectory_pfs( | |||
valid_actions.tensor | |||
) # Using the actions sampled off-policy. | |||
|
|||
log_pf_trajectories = torch.full_like( | |||
trajectories.actions.tensor[..., 0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the rationale behind removing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just moved to line 78 to address the edge case in line 84!
@@ -145,13 +160,13 @@ def get_trajectory_pbs( | |||
valid_states, estimator_outputs | |||
).log_prob(valid_actions.tensor) | |||
|
|||
log_pb_trajectories = torch.full_like( | |||
trajectories.actions.tensor[..., 0], | |||
fill_value=fill_value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here :)
src/gfn/containers/trajectories.py
Outdated
0, 1 | ||
) # shape (max_len + 2, n_trajectories, *state_dim) | ||
|
||
# TODO: Add below into the test suite to ensure correctness |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for handling the vectorization. did you test this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I test this by uncommenting it.
@josephdviviano, Could you give any advice on how to design a test to check whether the vectorization works appropriately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about copying the code pre-vecotrization in a test file and compare the outputs of your new function and that code on a few hypergrid + other env trajectories ?
|
||
|
||
@pytest.mark.parametrize("env_name", ["HyperGrid", "DiscreteEBM"]) | ||
def test_reverse_backward_trajectories(env_name: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great
FYI: In terms of L1 distance (empirical distribution vs. true distribution, as did in original GFN paper), both LS-GFN (TB) w/ and w/o Metropolis-Hastings correction outperform vanilla TB in 16x16x16x16 HyperGrid (using the default hyperparameters).
|
LGTM. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry for the lag reviewing these changes - Salem and I think these contributions are excellent
This is a PR for adding
LocalSearchSampler
.The local search is based on the work [1] and [2].
Test in hypergrid env: