Skip to content

Commit

Permalink
minor update to improve code efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
HumphreyYang committed Jul 6, 2024
1 parent 09d6e1c commit b783beb
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions lectures/calvo_gradient.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ clq = ChangLQ(β=0.85, c=2, T=T)
```

```{code-cell} ipython3
@jit
def compute_θ(μ, α=1):
λ = α / (1 + α)
T = len(μ) - 1
Expand All @@ -326,6 +327,7 @@ def compute_θ(μ, α=1):
θ = θ.at[-1].set(μbar)
return θ
@jit
def compute_V(μ, β, c, α=1, u0=1, u1=0.5, u2=3):
θ = compute_θ(μ, α)
Expand All @@ -342,9 +344,6 @@ def compute_V(μ, β, c, α=1, u0=1, u1=0.5, u2=3):
V += (β**T / (1 - β)) * (h0 + h1 * μ[-1] + h2 * μ[-1]**2 - 0.5 * c * μ[-1]**2)
return V
compute_θ = jit(compute_θ)
compute_V = jit(compute_V)
```

```{code-cell} ipython3
Expand All @@ -356,7 +355,7 @@ print(f'deviation = {np.linalg.norm(V_val - clq.J_series[0])}') # good!

Now we want to maximize the function $V$ by choice of $\mu$.

We will use the [Adam optimizer](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adam) from the `optax` library.
We will use the [`optax.adam`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adam) from the `optax` library.

```{code-cell} ipython3
def adam_optimizer(grad_func, init_params,
Expand All @@ -370,6 +369,7 @@ def adam_optimizer(grad_func, init_params,
opt_state = optimizer.init(params)
# Update parameters and gradients
@jit
def update(params, opt_state):
grads = grad_func(params)
updates, opt_state = optimizer.update(grads, opt_state)
Expand All @@ -390,19 +390,24 @@ def adam_optimizer(grad_func, init_params,
return params
```

Here we use automatic differentiation functionality in JAX with `jax.grad`.

```{code-cell} ipython3
:tags: [scroll-output]
%%time
# Initial guess for μ
μ = jnp.zeros(T)
μ_init = jnp.zeros(T)
# Maximization instead of minimization
grad_V = jax.grad(lambda μ: -compute_V(μ, β=0.85, c=2))
grad_V = jit(jax.grad(
lambda μ: -compute_V(μ, β=0.85, c=2)))
```

```{code-cell} ipython3
%%time
# Optimize μ
optimized_μ = adam_optimizer(grad_V, μ)
optimized_μ = adam_optimizer(grad_V, μ_init)
print(f"optimized μ = \n{optimized_μ}")
```
Expand All @@ -416,8 +421,11 @@ print(f'deviation = {np.linalg.norm(optimized_μ - clq.μ_series)}')
```

```{code-cell} ipython3
compute_V(optimized_μ, β=0.85, c=2) \
> compute_V(clq.μ_series, β=0.85, c=2)
compute_V(optimized_μ, β=0.85, c=2)
```

```{code-cell} ipython3
compute_V(clq.μ_series, β=0.85, c=2)
```

## Regressing $\vec \theta_t$ and $\vec \mu_t$
Expand Down Expand Up @@ -539,13 +547,14 @@ In this case, we restrict $\mu_t = \bar \mu \text{ for } \forall t$

```{code-cell} ipython3
# Initial guess for single μ
μ = jnp.zeros(1)
μ_init = jnp.zeros(1)
# Maximization instead of minimization
grad_V = jax.grad(lambda μ: -compute_V(μ, β=0.85, c=2))
grad_V = jit(jax.grad(
lambda μ: -compute_V(μ, β=0.85, c=2)))
# Optimize μ
optimized_μ_CR = adam_optimizer(grad_V, μ)
optimized_μ_CR = adam_optimizer(grad_V, μ_init)
print(f"optimized μ = \n{optimized_μ_CR}")
```
Expand All @@ -557,9 +566,9 @@ np.linalg.norm(clq.μ_CR - optimized_μ_CR)
```

```{code-cell} ipython3
compute_V(jnp.array([clq.μ_CR]), β=0.85, c=2)
compute_V(optimized_μ_CR, β=0.85, c=2)
```

```{code-cell} ipython3
compute_V(optimized_μ_CR, β=0.85, c=2)
compute_V(jnp.array([clq.μ_CR]), β=0.85, c=2)
```

1 comment on commit b783beb

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.