Skip to content

Commit

Permalink
Fix types and add callback.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed May 15, 2024
1 parent 4e22a0c commit a88bd83
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions flybody/agents/actors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Acme agent implementations."""

from typing import Optional
from typing import Callable

from acme import adders
from acme import core
Expand All @@ -27,9 +27,10 @@ class DelayedFeedForwardActor(core.Actor):
def __init__(
self,
policy_network: snt.Module,
adder: Optional[adders.Adder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
action_delay: Optional[int] = None,
adder: adders.Adder | None = None,
variable_client: tf2_variable_utils.VariableClient | None = None,
action_delay: int | None = None,
observation_callback: Callable | None = None,
):
"""Initializes the actor.
Expand All @@ -40,6 +41,8 @@ def __init__(
variable_client: object which allows to copy weights from the learner copy
of the policy to the actor copy (in case they are separate).
action_delay: number of timesteps to delay the action for.
observation_callback: Optional callable to process observations before
passing them to policy.
"""

# Store these for later use.
Expand All @@ -49,9 +52,10 @@ def __init__(
self._action_delay = action_delay
if action_delay is not None:
self._action_queue = []
self._observation_callback = observation_callback

@tf.function
def _policy(self, observation: types.NestedTensor) -> types.NestedTensor:
def _policy(self, observation: types.NestedArray) -> types.NestedTensor:
# Add a dummy batch dimension and as a side effect convert numpy to TF.
batched_observation = tf2_utils.add_batch_dim(observation)

Expand All @@ -66,6 +70,9 @@ def _policy(self, observation: types.NestedTensor) -> types.NestedTensor:

def select_action(self,
observation: types.NestedArray) -> types.NestedArray:
"""Samples from the policy and returns an action."""
if self._observation_callback is not None:
observation = self._observation_callback(observation)
# Pass the observation through the policy network.
action = self._policy(observation)

Expand Down

0 comments on commit a88bd83

Please sign in to comment.