Skip to content

Commit

Permalink
Add doc to use MultiTrainStep
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 732914402
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Mar 3, 2025
1 parent 939b2f6 commit c23cf0b
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions docs/eval.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,39 @@ def forward(step, params, batch):
rngs = rng_streams.train_rngs(step) # Create the rng for current `step`
return model.apply(params, batch, rngs=rngs)
```

## Use cases

### GAN & Multi optimizers

Training on multi optimizer can be done using `kd.contrib.train.multi_optimizer`
and `trainstep=kd.contrib.train.MultiTrainStep()`.

```python
trainer = kd.train.Trainer(
...,
model=MyGan(
generator=MyGenerator(),
discriminator=MyDiscriminator(),
),
# Define the loss for the generator and the discriminator
losses={
'discriminator': kd.losses.L2(...),
'generator': kd.losses.L2(...),
},
optimizer=kd.contrib.train.multi_optimizer(
# Using `kd.optim.partial_update`, you can mask out which weights
# each of the optimizer will be applied too.
discriminator=kd.optim.partial_update(
optimizer=optax.adam(1e-4),
mask=kd.optim.select('discriminator'),
),
generator=kd.optim.partial_update(
optimizer=optax.adam(1e-4),
mask=kd.optim.select('generator'),
),
),
# Using `kd.contrib.train.multi_optimizer` require to use `MultiTrainStep`
trainstep=kd.contrib.train.MultiTrainStep(),
)
```

0 comments on commit c23cf0b

Please sign in to comment.