-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcommon.py
114 lines (92 loc) · 3.27 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import collections
import os
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
Batch = collections.namedtuple(
"Batch",
[
"observations",
"actions",
"rewards",
"masks",
"next_observations",
"next_actions",
],
)
def default_init(scale: Optional[float] = jnp.sqrt(2)):
return nn.initializers.orthogonal(scale)
PRNGKey = Any
Params = flax.core.FrozenDict[str, Any]
PRNGKey = Any
Shape = Sequence[int]
Dtype = Any # this could be a real type?
InfoDict = Dict[str, float]
class MLP(nn.Module):
hidden_dims: Sequence[int]
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
activate_final: int = False
layer_norm: bool = False
dropout_rate: Optional[float] = None
@nn.compact
def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
for i, size in enumerate(self.hidden_dims):
x = nn.Dense(size, kernel_init=default_init())(x)
if i + 1 < len(self.hidden_dims) or self.activate_final:
if self.layer_norm:
x = nn.LayerNorm()(x)
x = self.activations(x)
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(
x, deterministic=not training
)
return x
@flax.struct.dataclass
class Model:
step: int
apply_fn: nn.Module = flax.struct.field(pytree_node=False)
params: Params
tx: Optional[optax.GradientTransformation] = flax.struct.field(pytree_node=False)
opt_state: Optional[optax.OptState] = None
@classmethod
def create(
cls,
model_def: nn.Module,
inputs: Sequence[jnp.ndarray],
tx: Optional[optax.GradientTransformation] = None,
) -> "Model":
variables = model_def.init(*inputs)
_, params = variables.pop("params")
if tx is not None:
opt_state = tx.init(params)
else:
opt_state = None
return cls(
step=1, apply_fn=model_def, params=params, tx=tx, opt_state=opt_state
)
def __call__(self, *args, **kwargs):
return self.apply_fn.apply({"params": self.params}, *args, **kwargs)
def apply(self, *args, **kwargs):
return self.apply_fn.apply(*args, **kwargs)
def apply_gradient(self, loss_fn) -> Tuple[Any, "Model"]:
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, info = grad_fn(self.params)
updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
new_params = optax.apply_updates(self.params, updates)
return (
self.replace(
step=self.step + 1, params=new_params, opt_state=new_opt_state
),
info,
)
def save(self, save_path: str):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "wb") as f:
f.write(flax.serialization.to_bytes(self.params))
def load(self, load_path: str) -> "Model":
with open(load_path, "rb") as f:
params = flax.serialization.from_bytes(self.params, f.read())
return self.replace(params=params)