Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warm-up functionality with tensor to trajectory helper functions #224

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType:
"""Converts trajectories to training samples. The type depends on the GFlowNet."""

@abstractmethod
def loss(self, env: Env, training_objects: Any):
def loss(self, env: Env, training_objects: Any) -> torch.Tensor:
"""Computes the loss given the training objects."""


Expand Down
111 changes: 109 additions & 2 deletions src/gfn/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from typing import Dict, Optional

import torch
from tqdm import trange

from gfn.env import Env
from gfn.containers import ReplayBuffer
from gfn.env import DiscreteEnv, Env
from gfn.gflownet import GFlowNet, TBGFlowNet
from gfn.states import States
from gfn.gflownet.base import PFBasedGFlowNet
from gfn.samplers import Trajectories
from gfn.states import States, stack_states


def get_terminating_state_dist_pmf(env: Env, states: States) -> torch.Tensor:
Expand Down Expand Up @@ -74,3 +78,106 @@ def validate(
if logZ is not None:
validation_info["logZ_diff"] = abs(logZ - true_logZ)
return validation_info


def states_actions_tns_to_traj(
states_tns: torch.Tensor,
actions_tns: torch.Tensor,
env: DiscreteEnv,
) -> Trajectories:
"""
This utility function helps integrate external data (e.g. expert demonstrations)
into the GFlowNet framework by converting raw tensors into proper Trajectories objects.
The downstream GFN needs to be capable of recalculating all logprobs (e.g. PFBasedGFlowNets)

Args:
states_tns: Tensor of shape [traj_len, *state_shape] containing states for a single trajectory
actions_tns: Tensor of shape [traj_len] containing discrete action indices
env: The discrete environment that defines the state/action spaces

Returns:
Trajectories: A Trajectories object containing the converted states and actions

Raises:
ValueError: If tensor shapes are invalid or inconsistent
"""

if states_tns.shape[1:] != env.state_shape:
raise ValueError(
f"states_tns state dimensions must match env.state_shape {env.state_shape}, "
f"got shape {states_tns.shape[1:]}"
)
if len(actions_tns.shape) != 1:
raise ValueError(f"actions_tns must be 1D, got batch_shape {actions_tns.shape}")
if states_tns.shape[0] != actions_tns.shape[0] + 1:
raise ValueError(
f"states and actions must have same trajectory length, got "
f"states: {states_tns.shape[0]}, actions: {actions_tns.shape[0]}"
)

states = [env.states_from_tensor(s.unsqueeze(0)) for s in states_tns]
actions = [
env.actions_from_tensor(a.unsqueeze(0).unsqueeze(0)) for a in actions_tns
]

# stack is a class method, so actions[0] is just to access a class instance and is not particularly relevant
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really appreciate this comment

actions = actions[0].stack(actions)
log_rewards = env.log_reward(states[-2])
states = stack_states(states)
when_is_done = torch.tensor([len(states_tns) - 1])

log_probs = None
estimator_outputs = None

trajectory = Trajectories(
env,
states,
actions,
log_rewards=log_rewards,
when_is_done=when_is_done,
log_probs=log_probs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_probs=None

estimator_outputs=estimator_outputs,
)
return trajectory


def warm_up(
replay_buf: ReplayBuffer,
optimizer: torch.optim.Optimizer,
gflownet: GFlowNet,
env: Env,
n_epochs: int,
batch_size: int,
recalculate_all_logprobs: bool = True,
):
"""
This utility function is an example implementation of pre-training for GFlowNets agent.

Args:
replay_buf: Replay Buffer, which collects Trajectories
optimizer: Any torch.optim optimizer (e.g. Adam, SGD)
gflownet: The GFlowNet to train
env: The environment instance
n_epochs: Number of epochs for warmup
batch_size: Number of trajectories to sample from replay buffer
recalculate_all_logprobs: For PFBasedGFlowNets only, force recalculating all log probs. Useful trajectories do not already have log probs.
Returns:
GFlowNet: A trained GFlowNet
"""
t = trange(n_epochs, desc="Bar desc", leave=True)
for epoch in t:
training_trajs = replay_buf.sample(batch_size)
optimizer.zero_grad()
if isinstance(gflownet, PFBasedGFlowNet):
loss = gflownet.loss(
env, training_trajs, recalculate_all_logprobs=recalculate_all_logprobs
)
else:
loss = gflownet.loss(env, training_trajs)

loss.backward()
optimizer.step()
t.set_description(f"{epoch=}, {loss=}")

optimizer.zero_grad()
return gflownet
Loading