Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local search sampler #208

Merged
merged 19 commits into from
Jan 13, 2025
Merged

Add local search sampler #208

merged 19 commits into from
Jan 13, 2025

Conversation

hyeok9855
Copy link
Collaborator

@hyeok9855 hyeok9855 commented Oct 29, 2024

This is a PR for adding LocalSearchSampler.
The local search is based on the work [1] and [2].

Test in hypergrid env:

python tutorials/examples/train_hypergrid_simple_ls.py

@hyeok9855 hyeok9855 self-assigned this Oct 29, 2024
@hyeok9855 hyeok9855 marked this pull request as draft October 29, 2024 11:03
@hyeok9855 hyeok9855 force-pushed the hyeok9855/local-search branch 2 times, most recently from c2d59b3 to c6b9f64 Compare October 29, 2024 17:15
@hyeok9855 hyeok9855 force-pushed the hyeok9855/local-search branch from c6b9f64 to 1ccf16c Compare October 29, 2024 18:21
@josephdviviano
Copy link
Collaborator

@hyeok9855 can you fix the merge conflicts?

@josephdviviano
Copy link
Collaborator

I noticed you are using force-push -- be careful with this, it can put the code in a state hard to resolve with the rest of the work.

https://www.gitkraken.com/learn/git/problems/git-push-force#:~:text=The%20Risks%20of%20Git%20Push%20Force&text=Because%20you%20have%20failed%20to,deleting%20your%20team%20member's%20work.

@hyeok9855
Copy link
Collaborator Author

I fixed an issue in the backward mask!

Is there anything necessary to do next?

@hyeok9855 hyeok9855 changed the title [Draft] Add local search sampler Add local search sampler Nov 29, 2024
@hyeok9855 hyeok9855 requested a review from younik November 29, 2024 18:57
@hyeok9855 hyeok9855 added the enhancement New feature or request label Nov 29, 2024
@hyeok9855 hyeok9855 marked this pull request as ready for review November 29, 2024 18:58
@saleml saleml self-requested a review December 3, 2024 16:05
Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @hyeok9855 - please see my comments. This is a really awesome PR, I like the changes you made to trajectories and with some tweaks the local search sampler looks very clean.

Let me know if you want to schedule a pair programming session.

@staticmethod
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
"""Reverses a backward trajectory"""
# FIXME: This method is not compatible with continuous GFN.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the major blocker here? Anyone know?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure either... This was from here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One guess is this:

In line 436-443:

        new_actions = torch.full(
            (
                trajectories.max_length + 1,
                len(trajectories),
                *trajectories.actions.action_shape,
            ),
            -1,
        )

Also, in line 461-463:

            new_actions[trajectories.when_is_done[i], i] = (
                trajectories.env.n_actions - 1
            )

These assume that the action is an integer, which is not true for continuous case, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blocker: see my response to line 462 of this file in the PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function will not work on non-discrete environments!


# FIXME: Can we vectorize this?
# FIXME: Also, loop over batch or sequence?
for i in range(len(trajectories)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we flip the full trajectory tensor in one call, and then use indexing to resolve the padding instead? it should be much faster.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, yes. I'll check and let you know if I need your help.

@@ -61,6 +61,8 @@ def __init__(
self.dummy_action = dummy_action
self.exit_action = exit_action

# Warning: don't use self.States or self.Actions to initialize an instance of the class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who is this warning intended for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe us?? Regarding this, what about making them into private variables (e.g., self.__States and self.__Actions)??

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we should not initialize a States object using self.States, but rather self.states_from_tensor, as in line 251 of src/gfn/gym/discrete_ebm.py.
I agree with the general sentiment here. Should we actually raise a warning when self.States is used? Or is there a way to prevent it?
I agree with @hyeok9855's comment as well.

)

# Calculate the forward probability if needed (Metropolis-Hastings).
prev_trajectories = Trajectories.reverse_backward_trajectories(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe for clarity prev_forward_trajectories or reversed_bakward_trajectories, since you also operate on trajectories below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make this correspond to the new_trajectories. (Please check my comment below.)

I think adding a short explanation of why this is called prev_trajectories would be fine, e.g.,

        # By reversing the backward trajectories, obtain the forward trajectories.
        # This is called `prev_trajectories` since they are the trajectories before
        # the local search. The `new_trajectories` will be obtained by performing local
        # search on them.
        prev_trajectories = Trajectories.reverse_backward_trajectories(
            backward_trajectories
        )

What do you think about this??

prev_trajectories = Trajectories.reverse_backward_trajectories(
backward_trajectories
)
prev_trajectories_log_rewards = trajectories.log_rewards
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be prev_trajectories? I actually think it does not matter since you're just looking at the reward at the end of the trajectory, but a comment here would be clarifying b/c above prev_trajectories refers to the reverse of a trajectory sampled from pf and here you're grabbing log_rewards directly from the forward trajectories.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is prev_trajectories -> check like 450-456. I wanted that part to be new... +/- prev....
To alleviate the confusion, we can simply change this line to

prev_trajectories_log_rewards = prev_trajectories.log_rewards

n_back = backward_trajectories.when_is_done[i] - K[i]

# Sanity check
assert (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good I think to move this into a test, to prevent the assertion being run so often in production.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to test with the local variables of a function?

device=device, dtype=torch.float
)

for i in range(bs): # FIXME: Can we vectorize this?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be vectorized -- if you need help, no problem, let's schedule a pair programming session (maybe in the evening EST instead of my early morning ;) ).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check first! I will let you know :)

Copy link
Collaborator

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job, the code is very well written :)

Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work

@@ -61,6 +61,8 @@ def __init__(
self.dummy_action = dummy_action
self.exit_action = exit_action

# Warning: don't use self.States or self.Actions to initialize an instance of the class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we should not initialize a States object using self.States, but rather self.states_from_tensor, as in line 251 of src/gfn/gym/discrete_ebm.py.
I agree with the general sentiment here. Should we actually raise a warning when self.States is used? Or is there a way to prevent it?
I agree with @hyeok9855's comment as well.

# FIXME: Can we vectorize this?
# FIXME: Also, loop over batch or sequence?
for i in range(len(trajectories)):
new_actions[trajectories.when_is_done[i], i] = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is here. Actions are not always integers.

@staticmethod
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
"""Reverses a backward trajectory"""
# FIXME: This method is not compatible with continuous GFN.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blocker: see my response to line 462 of this file in the PR.

@staticmethod
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
"""Reverses a backward trajectory"""
# FIXME: This method is not compatible with continuous GFN.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function will not work on non-discrete environments!

@@ -103,23 +104,34 @@ def get_trajectory_pfs(
valid_actions.tensor
) # Using the actions sampled off-policy.

log_pf_trajectories = torch.full_like(
trajectories.actions.tensor[..., 0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the rationale behind removing this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just moved to line 78 to address the edge case in line 84!

@@ -145,13 +160,13 @@ def get_trajectory_pbs(
valid_states, estimator_outputs
).log_prob(valid_actions.tensor)

log_pb_trajectories = torch.full_like(
trajectories.actions.tensor[..., 0],
fill_value=fill_value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here :)

0, 1
) # shape (max_len + 2, n_trajectories, *state_dim)

# TODO: Add below into the test suite to ensure correctness
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for handling the vectorization. did you test this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I test this by uncommenting it.
@josephdviviano, Could you give any advice on how to design a test to check whether the vectorization works appropriately?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about copying the code pre-vecotrization in a test file and compare the outputs of your new function and that code on a few hypergrid + other env trajectories ?



@pytest.mark.parametrize("env_name", ["HyperGrid", "DiscreteEBM"])
def test_reverse_backward_trajectories(env_name: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great

@hyeok9855
Copy link
Collaborator Author

FYI: In terms of L1 distance (empirical distribution vs. true distribution, as did in original GFN paper), both LS-GFN (TB) w/ and w/o Metropolis-Hastings correction outperform vanilla TB in 16x16x16x16 HyperGrid (using the default hyperparameters).

  • TB: 2.571 × 10 5
  • TB + LS: 2.513 × 10 5
  • TB + LS + MH: 2.507 × 10 5

@saleml
Copy link
Collaborator

saleml commented Jan 11, 2025

LGTM.
@josephdviviano , please merge if you're satisfied with the changed.

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the lag reviewing these changes - Salem and I think these contributions are excellent

@josephdviviano josephdviviano merged commit 7f03681 into master Jan 13, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants