Skip to content

Commit

Permalink
Initial Working Version
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 3, 2023
0 parents commit e088f0b
Show file tree
Hide file tree
Showing 12 changed files with 377 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
style = "sciml"
margin = 92
indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
always_for_in = true
annotate_untyped_fields_with_any = false
59 changes: 59 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
steps:
- label: ":julia: Julia {{matrix.julia}}"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
agents:
os: "linux"
queue: "juliaecosystem"
arch: "x86_64"
env:
GROUP: "CPU"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 240
matrix:
setup:
julia:
- "1"

- label: ":julia: Julia {{matrix.julia}} + CUDA GPU"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
agents:
queue: "juliagpu"
cuda: "*"
env:
GROUP: "CUDA"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 240
matrix:
setup:
julia:
- "1"

- label: ":julia: Julia: {{matrix.julia}} + AMD GPU"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
env:
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
GROUP: "AMDGPU"
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 240
matrix:
setup:
julia:
- "1"
16 changes: 16 additions & 0 deletions .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: CompatHelper
on:
schedule:
- cron: 0 0 * * *
workflow_dispatch:
jobs:
CompatHelper:
runs-on: ubuntu-latest
steps:
- name: Pkg.add("CompatHelper")
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
- name: CompatHelper.main()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
31 changes: 31 additions & 0 deletions .github/workflows/TagBot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: TagBot
on:
issue_comment:
types:
- created
workflow_dispatch:
inputs:
lookback:
default: 3
permissions:
actions: read
checks: read
contents: write
deployments: read
issues: read
discussions: read
packages: read
pages: read
pull-requests: read
repository-projects: read
security-events: read
statuses: read
jobs:
TagBot:
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'
runs-on: ubuntu-latest
steps:
- uses: JuliaRegistries/TagBot@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
ssh: ${{ secrets.DOCUMENTER_KEY }}
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/Manifest.toml
test/Manifest.toml
.vscode
wip
test_ext
ext_compat
.CondaPkg
7 changes: 7 additions & 0 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
# Conda package names and versions
python = ">=3.10,<3.12"

[pip.deps]
flax = ">= 0.7"
numpy = ">= 1"
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 Avik Pal <[email protected]> and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
17 changes: 17 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name = "LuxJax"
uuid = "0265fd5b-45c4-47b1-878d-f9552b087dff"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
DLPack = "53c2dc0f-f7d5-43fd-8906-6c0220547083"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
julia = "1.9"
54 changes: 54 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# LuxJax

LuxJax allows you to use Neural Networks written in Jax with the Lux API, allowing seamless
integration with the rest of the SciML ecosystem.

Lux.jl is great and is quite fast and useful if you are implementing custom operations.
However, there are quite a few standard workloads where XLA can highly optimize the training
and inference. This package bridges that gap, and allows you to use the fast Jax Neural
Networks with the SciMLverse!

## Installation

The installation is currently a bit manual. First install this package.

```julia
import Pkg
Pkg.add("https://github.com/LuxDL/LuxJax.jl")
```

Then install the Jax dependencies.

```julia
using LuxJax
LuxJax.install("<setup>")
```

`install` will install the Jax dependencies based on the `setup` provided!

## Usage Example

```julia
using LuxJax
```

## Tips

* When mixing jax and julia it's recommended to disable jax's preallocation with setting the
environment variable `XLA_PYTHON_CLIENT_PREALLOCATE=false`.

## Roadmap

- [ ] Automatic Differentiation
- [ ] Capture Chain Rules
- [ ] Reverse Mode
- [ ] Forward Mode (Very Low Priority)
- [ ] Capture ForwardDiff Duals for Forward Mode
- [ ] Automatically Map Lux Models to Flax Models similar to the Flux to Lux conversion
- [ ] Handle Component Arrays
- [ ] Demonstrate Training of Neural ODEs using Jax and SciMLSensitivity.jl

## Acknowledgements

This packages is a more opinionated take on
[PyCallChainRules.jl](https://github.com/rejuvyesh/PyCallChainRules.jl)
144 changes: 144 additions & 0 deletions src/LuxJax.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
module LuxJax

using CondaPkg, PythonCall # Interact with Jax
using DLPack # Transfer Data between Julia and Jax
using ChainRulesCore, Functors, Random
using LuxCore
import ConcreteStructs: @concrete

# Setup Jax
const jax = PythonCall.pynew()
const dlpack = PythonCall.pynew()
const jnp = PythonCall.pynew()
const numpy = PythonCall.pynew()
const flax = PythonCall.pynew()
const linen = PythonCall.pynew()

const is_jax_setup = Ref{Bool}(false)

const VALID_JAX_SETUPS = ("cpu", "cuda12_pip", "cuda11_pip", "cuda12_local", "cuda11_local",
"tpu")

function __load_jax_dependencies(; force::Bool = true)
try
CondaPkg.withenv() do
PythonCall.pycopy!(jax, pyimport("jax"))
PythonCall.pycopy!(dlpack, pyimport("jax.dlpack"))
PythonCall.pycopy!(jnp, pyimport("jax.numpy"))
PythonCall.pycopy!(flax, pyimport("flax"))
PythonCall.pycopy!(linen, pyimport("flax.linen"))
end

is_jax_setup[] = true
catch err
is_jax_setup[] = false

if force
rethrow(err)
else
@warn "Jax is not installed. Please install Jax first using `LuxJax.install(<setup>)`!"
@debug err
end
end
end

function __init__()
CondaPkg.withenv() do
PythonCall.pycopy!(numpy, pyimport("numpy"))
end
__load_jax_dependencies(; force = false)
end

"""
install(setup::String = "cpu")
Installs Jax into the correct environment. The `setup` argument can be one of the following:
- `"cpu"`: Installs Jax with CPU support.
- `"cuda12_pip"`: Installs Jax with CUDA 12 support using pip.
- `"cuda11_pip"`: Installs Jax with CUDA 11 support using pip.
- `"cuda12_local"`: Installs Jax with CUDA 12 support using a local CUDA installation.
- `"cuda11_local"`: Installs Jax with CUDA 11 support using a local CUDA installation.
- `"tpu"`: Installs Jax with TPU support.
"""
function install(setup::String = "cpu")
@assert setup VALID_JAX_SETUPS "Invalid setup: $(setup)! Select one of $(VALID_JAX_SETUPS)!"
CondaPkg.withenv() do
python = CondaPkg.which("python")
run(`$(python) --version`)
run(`$(python) -m pip install --upgrade pip`)
if occursin("cpu", setup)
run(`$(python) -m pip install --upgrade jax\[cpu\]\>=0.4`)
elseif occursin("cuda", setup)
run(`$(python) -m pip install --upgrade jax\[$(setup)\]\>=0.4 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html`)
else # TPU
run(`$(python) -m pip install --upgrade jax\[$(setup)\]\>=0.4 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html`)
end
end

__load_jax_dependencies()

return
end

# FIXME: Wrap DLPack functions to allow them to work for tuple, namedtuple and functors

# Lux Interop
__from_flax_params(ps) = (; params = Tuple(map(p -> DLPack.wrap(p, dlpack.to_dlpack), ps)))

function __to_flax_params(ps, tree)
return jax.tree_util.tree_unflatten(tree,
pylist(map(p -> DLPack.share(p, dlpack.from_dlpack), ps.params)))
end

@concrete struct LuxFlaxWrapper <: LuxCore.AbstractExplicitLayer
flaxmodel
input_shape
tree_structure
end

function LuxFlaxWrapper(flaxmodel, input_shape)
return LuxFlaxWrapper(flaxmodel, input_shape, PythonCall.pynew())
end

function LuxCore.initialparameters(rng::AbstractRNG, layer::LuxFlaxWrapper)
seed = rand(rng, UInt32)
params = layer.flaxmodel.init(jax.random.PRNGKey(seed),
jnp.ones((1, reverse(layer.input_shape)...)))
ps_flat, tree_structure = jax.tree_util.tree_flatten(params)
PythonCall.pycopy!(layer.tree_structure, tree_structure)
return __from_flax_params(ps_flat)
end

(l::LuxFlaxWrapper)(x, ps, st) = LuxCore.apply(l, x, ps, st)

function LuxCore.apply(l::LuxFlaxWrapper, x, ps, st::NamedTuple)
x_jax = DLPack.share(x, dlpack.from_dlpack)
ps_jax = __to_flax_params(ps, l.tree_structure)
y = l.flaxmodel.apply(ps_jax, x_jax)
return DLPack.wrap(y, dlpack.to_dlpack), st
end

function ChainRulesCore.rrule(::typeof(LuxCore.apply), l::LuxFlaxWrapper, x, ps,
st::NamedTuple)
projectₓ = ProjectTo(x)
projectₚ = ProjectTo(ps)
x_jax = DLPack.share(x, dlpack.from_dlpack)
ps_jax = __to_flax_params(ps, l.tree_structure)
y, jax_vjpfun = jax.vjp(l.flaxmodel.apply, ps_jax, x_jax)
function ∇flax_apply(Δ)
# FIXME: Fix dispatches so that we dont have to collect
∂y = DLPack.share(collect(first(unthunk(Δ))), dlpack.from_dlpack)
(∂ps_jax, ∂x_jax) = jax_vjpfun(∂y)
∂x = projectₓ(DLPack.wrap(∂x_jax, dlpack.to_dlpack))
∂ps = projectₚ(__from_flax_params(jax.tree_util.tree_flatten(∂ps_jax)[0]))
return (NoTangent(), NoTangent(), ∂x, ∂ps, NoTangent())
end
return (DLPack.wrap(y, dlpack.to_dlpack), st), ∇flax_apply
end

# exports
export jax, jnp, flax, linen
export LuxFlaxWrapper

end
9 changes: 9 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using Aqua, LuxJax, SafeTestsets, Test

@testset "LuxJax.jl" begin
@test 1 == 1
end

0 comments on commit e088f0b

Please sign in to comment.