Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544346317
  • Loading branch information
psc-g committed Jun 29, 2023
1 parent a6f414c commit ce36aab
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion dopamine/jax/agents/dqn/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ def __init__(self,

def _build_networks_and_optimizer(self):
self._rng, rng = jax.random.split(self._rng)
self.online_params = self.network_def.init(rng, x=self.state)
state = self.preprocess_fn(self.state)
self.online_params = self.network_def.init(rng, x=state)
self.optimizer = create_optimizer(self._optimizer_name)
self.optimizer_state = self.optimizer.init(self.online_params)
self.target_network_params = self.online_params
Expand Down
5 changes: 3 additions & 2 deletions dopamine/jax/agents/rainbow/rainbow_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def __init__(self,

def _build_networks_and_optimizer(self):
self._rng, rng = jax.random.split(self._rng)
self.online_params = self.network_def.init(rng, x=self.state,
state = self.preprocess_fn(self.state)
self.online_params = self.network_def.init(rng, x=state,
support=self._support)
self.optimizer = dqn_agent.create_optimizer(self._optimizer_name)
self.optimizer_state = self.optimizer.init(self.online_params)
Expand Down Expand Up @@ -316,7 +317,7 @@ def begin_episode(self, observation):

self._rng, self.action = select_action(self.network_def,
self.online_params,
self.state,
self.preprocess_fn(self.state),
self._rng,
self.num_actions,
self.eval_mode,
Expand Down

0 comments on commit ce36aab

Please sign in to comment.