-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add warm-up functionality with tensor to trajectory helper functions #224
Open
alexandrelarouche
wants to merge
5
commits into
GFNOrg:master
Choose a base branch
from
alexandrelarouche:warmup
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
51c1b54
Add crude warmup functions
alexandrelarouche 6d12656
Add output type for GFlowNet abstract loss function to avoid LSP flag…
alexandrelarouche 0ca9082
Implement feedback from saleml
alexandrelarouche 9b7a9d3
Set logprobs and estimator_outputs to None
alexandrelarouche e7b2057
Run pre-commit
alexandrelarouche File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,14 @@ | |
from typing import Dict, Optional | ||
|
||
import torch | ||
from tqdm import trange | ||
|
||
from gfn.env import Env | ||
from gfn.containers import ReplayBuffer | ||
from gfn.env import DiscreteEnv, Env | ||
from gfn.gflownet import GFlowNet, TBGFlowNet | ||
from gfn.states import States | ||
from gfn.gflownet.base import PFBasedGFlowNet | ||
from gfn.samplers import Trajectories | ||
from gfn.states import States, stack_states | ||
|
||
|
||
def get_terminating_state_dist_pmf(env: Env, states: States) -> torch.Tensor: | ||
|
@@ -74,3 +78,106 @@ def validate( | |
if logZ is not None: | ||
validation_info["logZ_diff"] = abs(logZ - true_logZ) | ||
return validation_info | ||
|
||
|
||
def states_actions_tns_to_traj( | ||
states_tns: torch.Tensor, | ||
actions_tns: torch.Tensor, | ||
env: DiscreteEnv, | ||
) -> Trajectories: | ||
""" | ||
This utility function helps integrate external data (e.g. expert demonstrations) | ||
into the GFlowNet framework by converting raw tensors into proper Trajectories objects. | ||
The downstream GFN needs to be capable of recalculating all logprobs (e.g. PFBasedGFlowNets) | ||
|
||
Args: | ||
states_tns: Tensor of shape [traj_len, *state_shape] containing states for a single trajectory | ||
actions_tns: Tensor of shape [traj_len] containing discrete action indices | ||
env: The discrete environment that defines the state/action spaces | ||
|
||
Returns: | ||
Trajectories: A Trajectories object containing the converted states and actions | ||
|
||
Raises: | ||
ValueError: If tensor shapes are invalid or inconsistent | ||
""" | ||
|
||
if states_tns.shape[1:] != env.state_shape: | ||
raise ValueError( | ||
f"states_tns state dimensions must match env.state_shape {env.state_shape}, " | ||
f"got shape {states_tns.shape[1:]}" | ||
) | ||
if len(actions_tns.shape) != 1: | ||
raise ValueError(f"actions_tns must be 1D, got batch_shape {actions_tns.shape}") | ||
if states_tns.shape[0] != actions_tns.shape[0] + 1: | ||
raise ValueError( | ||
f"states and actions must have same trajectory length, got " | ||
f"states: {states_tns.shape[0]}, actions: {actions_tns.shape[0]}" | ||
) | ||
|
||
states = [env.states_from_tensor(s.unsqueeze(0)) for s in states_tns] | ||
actions = [ | ||
env.actions_from_tensor(a.unsqueeze(0).unsqueeze(0)) for a in actions_tns | ||
] | ||
|
||
# stack is a class method, so actions[0] is just to access a class instance and is not particularly relevant | ||
actions = actions[0].stack(actions) | ||
log_rewards = env.log_reward(states[-2]) | ||
states = stack_states(states) | ||
when_is_done = torch.tensor([len(states_tns) - 1]) | ||
|
||
log_probs = None | ||
estimator_outputs = None | ||
|
||
trajectory = Trajectories( | ||
env, | ||
states, | ||
actions, | ||
log_rewards=log_rewards, | ||
when_is_done=when_is_done, | ||
log_probs=log_probs, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
estimator_outputs=estimator_outputs, | ||
) | ||
return trajectory | ||
|
||
|
||
def warm_up( | ||
replay_buf: ReplayBuffer, | ||
optimizer: torch.optim.Optimizer, | ||
gflownet: GFlowNet, | ||
env: Env, | ||
n_epochs: int, | ||
batch_size: int, | ||
recalculate_all_logprobs: bool = True, | ||
): | ||
""" | ||
This utility function is an example implementation of pre-training for GFlowNets agent. | ||
|
||
Args: | ||
replay_buf: Replay Buffer, which collects Trajectories | ||
optimizer: Any torch.optim optimizer (e.g. Adam, SGD) | ||
gflownet: The GFlowNet to train | ||
env: The environment instance | ||
n_epochs: Number of epochs for warmup | ||
batch_size: Number of trajectories to sample from replay buffer | ||
recalculate_all_logprobs: For PFBasedGFlowNets only, force recalculating all log probs. Useful trajectories do not already have log probs. | ||
Returns: | ||
GFlowNet: A trained GFlowNet | ||
""" | ||
t = trange(n_epochs, desc="Bar desc", leave=True) | ||
for epoch in t: | ||
training_trajs = replay_buf.sample(batch_size) | ||
optimizer.zero_grad() | ||
if isinstance(gflownet, PFBasedGFlowNet): | ||
loss = gflownet.loss( | ||
env, training_trajs, recalculate_all_logprobs=recalculate_all_logprobs | ||
) | ||
else: | ||
loss = gflownet.loss(env, training_trajs) | ||
|
||
loss.backward() | ||
optimizer.step() | ||
t.set_description(f"{epoch=}, {loss=}") | ||
|
||
optimizer.zero_grad() | ||
return gflownet |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really appreciate this comment