-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathactor.py
42 lines (34 loc) · 1.08 KB
/
actor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Tuple
import jax.numpy as jnp
from common import Batch, InfoDict, Model, Params, PRNGKey
def update_actor(
key: PRNGKey,
actor: Model,
critic: Model,
value: Model,
batch: Batch,
alpha: float,
epsilon: float,
alg: str,
) -> Tuple[Model, InfoDict]:
v = value(batch.observations)
if alg == "PORelDICE":
q1, q2 = critic(batch.observations, batch.actions)
q = jnp.minimum(q1, q2)
weight = 1 + (q - v) / alpha
weight = jnp.maximum(weight, 0.0)
else:
NotImplementedError
weight = jnp.clip(weight, 0.0, 100.0)
def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
dist = actor.apply(
{"params": actor_params},
batch.observations,
training=True,
rngs={"dropout": key},
)
log_probs = dist.log_prob(batch.actions)
actor_loss = -(weight * log_probs).mean()
return actor_loss, {"actor_loss": actor_loss}
new_actor, info = actor.apply_gradient(actor_loss_fn)
return new_actor, info