Skip to content

Commit

Permalink
Clarify types.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed May 14, 2024
1 parent a91ffd0 commit 4e22a0c
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion flybody/agents/utils_tf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utilities for tensorflow networks and nested data structures."""

import numpy as np
from acme import types
from acme.tf import utils as tf2_utils


Expand All @@ -17,7 +19,9 @@ def __init__(self, policy, sample=False):
self._policy = policy
self._sample = sample

def __call__(self, observation):
def __call__(self, observation: types.NestedArray) -> np.ndarray:
# Add a dummy batch dimension and as a side effect convert numpy to TF,
# batched_observation: types.NestedTensor.
batched_observation = tf2_utils.add_batch_dim(observation)
distribution = self._policy(batched_observation)
if self._sample:
Expand Down

0 comments on commit 4e22a0c

Please sign in to comment.