diff --git a/Project.toml b/Project.toml index 44440b84..2875b439 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ SafeTestsets = "0.0.1" StaticArrays = "1.9" StaticArraysCore = "1.4" Test = "1" +Zygote = "0.6.67" julia = "1.10" [extras] @@ -27,6 +28,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays"] +test = ["Aqua", "Pkg", "Test", "SafeTestsets", "StaticArrays", "Zygote"] diff --git a/test/runtests.jl b/test/runtests.jl index eaf2ebe0..b91706cf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,9 @@ if GROUP == "All" || GROUP == "Core" @safetestset "BatchedInterface test" begin @time include("batched_interface_test.jl") end + @safetestset "Simple Adjoints test" begin + @time include("simple_adjoints_test.jl") + end end if GROUP == "All" || GROUP == "Downstream" diff --git a/test/simple_adjoints_test.jl b/test/simple_adjoints_test.jl new file mode 100644 index 00000000..329fd104 --- /dev/null +++ b/test/simple_adjoints_test.jl @@ -0,0 +1,17 @@ +using SymbolicIndexingInterface +using Zygote + +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +pstate = ProblemState(; u = rand(3), p = rand(3), t = rand()) + +getter = getu(sys, :x) +@test Zygote.gradient(getter, pstate)[1].u == [1.0, 0.0, 0.0] + +getter = getu(sys, [:x, :z]) +@test Zygote.gradient(sum ∘ getter, pstate)[1].u == [1.0, 0.0, 1.0] + +getter = getu(sys, :a) +@test Zygote.gradient(getter, pstate)[1].p == [1.0, 0.0, 0.0] + +getter = getu(sys, [:a, :c]) +@test Zygote.gradient(sum ∘ getter, pstate)[1].p == [1.0, 0.0, 1.0]