Skip to content

Commit

Permalink
fix merging issues
Browse files Browse the repository at this point in the history
younik committed Jan 13, 2025
1 parent c3df427 commit 78b729a
Showing 3 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
@@ -104,7 +104,7 @@ def __init__(
assert (
log_probs.shape == (self.max_length, self.n_trajectories)
and log_probs.dtype == torch.float
)
), f"log_probs.shape={log_probs.shape}, self.max_length={self.max_length}, self.n_trajectories={self.n_trajectories}"
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs: torch.Tensor = log_probs
6 changes: 3 additions & 3 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
@@ -207,7 +207,6 @@ def sample_trajectories(
all_estimator_outputs.append(estimator_outputs_padded)

actions[~dones] = valid_actions
trajectories_actions.append(actions)
if save_logprobs:
# When off_policy, actions_log_probs are None.
log_probs[~dones] = actions_log_probs
@@ -247,7 +246,9 @@ def sample_trajectories(
trajectories_states.append(deepcopy(states))

trajectories_states = env.States.stack(trajectories_states)
trajectories_actions = env.Actions.stack(trajectories_actions)[1:] # Drop dummy action
trajectories_actions = env.Actions.stack(trajectories_actions)[
1:
] # Drop dummy action
trajectories_logprobs = (
torch.stack(trajectories_logprobs, dim=0)[1:] # Drop dummy logprob
if save_logprobs
@@ -257,7 +258,6 @@ def sample_trajectories(
# TODO: use torch.nested.nested_tensor(dtype, device, requires_grad).
if save_estimator_outputs:
all_estimator_outputs = torch.stack(all_estimator_outputs, dim=0)

trajectories = Trajectories(
env=env,
states=trajectories_states,
6 changes: 6 additions & 0 deletions testing/test_samplers_and_trajectories.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from typing import Literal, Tuple

import pytest
import torch
from tensordict import TensorDict
from torch import nn
from torch_geometric.nn import GCNConv

from gfn.actions import GraphActionType
from gfn.containers import Trajectories
from gfn.containers.replay_buffer import ReplayBuffer
from gfn.gym import Box, DiscreteEBM, HyperGrid
from gfn.gym.graph_building import GraphBuilding
from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP
from gfn.modules import DiscretePolicyEstimator, GFNModule, GraphActionPolicyEstimator
from gfn.samplers import LocalSearchSampler, Sampler

0 comments on commit 78b729a

Please sign in to comment.