Skip to content

Commit

Permalink
Make garage.torch.ObservationBatch constructable
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner committed Jul 3, 2022
1 parent b540f21 commit 9985301
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/garage/torch/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,17 @@ class ObservationBatch(torch.Tensor):
order: ObservationOrder
lengths: torch.Tensor = None

def __init__(self, observations, order, lengths=None):
def __new__(cls, observations, order, lengths=None):
"""Check that lengths is consistent with the rest of the fields.
Raises:
ValueError: If lengths is not consistent with another field.
Returns:
ObservationBatch: A new observation batch.
"""
super().__init__(observations)
self = super().__new__(cls, observations)
self.order = order
self.lengths = lengths
if self.order == ObservationOrder.EPISODES:
Expand All @@ -86,6 +89,7 @@ def __init__(self, observations, order, lengths=None):
raise ValueError(
f'lengths has value {self.lengths}, but must be None '
f'when order == {self.order}')
return self


def observation_batch_to_packed_sequence(observations):
Expand Down

0 comments on commit 9985301

Please sign in to comment.