Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local search sampler #208

Merged
merged 19 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 42 additions & 28 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,9 @@ def to_non_initial_intermediary_and_terminating_states(
)

@staticmethod
def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
def reverse_backward_trajectories(
trajectories: Trajectories, debug: bool = False
) -> Trajectories:
"""Reverses a backward trajectory"""
# FIXME: This method is not compatible with continuous GFN.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the major blocker here? Anyone know?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure either... This was from here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One guess is this:

In line 436-443:

        new_actions = torch.full(
            (
                trajectories.max_length + 1,
                len(trajectories),
                *trajectories.actions.action_shape,
            ),
            -1,
        )

Also, in line 461-463:

            new_actions[trajectories.when_is_done[i], i] = (
                trajectories.env.n_actions - 1
            )

These assume that the action is an integer, which is not true for continuous case, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blocker: see my response to line 462 of this file in the PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function will not work on non-discrete environments!


Expand Down Expand Up @@ -519,35 +521,11 @@ def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
0, 1
) # shape (max_len + 2, n_trajectories, *state_dim)

# TODO: Add below into the test suite to ensure correctness
# new_actions2 = torch.full((max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1).to(actions)
# new_states2 = trajectories.env.sf.repeat(max_len + 2, len(trajectories), 1).to(states) # shape (max_len + 2, n_trajectories, *state_dim)

# for i in range(len(trajectories)):
# new_actions2[trajectories.when_is_done[i], i] = (
# trajectories.env.n_actions - 1
# )
# new_actions2[
# : trajectories.when_is_done[i], i
# ] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(0)

# new_states2[
# : trajectories.when_is_done[i] + 1, i
# ] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip(
# 0
# )

# assert torch.all(new_actions == new_actions2)
# assert torch.all(new_states == new_states2)

trajectories_states = trajectories.env.states_from_tensor(new_states)
trajectories_actions = trajectories.env.actions_from_tensor(new_actions)

return Trajectories(
reversed_trajectories = Trajectories(
env=trajectories.env,
states=trajectories_states,
states=trajectories.env.states_from_tensor(new_states),
conditioning=trajectories.conditioning,
actions=trajectories_actions,
actions=trajectories.env.actions_from_tensor(new_actions),
when_is_done=trajectories.when_is_done + 1,
is_backward=False,
log_rewards=trajectories.log_rewards,
Expand All @@ -557,6 +535,42 @@ def reverse_backward_trajectories(trajectories: Trajectories) -> Trajectories:
estimator_outputs=None, # Same as `log_probs`.
)

# ------------------------------ DEBUG ------------------------------
# If `debug` is True (expected only when testing), compare the
# vectorized approach's results (above) to the for-loop results (below).
if debug:
_new_actions = torch.full(
(max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1
).to(actions)
_new_states = trajectories.env.sf.repeat(
max_len + 2, len(trajectories), 1
).to(
states
) # shape (max_len + 2, n_trajectories, *state_dim)

for i in range(len(trajectories)):
_new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.n_actions - 1
)
_new_actions[
: trajectories.when_is_done[i], i
] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(
0
)

_new_states[
: trajectories.when_is_done[i] + 1, i
] = trajectories.states.tensor[
: trajectories.when_is_done[i] + 1, i
].flip(
0
)

assert torch.all(new_actions == _new_actions)
assert torch.all(new_states == _new_states)

return reversed_trajectories


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
"""Pads tensor a to match the dimention of b."""
Expand Down
Loading
Loading