Skip to content

Commit

Permalink
Fix gradient fitting (#19)
Browse files Browse the repository at this point in the history
* update doc

* fix Calcium bugs

* update requirements

* add examples for fitting with Adam optimizer

* fix gradient computation
  • Loading branch information
chaoming0625 authored Sep 6, 2024
1 parent b3438b6 commit 087bb5b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
1 change: 1 addition & 0 deletions examples/simple_dendrite_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,4 @@ def __init__(self, n_neuron: int, g_na, g_k):
def step_run(self, t, inp):
dx.rk4_step(self, t, inp)
return self.V.value, self.spike.value

28 changes: 14 additions & 14 deletions examples/simple_dendrite_model_fitting_by_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,16 @@ def loss_per_param(param, step=10):
losses = bts.metric.squared_error(simulated_vs.mantissa[..., ::step, 0], target_vs.mantissa[..., ::step, 0])
return losses.mean()

# calculate the average loss for all parameters,
# this is the loss function to be gradient-based optimizations
def loss_fun(step=10):
return jax.vmap(functools.partial(loss_per_param, step=step))(param_to_optimize.value).mean()
# calculate the gradients and loss for each parameter
@jax.vmap
@jax.jit
def compute_grad(param):
grads, loss = bst.transform.grad(loss_per_param, argnums=0, return_value=True)(param)
return grads, loss

# find the best loss and parameter in the batch
@bst.transform.jit
def best_loss_and_param(params):
losses = jax.vmap(loss_per_param)(params)
def best_loss_and_param(params, losses):
i_best = u.math.argmin(losses)
return losses[i_best], params[i_best]

Expand All @@ -117,19 +118,18 @@ def best_loss_and_param(params):
# Step 6: training
@bst.transform.jit
def train_step_per_epoch():
grads, loss = bst.transform.grad(loss_fun, grad_vars={'param': param_to_optimize}, return_value=True)()
optimizer.update(grads)
return loss
grads, losses = compute_grad(param_to_optimize.value)
optimizer.update({'param': grads})
return losses

for i_epoch in range(1000):
loss = train_step_per_epoch()
best_loss, best_param = best_loss_and_param(param_to_optimize.value)
losses = train_step_per_epoch()
best_loss, best_param = best_loss_and_param(param_to_optimize.value, losses)
if best_loss < 1e-5:
best_param = best_loss_and_param(param_to_optimize.value)[1]
print(f'Epoch {i_epoch}, loss={loss}, best loss={best_loss}, best param={best_param}')
print(f'Epoch {i_epoch}, loss={losses.mean()}, best loss={best_loss}, best param={best_param}')
break
if i_epoch % 10 == 0:
print(f'Epoch {i_epoch}, loss={loss}, best loss={best_loss}, best param={best_param}')
print(f'Epoch {i_epoch}, loss={losses.mean()}, best loss={best_loss}, best param={best_param}')

# Step 7: visualize the results
visualize_a_simulate(target_params, functools.partial(f_current, i_current=0), title='Target', show=False)
Expand Down

0 comments on commit 087bb5b

Please sign in to comment.