Skip to content

Commit

Permalink
Add JaxLayer and FlaxLayer to wrap JAX/Flax modules as layers. (#…
Browse files Browse the repository at this point in the history
…19342)

- `JaxLayer` can wrap any JAX model defined by a function.
- `FlaxLayer` is a subclass of `JaxLayer` that can wrap a Flax module.
  • Loading branch information
hertschuh authored Mar 21, 2024
1 parent ed4c802 commit ff28c35
Show file tree
Hide file tree
Showing 3 changed files with 1,325 additions and 0 deletions.
Loading

0 comments on commit ff28c35

Please sign in to comment.