From c2d59b32cff3f246d6e9c526025cacf4ca82f493 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Tue, 29 Oct 2024 23:57:34 +0900 Subject: [PATCH] draft commit --- src/gfn/samplers.py | 169 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 166 insertions(+), 3 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 819620f0..d0871470 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -39,7 +39,6 @@ def sample_actions( """Samples actions from the given states. Args: - estimator: A GFNModule to pass to the probability distribution calculator. env: The environment to sample actions from. states: A batch of states. conditioning: An optional tensor of conditioning information. @@ -203,11 +202,11 @@ def sample_trajectories( all_estimator_outputs.append(estimator_outputs_padded) actions[~dones] = valid_actions + trajectories_actions.append(actions) if save_logprobs: # When off_policy, actions_log_probs are None. log_probs[~dones] = actions_log_probs - trajectories_actions.append(actions) - trajectories_logprobs.append(log_probs) + trajectories_logprobs.append(log_probs) if self.estimator.is_backward: new_states = env._backward_step(states, actions) @@ -264,3 +263,167 @@ def sample_trajectories( ) return trajectories + + +class LocalSearchSampler(Sampler): + """Sampler equipped with local search capabilities. + The local search operation is based on back-and-forth heuristic, first proposed + by Zhang et al. 2022 (https://arxiv.org/abs/2202.01361) for negative sampling + and further explored its effectiveness in various applications by Kim et al. 2023 + (https://arxiv.org/abs/2310.02710). + + Attributes: + estimator: the submitted PolicyEstimator for the forward pass. + pb_estimator: the PolicyEstimator for the backward pass. + """ + + def __init__(self, estimator: GFNModule, pb_estimator: GFNModule): + super().__init__(estimator) + self.backward_sampler = Sampler(pb_estimator) + + def local_search( + self, + env: Env, + trajectories: Trajectories, + conditioning: torch.Tensor | None = None, + save_estimator_outputs: bool = False, + save_logprobs: bool = True, + back_steps: torch.Tensor | None = None, + back_ratio: float | None = None, + **policy_kwargs: Any, + ) -> Trajectories: + # K-step backward sampling with the backward estimator, + # where K is the number of backward steps used in https://arxiv.org/abs/2202.01361. + if back_steps is None: + assert ( + back_ratio is not None + ), "Either kwarg `back_steps` or `back_ratio` must be specified" + K = torch.ceil(back_ratio * (trajectories.when_is_done - 1)).long() + else: + K = torch.where( + back_steps > trajectories.when_is_done, + trajectories.when_is_done, + back_steps, + ) + + backward_trajectories = self.backward_sampler.sample_trajectories( + env, + states=trajectories.last_states, + conditioning=conditioning, + save_estimator_outputs=save_estimator_outputs, + save_logprobs=save_logprobs, + **policy_kwargs, + ) + # Calculate the forward probability if needed (metropolis-hastings). + if save_logprobs: + raise NotImplementedError("metropolis-hastings is not implemented yet.") + + all_states = backward_trajectories.to_states() + bs = backward_trajectories.n_trajectories + junction_states = all_states[ + torch.arange(bs, device=all_states.device) + bs * K + ] + + ### Reconstructing with self.estimator + recon_trajectories = super().sample_trajectories( + env, + states=junction_states, + conditioning=conditioning, + save_estimator_outputs=save_estimator_outputs, + save_logprobs=save_logprobs, + **policy_kwargs, + ) + # Calculate backward probability if needed (metropolis-hastings). + if save_logprobs: + raise NotImplementedError("metropolis-hastings is not implemented yet.") + + # Obtain full trajectories by concatenating the backward and forward parts. + import pdb + + pdb.set_trace() + + if save_logprobs: # concatenate log_probs + raise NotImplementedError("metropolis-hastings is not implemented yet.") + + def sample_trajectories( + self, + env: Env, + n: Optional[int] = None, + states: Optional[States] = None, + conditioning: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, + save_logprobs: bool = True, + n_local_search_loops: int = 0, + back_steps: torch.Tensor | None = None, + back_ratio: float | None = None, + use_metropolis_hastings: bool = False, + **policy_kwargs: Any, + ) -> Trajectories: + """Sample trajectories sequentially with optional local search. + + Args: + env: The environment to sample trajectories from. + n: If given, a batch of n_trajectories will be sampled all + starting from the environment's s_0. + states: If given, trajectories would start from such states. Otherwise, + trajectories are sampled from $s_o$ and n_trajectories must be provided. + conditioning: An optional tensor of conditioning information. + save_estimator_outputs: If True, the estimator outputs will be returned. This + is useful for off-policy training with tempered policy. + save_logprobs: If True, calculates and saves the log probabilities of sampled + actions. This is useful for on-policy training. + local_search: If True, applies local search operation. + back_steps: The number of backward steps. + back_ratio: The ratio of the number of backward steps to the length of the trajectory. + use_metropolis_hastings: If True, applies Metropolis-Hastings acceptance criterion. + policy_kwargs: keyword arguments to be passed to the + `to_probability_distribution` method of the estimator. For example, for + DiscretePolicyEstimators, the kwargs can contain the `temperature` + parameter, `epsilon`, and `sf_bias`. In the continuous case these + kwargs will be user defined. This can be used to, for example, sample + off-policy. + + Returns: A Trajectories object representing the batch of sampled trajectories, + where the batch size is n * (1 + n_local_search_loops). + """ + + trajectories = super().sample_trajectories( + env, + n, + states, + conditioning, + save_estimator_outputs, + save_logprobs, + **policy_kwargs, + ) + all_trajectories = trajectories + for _ in range(n_local_search_loops): + # Search phase + ls_trajectories = self.local_search( + env, + trajectories, + conditioning, + save_estimator_outputs, + save_logprobs or use_metropolis_hastings, + back_steps, + back_ratio, + **policy_kwargs, + ) + all_trajectories.extend( + ls_trajectories + ) # Store all regardless of the acceptance. + + # Selection phase + if not use_metropolis_hastings: + update_indices = trajectories.log_rewards < ls_trajectories.log_rewards + trajectories[update_indices] = ls_trajectories[update_indices] + else: # Metropolis-Hastings acceptance criterion + # TODO: Implement Metropolis-Hastings acceptance criterion. + # We need p(x -> s -> x') = p_B(x -> s) * p_F(s -> x') + # and p(x' -> s' -> x) = p_B(x' -> s') * p_F(s' -> x) + # to calculate the acceptance ratio. + raise NotImplementedError( + "Metropolis-Hastings acceptance criterion is not implemented." + ) + + return all_trajectories