Skip to content

Commit

Permalink
Add test with accepted_kwargs as Val type
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio committed Oct 19, 2024
1 parent 531ede2 commit 3c9d1b7
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions test/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,5 +305,57 @@ end
@test v2 a2 * A * u2 + b2 * w2
@test v1 a1 * A * u1 + b1 * w1
@test v1 + v2 (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2)

## Do the same with Val((:scale,))

L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true,
accepted_kwargs = Val((:scale,)), scale = 1.0)

@test_throws ArgumentError FunctionOperator(
f, u, u; p = zero(p), t = zero(t), batch = true,
accepted_kwargs = Val((:scale,)))

@test size(L) == (N, N)

ans = @. u * p * t * scale
@test L(u, p, t; scale) ans
v = copy(u)
@test L(v, u, p, t; scale) ans

# test that output isn't accidentally mutated by passing an internal cache.

A = Diagonal(p * t * scale)
u1 = rand(N, K)
u2 = rand(N, K)

v1 = L * u1
@test v1 A * u1
v2 = L * u2
@test v2 A * u2
@test v1 A * u1
@test v1 + v2 A * (u1 + u2)

v1 .= 0.0
v2 .= 0.0

mul!(v1, L, u1)
@test v1 A * u1
mul!(v2, L, u2)
@test v2 A * u2
@test v1 A * u1
@test v1 + v2 A * (u1 + u2)

v1 = rand(N, K)
w1 = copy(v1)
v2 = rand(N, K)
w2 = copy(v2)
a1, a2, b1, b2 = rand(4)

mul!(v1, L, u1, a1, b1)
@test v1 a1 * A * u1 + b1 * w1
mul!(v2, L, u2, a2, b2)
@test v2 a2 * A * u2 + b2 * w2
@test v1 a1 * A * u1 + b1 * w1
@test v1 + v2 (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2)
end
#

0 comments on commit 3c9d1b7

Please sign in to comment.