Skip to content

Commit

Permalink
talk about on policy and off policy in README
Browse files Browse the repository at this point in the history
  • Loading branch information
Salem Lahlou committed Jan 21, 2025
1 parent 5478f25 commit a0db872
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n=16)
trajectories = sampler.sample_trajectories(env=env, n=16, save_logprobs=True) # The save_logprobs=True makes on-policy training faster
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
Expand Down Expand Up @@ -152,7 +152,7 @@ logF_estimator = ScalarEstimator(module=module_logF, preprocessor=env.preprocess
gfn = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF, lamda=0.9)

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy
sampler = Sampler(estimator=pf_estimator)

# Different policy parameters can have their own LR.
# Log F gets dedicated learning rate (typically higher).
Expand All @@ -161,7 +161,10 @@ optimizer.add_param_group({"params": gfn.logF_parameters(), "lr": 1e-2})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n=16)
# We are going to sample trajectories off policy, by tempering the distribution.
# We should not save the sampling logprobs, as we are not using them for training.
# We should save the estimator outputs to make training faster.
trajectories = sampler.sample_trajectories(env=env, n=16, save_logprobs=False, save_estimator_outputs=True, temperature=1.5)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
Expand Down

0 comments on commit a0db872

Please sign in to comment.