Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function to revert backward trajectories #109

Closed
saleml opened this issue Aug 2, 2023 · 7 comments
Closed

Function to revert backward trajectories #109

saleml opened this issue Aug 2, 2023 · 7 comments
Assignees
Labels
enhancement New feature or request high priority Let's do these first!

Comments

@saleml
Copy link
Collaborator

saleml commented Aug 2, 2023

In previous versions of the code, when actions were integers, we had this function that reverts backward trajectories. It's not used as part of the codebase, but I remember using it for another project (probably GFN vs HVI). I just removed it (in an upcoming PR), and it would be nice to fix it and have it back

    @staticmethod
    def revert_backward_trajectories(trajectories: Trajectories) -> Trajectories:
        """Reverses a trajectory, but not compatible with continuous GFN. Remove."""
        # TODO: this isn't used anywhere - it doesn't work as it assumes that the
        # actions are ints. Do we need it?
        assert trajectories.is_backward
        new_actions = torch.full_like(trajectories.actions, -1)
        new_actions = torch.cat(
            [new_actions, torch.full((1, len(trajectories)), -1)], dim=0
        )

        # env.sf should never be None unless something went wrong during class
        # instantiation.
        if trajectories.env.sf is None:
            raise AttributeError(
                "Something went wrong during the instantiation of environment {}".format(
                    trajectories.env
                )
            )

        new_states = trajectories.env.sf.repeat(
            trajectories.when_is_done.max() + 1, len(trajectories), 1
        )
        new_when_is_done = trajectories.when_is_done + 1

        for i in range(len(trajectories)):
            new_actions[trajectories.when_is_done[i], i] = (
                trajectories.env.n_actions - 1
            )

            new_actions[: trajectories.when_is_done[i], i] = trajectories.actions[
                : trajectories.when_is_done[i], i
            ].flip(0)

            new_states[
                : trajectories.when_is_done[i] + 1, i
            ] = trajectories.states.tensor[: trajectories.when_is_done[i] + 1, i].flip(
                0
            )

        new_states = trajectories.env.States(new_states)

        return Trajectories(
            env=trajectories.env,
            states=new_states,
            actions=new_actions,
            log_probs=trajectories.log_probs,
            when_is_done=new_when_is_done,
            is_backward=False,
        )
@saleml saleml added the after_v1 label Aug 2, 2023
@josephdviviano josephdviviano added enhancement New feature or request high priority Let's do these first! labels Oct 29, 2024
@josephdviviano
Copy link
Collaborator

Note on the commit that removed this feature : 05f0c68

@josephdviviano
Copy link
Collaborator

See here

else env.Actions.make_dummy_actions(batch_shape=(0, 0))
)

@hyeok9855
Copy link
Collaborator

it would be nice to fix it and have it back

What needs to be fixed here?

@saleml
Copy link
Collaborator Author

saleml commented Jan 13, 2025

@hyeok9855 , did you actually do this in #208 ?

@hyeok9855
Copy link
Collaborator

@hyeok9855 , did you actually do this in #208 ?

@saleml Yes, see here! Still, it is not compatible with continuous case.

@josephdviviano
Copy link
Collaborator

@hyeok9855 can you raise the block RE: continuous during our meeting tomorrow? We can brainstorm.

@hyeok9855
Copy link
Collaborator

hyeok9855 commented Feb 5, 2025

Resolved via #233

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request high priority Let's do these first!
Projects
None yet
Development

No branches or pull requests

3 participants