diff --git a/README.md b/README.md index 8249e6ed..46a4e386 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,7 @@ optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1}) # 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration for i in (pbar := tqdm(range(1000))): - trajectories = sampler.sample_trajectories(env=env, n=16) + trajectories = sampler.sample_trajectories(env=env, n=16, save_logprobs=True) # The save_logprobs=True makes on-policy training faster optimizer.zero_grad() loss = gfn.loss(env, trajectories) loss.backward() @@ -152,7 +152,7 @@ logF_estimator = ScalarEstimator(module=module_logF, preprocessor=env.preprocess gfn = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF, lamda=0.9) # 5 - We define the sampler and the optimizer. -sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy +sampler = Sampler(estimator=pf_estimator) # Different policy parameters can have their own LR. # Log F gets dedicated learning rate (typically higher). @@ -161,7 +161,10 @@ optimizer.add_param_group({"params": gfn.logF_parameters(), "lr": 1e-2}) # 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration for i in (pbar := tqdm(range(1000))): - trajectories = sampler.sample_trajectories(env=env, n=16) + # We are going to sample trajectories off policy, by tempering the distribution. + # We should not save the sampling logprobs, as we are not using them for training. + # We should save the estimator outputs to make training faster. + trajectories = sampler.sample_trajectories(env=env, n=16, save_logprobs=False, save_estimator_outputs=True, temperature=1.5) optimizer.zero_grad() loss = gfn.loss(env, trajectories) loss.backward()