Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local search sampler #208

Merged
merged 19 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
and self._log_rewards.dtype == torch.float
)

if log_probs is not None:
if log_probs is not None and log_probs.shape != (0, 0):
assert (
log_probs.shape == (self.max_length, self.n_trajectories)
and log_probs.dtype == torch.float
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def reset(
return states

@abstractmethod
def update_masks(self, states: type[States]) -> None:
def update_masks(self, states: States) -> None:
"""Updates the masks in States.

Called automatically after each step for discrete environments.
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
preprocessor=preprocessor,
)

def update_masks(self, states: type[States]) -> None:
def update_masks(self, states: States) -> None:
states.forward_masks[..., : self.ndim] = states.tensor == -1
states.forward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == -1
states.forward_masks[..., -1] = torch.all(states.tensor != -1, dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
preprocessor=preprocessor,
)

def update_masks(self, states: type[DiscreteStates]) -> None:
def update_masks(self, states: DiscreteStates) -> None:
"""Update the masks based on the current states."""
# Not allowed to take any action beyond the environment height, but
# allow early termination.
Expand Down
328 changes: 324 additions & 4 deletions src/gfn/samplers.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def extend_with_sf(self, required_first_dim: int) -> None:
f"extend_with_sf is not implemented for batch shapes {self.batch_shape}"
)

def compare(self, other: torch.tensor) -> torch.Tensor:
def compare(self, other: torch.Tensor) -> torch.Tensor:
"""Computes elementwise equality between state tensor with an external tensor.

Args:
Expand Down
3 changes: 1 addition & 2 deletions tutorials/examples/train_hypergrid_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def main(args):
sampler = Sampler(estimator=pf_estimator)

# Move the gflownet to the GPU.
if torch.cuda.is_available():
gflownet = gflownet.to("cuda")
gflownet = gflownet.to(device_str)

# Policy parameters have their own LR. Log Z gets dedicated learning rate
# (typically higher).
Expand Down
117 changes: 117 additions & 0 deletions tutorials/examples/train_hypergrid_simple_ls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/env python
import argparse

import torch
from tqdm import tqdm

from gfn.gflownet import TBGFlowNet
from gfn.gym import HyperGrid
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import LocalSearchSampler
from gfn.utils.common import set_seed
from gfn.utils.modules import MLP


def main(args):
set_seed(args.seed)
device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

# Setup the Environment.
env = HyperGrid(ndim=args.ndim, height=args.height, device_str=device_str)

# Build the GFlowNet.
module_PF = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions,
)
module_PB = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
trunk=module_PF.trunk,
)
pf_estimator = DiscretePolicyEstimator(
module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor
)
pb_estimator = DiscretePolicyEstimator(
module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor
)
gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, logZ=0.0)

# Feed pf to the sampler.
sampler = LocalSearchSampler(estimator=pf_estimator, pb_estimator=pb_estimator)

# Move the gflownet to the GPU.
gflownet = gflownet.to(device_str)

# Policy parameters have their own LR. Log Z gets dedicated learning rate
# (typically higher).
optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr)
optimizer.add_param_group(
{"params": gflownet.logz_parameters(), "lr": args.lr_logz}
)

for i in (pbar := tqdm(range(args.n_iterations))):
trajectories = sampler.sample_trajectories(
env,
n=(args.batch_size // args.n_local_search_loops),
save_logprobs=False,
save_estimator_outputs=False,
epsilon=args.epsilon,
n_local_search_loops=args.n_local_search_loops,
back_ratio=0.5,
use_metropolis_hastings=False,
)
optimizer.zero_grad()
loss = gflownet.loss(env, trajectories)
loss.backward()
optimizer.step()
pbar.set_postfix({"loss": loss.item()})


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage")
parser.add_argument(
"--ndim", type=int, default=2, help="Number of dimensions in the environment"
)
parser.add_argument(
"--height", type=int, default=16, help="Height of the environment"
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--lr",
type=float,
default=1e-3,
help="Learning rate for the estimators' modules",
)
parser.add_argument(
"--lr_logz",
type=float,
default=1e-1,
help="Learning rate for the logZ parameter",
)
parser.add_argument(
"--n_iterations", type=int, default=1000, help="Number of iterations"
)
parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
parser.add_argument(
"--epsilon", type=float, default=0.1, help="Epsilon for the sampler"
)

# Local search parameters.
parser.add_argument(
"--n_local_search_loops",
type=int,
default=4,
help="Number of local search loops",
)
parser.add_argument(
"--back_ratio",
type=float,
default=0.5,
help="The ratio of the number of backward steps to the length of the trajectory",
)

args = parser.parse_args()

main(args)
Loading