diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index 6bfef84229..04a633198b 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -16,7 +16,7 @@ struct StaticWOperator{isinv, T} end end isinv(W::StaticWOperator{S}) where {S} = S -Base.:\(W::StaticWOperator, v) = isinv(W) ? W.W * v : W.W \ v +Base.:\(W::StaticWOperator, v::AbstractArray) = isinv(W) ? W.W * v : W.W \ v function calc_tderivative!(integrator, cache, dtd1, repeat_step) @inbounds begin @@ -355,7 +355,8 @@ function Base.convert(::Type{Number}, W::WOperator) end return W._concrete_form end -Base.size(W::WOperator, args...) = size(W.J, args...) +Base.size(W::WOperator) = size(W.J) +Base.size(W::WOperator, d::Integer) = d <= 2 ? size(W)[d] : 1 function Base.getindex(W::WOperator, i::Int) if W.transform -W.mass_matrix[i] / W.gamma + W.J[i] @@ -370,7 +371,14 @@ function Base.getindex(W::WOperator, I::Vararg{Int, N}) where {N} -W.mass_matrix[I...] + W.gamma * W.J[I...] end end -function Base.:*(W::WOperator, x::Union{AbstractVecOrMat, Number}) +function Base.:*(W::WOperator, x::AbstractVecOrMat) + if W.transform + (W.mass_matrix * x) / -W.gamma + W.J * x + else + -W.mass_matrix * x + W.gamma * (W.J * x) + end +end +function Base.:*(W::WOperator, x::Number) if W.transform (W.mass_matrix * x) / -W.gamma + W.J * x else @@ -827,6 +835,9 @@ end function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, ::Val{IIP}) where {IIP, uEltypeNoUnits, F} + # TODO - make J, W AbstractSciMLOperators (lazily defined with scimlops functionality) + # TODO - if jvp given, make it SciMLOperators.FunctionOperator + # TODO - make mass matrix a SciMLOperator so it can be updated with time. Default to IdentityOperator islin, isode = islinearfunction(f, alg) if f.jac_prototype isa DiffEqBase.AbstractDiffEqLinearOperator W = WOperator{IIP}(f, u, dt)