diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 025c11bd..60d11e3b 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -71,7 +71,7 @@ def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType: """Converts trajectories to training samples. The type depends on the GFlowNet.""" @abstractmethod - def loss(self, env: Env, training_objects: Any): + def loss(self, env: Env, training_objects: Any) -> torch.Tensor: """Computes the loss given the training objects.""" diff --git a/src/gfn/utils/training.py b/src/gfn/utils/training.py index 614820b7..36afb450 100644 --- a/src/gfn/utils/training.py +++ b/src/gfn/utils/training.py @@ -2,9 +2,13 @@ from typing import Dict, Optional import torch +from tqdm import trange -from gfn.env import Env +from gfn.containers import ReplayBuffer +from gfn.env import DiscreteEnv, Env from gfn.gflownet import GFlowNet, TBGFlowNet +from gfn.gflownet.base import PFBasedGFlowNet +from gfn.samplers import Trajectories from gfn.states import States @@ -81,3 +85,111 @@ def validate( if logZ is not None: validation_info["logZ_diff"] = abs(logZ - true_logZ) return validation_info + + +def states_actions_tns_to_traj( + states_tns: torch.Tensor, + actions_tns: torch.Tensor, + env: DiscreteEnv, + conditioning: torch.Tensor | None = None, +) -> Trajectories: + """ + This utility function helps integrate external data (e.g. expert demonstrations) + into the GFlowNet framework by converting raw tensors into proper Trajectories objects. + The downstream GFN needs to be capable of recalculating all logprobs (e.g. PFBasedGFlowNets) + + Args: + states_tns: Tensor of shape [traj_len, *state_shape] containing states for a single trajectory + actions_tns: Tensor of shape [traj_len] containing discrete action indices + env: The discrete environment that defines the state/action spaces + conditioning: Tensor of shape [traj_len, *conditioning_shape] containing states for a single trajectory + + Returns: + Trajectories: A Trajectories object containing the converted states and actions + + Raises: + ValueError: If tensor shapes are invalid or inconsistent + """ + + if states_tns.shape[1:] != env.state_shape: + raise ValueError( + f"states_tns state dimensions must match env.state_shape {env.state_shape}, " + f"got shape {states_tns.shape[1:]}" + ) + if len(actions_tns.shape) != 1: + raise ValueError(f"actions_tns must be 1D, got batch_shape {actions_tns.shape}") + if states_tns.shape[0] != actions_tns.shape[0] + 1: + raise ValueError( + f"states and actions must have same trajectory length, got " + f"states: {states_tns.shape[0]}, actions: {actions_tns.shape[0]}" + ) + + states = [env.states_from_tensor(s.unsqueeze(0)) for s in states_tns] + actions = [ + env.actions_from_tensor(a.unsqueeze(0).unsqueeze(0)) for a in actions_tns + ] + + # stack is a class method, so actions[0] is just to access a class instance and is not particularly relevant + actions = actions[0].stack(actions) + log_rewards = env.log_reward(states[-2]) + states = states[0].stack_states(states) + when_is_done = torch.tensor([len(states_tns) - 1]) + + log_probs = None + estimator_outputs = None + + trajectory = Trajectories( + env, + states, + conditioning, + actions, + log_rewards=log_rewards, + when_is_done=when_is_done, + log_probs=log_probs, + estimator_outputs=estimator_outputs, + ) + return trajectory + + +def warm_up( + replay_buf: ReplayBuffer, + optimizer: torch.optim.Optimizer, + gflownet: GFlowNet, + env: Env, + n_epochs: int, + batch_size: int, + recalculate_all_logprobs: bool = True, +): + """ + This utility function is an example implementation of pre-training for GFlowNets agent. + + Args: + replay_buf: Replay Buffer, which collects Trajectories + optimizer: Any torch.optim optimizer (e.g. Adam, SGD) + gflownet: The GFlowNet to train + env: The environment instance + n_epochs: Number of epochs for warmup + batch_size: Number of trajectories to sample from replay buffer + recalculate_all_logprobs: For PFBasedGFlowNets only, force recalculating all log probs. Useful trajectories do not already have log probs. + Returns: + GFlowNet: A trained GFlowNet + """ + t = trange(n_epochs, desc="Bar desc", leave=True) + for epoch in t: + training_trajs = replay_buf.sample(batch_size) + optimizer.zero_grad() + if isinstance(gflownet, PFBasedGFlowNet): + loss = gflownet.loss( + env, + training_trajs, + recalculate_all_logprobs=recalculate_all_logprobs, # pyright: ignore + ) # pyright: ignore + else: + loss = gflownet.loss(env, training_trajs) + + loss.backward() + optimizer.step() + t.set_description(f"{epoch=}, {loss=}") + + optimizer.zero_grad() + return gflownet