Skip to content

Commit

Permalink
Use ObservationBatch in StochasticPolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner committed Jul 3, 2022
1 parent d450450 commit b540f21
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/garage/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
pad_to_last, prefer_gpu,
product_of_gaussians, set_gpu_mode,
soft_update_model, state_dict_to,
torch_to_np, update_module_params)
torch_to_np, update_module_params,
list_to_tensor)

# yapf: enable
__all__ = [
Expand All @@ -23,6 +24,7 @@
'flatten_batch',
'flatten_to_single_vector',
'global_device',
'list_to_tensor',
'np_to_torch',
'ObservationBatch',
'observation_batch_to_packed_sequence',
Expand Down
2 changes: 1 addition & 1 deletion src/garage/torch/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ObservationBatch(torch.Tensor):
order: ObservationOrder
lengths: torch.Tensor = None

def __init__(self, observations, order, lengths):
def __init__(self, observations, order, lengths=None):
"""Check that lengths is consistent with the rest of the fields.
Raises:
Expand Down
12 changes: 10 additions & 2 deletions src/garage/torch/policies/stochastic_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np
import torch

from garage.torch._functions import list_to_tensor, np_to_torch
from garage.torch import (list_to_tensor, np_to_torch, ObservationBatch,
ObservationOrder)
from garage.torch.policies.policy import Policy


Expand Down Expand Up @@ -92,6 +93,8 @@ def get_actions(self, observations):

if isinstance(self._env_spec.observation_space, akro.Image):
observations /= 255.0 # scale image
observations = ObservationBatch(observations,
order=ObservationOrder.LAST)
dist, info = self.forward(observations)
return dist.sample().cpu().numpy(), {
k: v.detach().cpu().numpy()
Expand All @@ -105,7 +108,12 @@ def forward(self, observations):
Args:
observations (torch.Tensor): Batch of observations on default
torch device.
torch device. Stateful policies may require this input to be a
garage.torch.ObservationBatch.
Raises:
ShuffledOptimizationNotSupported: If this policy is a stateful
policy and the required an ObservationBatch.
Returns:
torch.distributions.Distribution: Batch distribution of actions.
Expand Down

0 comments on commit b540f21

Please sign in to comment.