Skip to content

Commit

Permalink
Merge pull request #239 from GFNOrg/hyeok9855/test_to_transition
Browse files Browse the repository at this point in the history
add test for to_transition and some refactorings
  • Loading branch information
saleml authored Feb 8, 2025
2 parents a59b581 + 4872a98 commit 84ca656
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 147 deletions.
122 changes: 49 additions & 73 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def to_transitions(self) -> Transitions:
dtype=torch.float,
device=actions.device,
)
# Can we vectorize this?
log_rewards[is_done] = torch.cat(
[
self._log_rewards[self.when_is_done == i]
Expand All @@ -368,7 +369,8 @@ def to_transitions(self) -> Transitions:
dim=0,
)

# Only return logprobs if they exist.
# FIXME: Transitions requires log_probs for initialization (see line 107 in transitions.py).
# Shouldn't we make sure that log_probs are always available?
log_probs = (
self.log_probs[~self.actions.is_dummy] if has_log_probs(self) else None
)
Expand Down Expand Up @@ -427,82 +429,62 @@ def to_non_initial_intermediary_and_terminating_states(
conditioning,
)

@staticmethod
def reverse_backward_trajectories(
trajectories: Trajectories, debug: bool = False
) -> Trajectories:
"""Reverses a backward trajectory"""
assert trajectories.is_backward, "Trajectories must be backward."
def reverse_backward_trajectories(self, debug: bool = False) -> Trajectories:
"""Return a reversed version of the backward trajectories."""
assert self.is_backward, "Trajectories must be backward."

# env.sf should never be None unless something went wrong during class
# instantiation.
if trajectories.env.sf is None:
if self.env.sf is None:
raise AttributeError(
"Something went wrong during the instantiation of environment {}".format(
trajectories.env
self.env
)
)

# Compute sequence lengths and maximum length
seq_lengths = trajectories.when_is_done # shape (n_trajectories,)
seq_lengths = self.when_is_done # shape (n_trajectories,)
max_len = seq_lengths.max().item()

# Get actions and states
actions = (
trajectories.actions.tensor
) # shape (max_len, n_trajectories *action_dim)
states = (
trajectories.states.tensor
) # shape (max_len + 1, n_trajectories, *state_dim)
actions = self.actions.tensor # shape (max_len, n_trajectories *action_dim)
states = self.states.tensor # shape (max_len + 1, n_trajectories, *state_dim)

# Initialize new actions and states
new_actions = trajectories.env.dummy_action.repeat(
max_len + 1, len(trajectories), 1 # pyright: ignore
).to(
actions
) # shape (max_len + 1, n_trajectories, *action_dim)
new_states = trajectories.env.sf.repeat(
max_len + 2, len(trajectories), 1 # pyright: ignore
).to(
states
) # shape (max_len + 2, n_trajectories, *state_dim)
new_actions = self.env.dummy_action.repeat(
max_len + 1, len(self), 1 # pyright: ignore
).to(actions)
# shape (max_len + 1, n_trajectories, *action_dim)
new_states = self.env.sf.repeat(
max_len + 2, len(self), 1 # pyright: ignore
).to(states)
# shape (max_len + 2, n_trajectories, *state_dim)

# Create helper indices and masks
idx = (
torch.arange(max_len)
.unsqueeze(1)
.expand(-1, len(trajectories))
.to(seq_lengths)
)
idx = torch.arange(max_len).unsqueeze(1).expand(-1, len(self)).to(seq_lengths)
rev_idx = seq_lengths - 1 - idx # shape (max_len, n_trajectories)
mask = rev_idx >= 0 # shape (max_len, n_trajectories)
rev_idx[:, 1:] += seq_lengths.cumsum(0)[:-1]

# Transpose for easier indexing
actions = actions.transpose(
0, 1
) # shape (n_trajectories, max_len, *action_dim)
new_actions = new_actions.transpose(
0, 1
) # shape (n_trajectories, max_len + 1, *action_dim)
states = states.transpose(
0, 1
) # shape (n_trajectories, max_len + 1, *state_dim)
new_states = new_states.transpose(
0, 1
) # shape (n_trajectories, max_len + 2, *state_dim)
actions = actions.transpose(0, 1)
# shape (n_trajectories, max_len, *action_dim)
new_actions = new_actions.transpose(0, 1)
# shape (n_trajectories, max_len + 1, *action_dim)
states = states.transpose(0, 1)
# shape (n_trajectories, max_len + 1, *state_dim)
new_states = new_states.transpose(0, 1)
# shape (n_trajectories, max_len + 2, *state_dim)
rev_idx = rev_idx.transpose(0, 1)
mask = mask.transpose(0, 1)

# Assign reversed actions to new_actions
new_actions[:, :-1][mask] = actions[mask][rev_idx[mask]]
new_actions[torch.arange(len(trajectories)), seq_lengths] = (
trajectories.env.exit_action
)
new_actions[torch.arange(len(self)), seq_lengths] = self.env.exit_action

# Assign reversed states to new_states
assert torch.all(states[:, -1] == trajectories.env.s0), "Last state must be s0"
new_states[:, 0] = trajectories.env.s0
assert torch.all(states[:, -1] == self.env.s0), "Last state must be s0"
new_states[:, 0] = self.env.s0
new_states[:, 1:-1][mask] = states[:, :-1][mask][rev_idx[mask]]

# Transpose back
Expand All @@ -514,13 +496,13 @@ def reverse_backward_trajectories(
) # shape (max_len + 2, n_trajectories, *state_dim)

reversed_trajectories = Trajectories(
env=trajectories.env,
states=trajectories.env.states_from_tensor(new_states),
conditioning=trajectories.conditioning,
actions=trajectories.env.actions_from_tensor(new_actions),
when_is_done=trajectories.when_is_done + 1,
env=self.env,
states=self.env.states_from_tensor(new_states),
conditioning=self.conditioning,
actions=self.env.actions_from_tensor(new_actions),
when_is_done=self.when_is_done + 1,
is_backward=False,
log_rewards=trajectories.log_rewards,
log_rewards=self.log_rewards,
log_probs=None, # We can't simply pass the trajectories.log_probs
# Since `log_probs` is assumed to be the forward log probabilities.
# FIXME: To resolve this, we can save log_pfs and log_pbs in the trajectories object.
Expand All @@ -531,32 +513,26 @@ def reverse_backward_trajectories(
# If `debug` is True (expected only when testing), compare the
# vectorized approach's results (above) to the for-loop results (below).
if debug:
_new_actions = trajectories.env.dummy_action.repeat(
max_len + 1, len(trajectories), 1 # pyright: ignore
_new_actions = self.env.dummy_action.repeat(
max_len + 1, len(self), 1 # pyright: ignore
).to(
actions
) # shape (max_len + 1, n_trajectories, *action_dim)
_new_states = trajectories.env.sf.repeat(
max_len + 2, len(trajectories), 1 # pyright: ignore
_new_states = self.env.sf.repeat(
max_len + 2, len(self), 1 # pyright: ignore
).to(
states
) # shape (max_len + 2, n_trajectories, *state_dim)

for i in range(len(trajectories)):
_new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.exit_action
)
_new_actions[: trajectories.when_is_done[i], i] = (
trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(
0
)
)
for i in range(len(self)):
_new_actions[self.when_is_done[i], i] = self.env.exit_action
_new_actions[: self.when_is_done[i], i] = self.actions.tensor[
: self.when_is_done[i], i
].flip(0)

_new_states[: trajectories.when_is_done[i] + 1, i] = (
trajectories.states.tensor[
: trajectories.when_is_done[i] + 1, i
].flip(0)
)
_new_states[: self.when_is_done[i] + 1, i] = self.states.tensor[
: self.when_is_done[i] + 1, i
].flip(0)

assert torch.all(new_actions == _new_actions)
assert torch.all(new_states == _new_states)
Expand Down
4 changes: 1 addition & 3 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,7 @@ def local_search(
# This is called `prev_trajectories` since they are the trajectories before
# the local search. The `new_trajectories` will be obtained by performing local
# search on them.
prev_trajectories = Trajectories.reverse_backward_trajectories(
prev_trajectories
)
prev_trajectories = prev_trajectories.reverse_backward_trajectories()
assert prev_trajectories.log_rewards is not None

### Reconstructing with self.estimator
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ def __getitem__(
backward_masks = self.backward_masks[index]
out = self.__class__(states, forward_masks, backward_masks)
if self._log_rewards is not None:
log_probs = self._log_rewards[index]
out.log_rewards = log_probs
log_rewards = self._log_rewards[index]
out.log_rewards = log_rewards
return out

def __setitem__(
Expand Down
Loading

0 comments on commit 84ca656

Please sign in to comment.