-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstate.py
29 lines (23 loc) · 1.2 KB
/
state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import NamedTuple
import jax.numpy as jnp
import optax
import chex
class SWAState(NamedTuple):
"""State for SWAG mean and non-centered variance."""
mean: optax.Params # Running mean of iterates.
step: chex.Array = jnp.zeros([], jnp.int32) # Step count.
n: chex.Array = jnp.zeros([], jnp.int32) # Iterate count using for running stats.
class SWAGDiagState(NamedTuple):
"""State for SWAG mean and diagonal non-centered variance."""
mean: optax.Params # Running mean of iterates.
params2: optax.Params # Running non-centered variance of iterates.
step: chex.Array = jnp.zeros([], jnp.int32) # Step count.
n: chex.Array = jnp.zeros([], jnp.int32) # Iterate count using for running stats.
class SWAGState(NamedTuple):
"""State for SWAG mean, diagonal non-centered variance and low rank terms."""
mean: optax.Params # Running mean of iterates.
params2: optax.Params # Running non-centered variance of iterates.
dparams: optax.Params # Low rank delta columns.
step: chex.Array = jnp.zeros([], jnp.int32) # Step count.
n: chex.Array = jnp.zeros([], jnp.int32) # Iterate count using for running stats.
c: chex.Array = jnp.zeros([], jnp.int32) # Current column to update.