From 3c9d1b7d4b820260aa6acb1e12f6d3ab24f6f25e Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sat, 19 Oct 2024 18:44:09 +0200 Subject: [PATCH] Add test with `accepted_kwargs` as `Val` type --- test/func.jl | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/test/func.jl b/test/func.jl index d19e0a5..33acf15 100644 --- a/test/func.jl +++ b/test/func.jl @@ -305,5 +305,57 @@ end @test v2 ≈ a2 * A * u2 + b2 * w2 @test v1 ≈ a1 * A * u1 + b1 * w1 @test v1 + v2 ≈ (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2) + + ## Do the same with Val((:scale,)) + + L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true, + accepted_kwargs = Val((:scale,)), scale = 1.0) + + @test_throws ArgumentError FunctionOperator( + f, u, u; p = zero(p), t = zero(t), batch = true, + accepted_kwargs = Val((:scale,))) + + @test size(L) == (N, N) + + ans = @. u * p * t * scale + @test L(u, p, t; scale) ≈ ans + v = copy(u) + @test L(v, u, p, t; scale) ≈ ans + + # test that output isn't accidentally mutated by passing an internal cache. + + A = Diagonal(p * t * scale) + u1 = rand(N, K) + u2 = rand(N, K) + + v1 = L * u1 + @test v1 ≈ A * u1 + v2 = L * u2 + @test v2 ≈ A * u2 + @test v1 ≈ A * u1 + @test v1 + v2 ≈ A * (u1 + u2) + + v1 .= 0.0 + v2 .= 0.0 + + mul!(v1, L, u1) + @test v1 ≈ A * u1 + mul!(v2, L, u2) + @test v2 ≈ A * u2 + @test v1 ≈ A * u1 + @test v1 + v2 ≈ A * (u1 + u2) + + v1 = rand(N, K) + w1 = copy(v1) + v2 = rand(N, K) + w2 = copy(v2) + a1, a2, b1, b2 = rand(4) + + mul!(v1, L, u1, a1, b1) + @test v1 ≈ a1 * A * u1 + b1 * w1 + mul!(v2, L, u2, a2, b2) + @test v2 ≈ a2 * A * u2 + b2 * w2 + @test v1 ≈ a1 * A * u1 + b1 * w1 + @test v1 + v2 ≈ (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2) end #