Skip to content

Commit

Permalink
draft commit
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Oct 29, 2024
1 parent 5a4198e commit c2d59b3
Showing 1 changed file with 166 additions and 3 deletions.
169 changes: 166 additions & 3 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def sample_actions(
"""Samples actions from the given states.
Args:
estimator: A GFNModule to pass to the probability distribution calculator.
env: The environment to sample actions from.
states: A batch of states.
conditioning: An optional tensor of conditioning information.
Expand Down Expand Up @@ -203,11 +202,11 @@ def sample_trajectories(
all_estimator_outputs.append(estimator_outputs_padded)

actions[~dones] = valid_actions
trajectories_actions.append(actions)
if save_logprobs:
# When off_policy, actions_log_probs are None.
log_probs[~dones] = actions_log_probs
trajectories_actions.append(actions)
trajectories_logprobs.append(log_probs)
trajectories_logprobs.append(log_probs)

if self.estimator.is_backward:
new_states = env._backward_step(states, actions)
Expand Down Expand Up @@ -264,3 +263,167 @@ def sample_trajectories(
)

return trajectories


class LocalSearchSampler(Sampler):
"""Sampler equipped with local search capabilities.
The local search operation is based on back-and-forth heuristic, first proposed
by Zhang et al. 2022 (https://arxiv.org/abs/2202.01361) for negative sampling
and further explored its effectiveness in various applications by Kim et al. 2023
(https://arxiv.org/abs/2310.02710).
Attributes:
estimator: the submitted PolicyEstimator for the forward pass.
pb_estimator: the PolicyEstimator for the backward pass.
"""

def __init__(self, estimator: GFNModule, pb_estimator: GFNModule):
super().__init__(estimator)
self.backward_sampler = Sampler(pb_estimator)

def local_search(
self,
env: Env,
trajectories: Trajectories,
conditioning: torch.Tensor | None = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
back_steps: torch.Tensor | None = None,
back_ratio: float | None = None,
**policy_kwargs: Any,
) -> Trajectories:
# K-step backward sampling with the backward estimator,
# where K is the number of backward steps used in https://arxiv.org/abs/2202.01361.
if back_steps is None:
assert (
back_ratio is not None
), "Either kwarg `back_steps` or `back_ratio` must be specified"
K = torch.ceil(back_ratio * (trajectories.when_is_done - 1)).long()
else:
K = torch.where(
back_steps > trajectories.when_is_done,
trajectories.when_is_done,
back_steps,
)

backward_trajectories = self.backward_sampler.sample_trajectories(
env,
states=trajectories.last_states,
conditioning=conditioning,
save_estimator_outputs=save_estimator_outputs,
save_logprobs=save_logprobs,
**policy_kwargs,
)
# Calculate the forward probability if needed (metropolis-hastings).
if save_logprobs:
raise NotImplementedError("metropolis-hastings is not implemented yet.")

all_states = backward_trajectories.to_states()
bs = backward_trajectories.n_trajectories
junction_states = all_states[
torch.arange(bs, device=all_states.device) + bs * K
]

### Reconstructing with self.estimator
recon_trajectories = super().sample_trajectories(
env,
states=junction_states,
conditioning=conditioning,
save_estimator_outputs=save_estimator_outputs,
save_logprobs=save_logprobs,
**policy_kwargs,
)
# Calculate backward probability if needed (metropolis-hastings).
if save_logprobs:
raise NotImplementedError("metropolis-hastings is not implemented yet.")

# Obtain full trajectories by concatenating the backward and forward parts.
import pdb

pdb.set_trace()

if save_logprobs: # concatenate log_probs
raise NotImplementedError("metropolis-hastings is not implemented yet.")

def sample_trajectories(
self,
env: Env,
n: Optional[int] = None,
states: Optional[States] = None,
conditioning: Optional[torch.Tensor] = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
n_local_search_loops: int = 0,
back_steps: torch.Tensor | None = None,
back_ratio: float | None = None,
use_metropolis_hastings: bool = False,
**policy_kwargs: Any,
) -> Trajectories:
"""Sample trajectories sequentially with optional local search.
Args:
env: The environment to sample trajectories from.
n: If given, a batch of n_trajectories will be sampled all
starting from the environment's s_0.
states: If given, trajectories would start from such states. Otherwise,
trajectories are sampled from $s_o$ and n_trajectories must be provided.
conditioning: An optional tensor of conditioning information.
save_estimator_outputs: If True, the estimator outputs will be returned. This
is useful for off-policy training with tempered policy.
save_logprobs: If True, calculates and saves the log probabilities of sampled
actions. This is useful for on-policy training.
local_search: If True, applies local search operation.
back_steps: The number of backward steps.
back_ratio: The ratio of the number of backward steps to the length of the trajectory.
use_metropolis_hastings: If True, applies Metropolis-Hastings acceptance criterion.
policy_kwargs: keyword arguments to be passed to the
`to_probability_distribution` method of the estimator. For example, for
DiscretePolicyEstimators, the kwargs can contain the `temperature`
parameter, `epsilon`, and `sf_bias`. In the continuous case these
kwargs will be user defined. This can be used to, for example, sample
off-policy.
Returns: A Trajectories object representing the batch of sampled trajectories,
where the batch size is n * (1 + n_local_search_loops).
"""

trajectories = super().sample_trajectories(
env,
n,
states,
conditioning,
save_estimator_outputs,
save_logprobs,
**policy_kwargs,
)
all_trajectories = trajectories
for _ in range(n_local_search_loops):
# Search phase
ls_trajectories = self.local_search(
env,
trajectories,
conditioning,
save_estimator_outputs,
save_logprobs or use_metropolis_hastings,
back_steps,
back_ratio,
**policy_kwargs,
)
all_trajectories.extend(
ls_trajectories
) # Store all regardless of the acceptance.

# Selection phase
if not use_metropolis_hastings:
update_indices = trajectories.log_rewards < ls_trajectories.log_rewards
trajectories[update_indices] = ls_trajectories[update_indices]
else: # Metropolis-Hastings acceptance criterion
# TODO: Implement Metropolis-Hastings acceptance criterion.
# We need p(x -> s -> x') = p_B(x -> s) * p_F(s -> x')
# and p(x' -> s' -> x) = p_B(x' -> s') * p_F(s' -> x)
# to calculate the acceptance ratio.
raise NotImplementedError(
"Metropolis-Hastings acceptance criterion is not implemented."
)

return all_trajectories

0 comments on commit c2d59b3

Please sign in to comment.