Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

concretize methods #205

Merged
merged 14 commits into from
Jul 18, 2023
2 changes: 2 additions & 0 deletions src/SciMLOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ export update_coefficients!,

issquare,
islinear,
concretize,
isconvertible,

has_adjoint,
has_expmv,
Expand Down
1 change: 1 addition & 0 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@

getops(L::InvertedOperator) = (L.L,)
islinear(L::InvertedOperator) = islinear(L.L)
isconvertible(::InvertedOperator) = false

Check warning on line 764 in src/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/basic.jl#L764

Added line #L764 was not covered by tests

has_mul(L::InvertedOperator) = has_ldiv(L.L)
has_mul!(L::InvertedOperator) = has_ldiv!(L.L)
Expand Down
7 changes: 7 additions & 0 deletions src/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@
)
end

function Base.convert(::Type{AbstractMatrix}, L::BatchedDiagonalOperator)
m, n = size(L)
msg = """$L cannot be represented by an $m × $n AbstractMatrix"""
throw(ArgumentError(msg))

Check warning on line 73 in src/batch.jl

View check run for this annotation

Codecov / codecov/patch

src/batch.jl#L70-L73

Added lines #L70 - L73 were not covered by tests
end

LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true
function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
if isreal(L)
Expand All @@ -91,6 +97,7 @@
update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!)
end
islinear(::BatchedDiagonalOperator) = true
isconvertible(::BatchedDiagonalOperator) = false

Check warning on line 100 in src/batch.jl

View check run for this annotation

Codecov / codecov/patch

src/batch.jl#L100

Added line #L100 was not covered by tests
has_adjoint(L::BatchedDiagonalOperator) = true
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)
Expand Down
6 changes: 6 additions & 0 deletions src/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
* `has_mul5` - `true` if the operator provides a five-argument `mul!` via the signature `op(v, u, p, t, α, β; <accepted_kwargs>)`. This trait is inferred if no value is provided.
* `isconstant` - `true` if the operator is constant, and doesn't need to be updated via `update_coefficients[!]` during operator evaluation.
* `islinear` - `true` if the operator is linear. Defaults to `false`.
* `isconvertible` - `true` a cheap `convert(AbstractMatrix, L.op)` method is available. Defaults to `false`.
* `batch` - Boolean indicating if the input/output arrays comprise of batched column-vectors stacked in a matrix. If `true`, the input/output arrays must be `AbstractVecOrMat`s, and the length of the second dimension (the batch dimension) must be the same. The batch dimension is not involved in size computation. For example, with `batch = true`, and `size(output), size(input) = (M, K), (N, K)`, the `FunctionOperator` size is set to `(M, N)`. If `batch = false`, which is the default, the `input`/`output` arrays may of any size so long as `ndims(input) == ndims(output)`, and the `size` of `FunctionOperator` is set to `(length(input), length(output))`.
* `ifcache` - Allocate cache arrays in constructor. Defaults to `true`. Cache can be generated afterwards by calling `cache_operator(L, input, output)`
* `cache` - Pregenerated cache arrays for in-place evaluations. Expected to be of type and shape `(similar(input), similar(output),)`. The constructor generates cache if no values are provided. Cache generation by the constructor can be disabled by setting the kwarg `ifcache = false`.
Expand Down Expand Up @@ -138,6 +139,7 @@
has_mul5::Union{Nothing,Bool}=nothing,
isconstant::Bool = false,
islinear::Bool = false,
isconvertible::Bool = false,

batch::Bool = false,
ifcache::Bool = true,
Expand Down Expand Up @@ -248,6 +250,7 @@

traits = (;
islinear = islinear,
isconvertible = isconvertible,
isconstant = isconstant,

opnorm = opnorm,
Expand Down Expand Up @@ -480,6 +483,8 @@
)
end

Base.convert(::Type{AbstractMatrix}, L::FunctionOperator) = convert(AbstractMatrix, L.op)

Check warning on line 486 in src/func.jl

View check run for this annotation

Codecov / codecov/patch

src/func.jl#L486

Added line #L486 was not covered by tests

function Base.resize!(L::FunctionOperator, n::Integer)

# input/output to `L` must be `AbstractVector`s
Expand Down Expand Up @@ -526,6 +531,7 @@
end

islinear(L::FunctionOperator) = L.traits.islinear
isconvertible(L::FunctionOperator) = L.traits.isconvertible

Check warning on line 534 in src/func.jl

View check run for this annotation

Codecov / codecov/patch

src/func.jl#L534

Added line #L534 was not covered by tests
isconstant(L::FunctionOperator) = L.traits.isconstant
has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing)
has_mul(::FunctionOperator{iip}) where{iip} = true
Expand Down
86 changes: 76 additions & 10 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,37 +187,53 @@
Base.iszero(::AbstractSciMLOperator) = false # TODO

"""
$SIGNATURES

Check if `adjoint(L)` is lazily defined.
"""
has_adjoint(L::AbstractSciMLOperator) = false # L', adjoint(L)
"""
$SIGNATURES

Check if `expmv!(v, L, u, t)`, equivalent to `mul!(v, exp(t * A), u)`, is
defined for `Number` `t`, and `AbstractArray`s `u, v` of appropriate sizes.
"""
has_expmv!(L::AbstractSciMLOperator) = false # expmv!(v, L, t, u)
"""
$SIGNATURES

Check if `expmv(L, u, t)`, equivalent to `exp(t * A) * u`, is defined for
`Number` `t`, and `AbstractArray` `u` of appropriate size.
"""
has_expmv(L::AbstractSciMLOperator) = false # v = exp(L, t, u)
"""
$SIGNATURES

Check if `exp(L)` is defined lazily defined.
"""
has_exp(L::AbstractSciMLOperator) = islinear(L)
"""
$SIGNATURES

Check if `L * u` is defined for `AbstractArray` `u` of appropriate size.
"""
has_mul(L::AbstractSciMLOperator) = true # du = L*u
"""
$SIGNATURES

Check if `mul!(v, L, u)` is defined for `AbstractArray`s `u, v` of
appropriate sizes.
"""
has_mul!(L::AbstractSciMLOperator) = true # mul!(du, L, u)
"""
$SIGNATURES

Check if `L \\ u` is defined for `AbstractArray` `u` of appropriate size.
"""
has_ldiv(L::AbstractSciMLOperator) = false # du = L\u
"""
$SIGNATURES

Check if `ldiv!(v, L, u)` is defined for `AbstractArray`s `u, v` of
appropriate sizes.
"""
Expand All @@ -244,7 +260,57 @@
) = true
isconstant(L::AbstractSciMLOperator) = all(isconstant, getops(L))

#islinear(L) = false
"""
isconvertible(L) -> Bool

Checks if `L` can be cheaply converted to an `AbstractMatrix` via eager fusion.
"""
isconvertible(L::AbstractSciMLOperator) = all(isconvertible, getops(L))

Check warning on line 268 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L268

Added line #L268 was not covered by tests

isconvertible(::Union{
# LinearAlgebra
AbstractMatrix,
UniformScaling,
Factorization,

# Base
Number,

# SciMLOperators
AbstractSciMLScalarOperator,
}
) = true

Check warning on line 282 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L282

Added line #L282 was not covered by tests

"""
concretize(L) -> AbstractMatrix

concretize(L) -> Number

Convert `SciMLOperator` to a concrete type via eager fusion. This method is a
no-op for types that are already concrete.
"""
concretize(L::Union{
# LinearAlgebra
AbstractMatrix,
Factorization,

# SciMLOperators
AbstractSciMLOperator,
}
) = convert(AbstractMatrix, L)
Comment on lines +294 to +300
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
AbstractMatrix,
Factorization,
# SciMLOperators
AbstractSciMLOperator,
}
) = convert(AbstractMatrix, L)
AbstractArray,
Factorization,
# SciMLOperators
AbstractSciMLOperator,
}
) = convert(AbstractArray, L)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should that be extended to array?

Copy link
Member Author

@vpuri3 vpuri3 Jul 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because, an ND array cannot be used as an operator, i.e. you can't mul! or * it to another array.

julia> ones(4,4,4) * ones(4)
ERROR: MethodError: no method matching *(::Array{Float64, 3}, ::Vector{Float64})

So we strictly want to return AbstractMatrix types

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And, in the package, I've defined convert(::Type{AbstractMatrix}, L::MyOperator) methods for each operator type like

Base.convert(::Type{AbstractMatrix}, ii::IdentityOperator) = Diagonal(ones(Bool, ii.len))

So calling convert with AbstractArray would lead to errors:

julia> convert(AbstractArray, DiagonalOperator(ones(4)))                                                                                                                                                 [0/110]
ERROR: MethodError: Cannot `convert` an object of type 
  MatrixOperator{Float64, LinearAlgebra.Diagonal{Float64, Vector{Float64}}, SciMLOperators.FilterKwargs{typeof(SciMLOperators.DEFAULT_UPDATE_FUNC), Tuple{}}, SciMLOperators.FilterKwargs{typeof(SciMLOperators.DEFAULT_UPDATE_FUNC), Tuple{}}} to an object of type  
  AbstractArray


concretize(L::Union{
# LinearAlgebra
UniformScaling,

# Base
Number,

# SciMLOperators
AbstractSciMLScalarOperator,
}
) = convert(Number, L)

Check warning on line 312 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L312

Added line #L312 was not covered by tests

"""
$SIGNATURES

Expand Down Expand Up @@ -349,22 +415,22 @@
function Base.conj(L::AbstractSciMLOperator)
isreal(L) && return L
@warn """using convert-based fallback for Base.conj"""
convert(AbstractMatrix, L) |> conj
concretize(L) |> conj

Check warning on line 418 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L418

Added line #L418 was not covered by tests
end

function Base.:(==)(L1::AbstractSciMLOperator, L2::AbstractSciMLOperator)
@warn """using convert-based fallback for Base.=="""
size(L1) != size(L2) && return false
convert(AbstractMatrix, L1) == convert(AbstractMatrix, L1)
concretize(L1) == concretize(L2)

Check warning on line 424 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L424

Added line #L424 was not covered by tests
end

Base.@propagate_inbounds function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Any,N}) where {N}
@warn """using convert-based fallback for Base.getindex"""
convert(AbstractMatrix, L)[I...]
concretize(L)[I...]

Check warning on line 429 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L429

Added line #L429 was not covered by tests
end
function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Int, N}) where {N}
@warn """using convert-based fallback for Base.getindex"""
convert(AbstractMatrix, L)[I...]
concretize(L)[I...]
end

function Base.resize!(L::AbstractSciMLOperator, n::Integer)
Expand All @@ -375,15 +441,15 @@

function LinearAlgebra.opnorm(L::AbstractSciMLOperator, p::Real=2)
@warn """using convert-based fallback in LinearAlgebra.opnorm."""
opnorm(convert(AbstractMatrix, L), p)
opnorm(concretize(L), p)

Check warning on line 444 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L444

Added line #L444 was not covered by tests
end

for op in (
:sum, :prod,
)
@eval function Base.$op(L::AbstractSciMLOperator; kwargs...)
@warn """using convert-based fallback in $($op)."""
$op(convert(AbstractMatrix, L); kwargs...)
$op(concretize(L); kwargs...)
end
end

Expand All @@ -394,17 +460,17 @@
)
@eval function LinearAlgebra.$pred(L::AbstractSciMLOperator)
@warn """using convert-based fallback in $($pred)."""
$pred(convert(AbstractMatrix, L))
$pred(concretize(L))

Check warning on line 463 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L463

Added line #L463 was not covered by tests
end
end

function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray)
@warn """using convert-based fallback in mul!."""
mul!(v, convert(AbstractMatrix, L), u)
mul!(v, concretize(L), u)

Check warning on line 469 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L469

Added line #L469 was not covered by tests
end

function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray, α, β)
@warn """using convert-based fallback in mul!."""
mul!(v, convert(AbstractMatrix, L), u, α, β)
mul!(v, concretize(L), u, α, β)

Check warning on line 474 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L474

Added line #L474 was not covered by tests
end
#
12 changes: 11 additions & 1 deletion src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@
has_ldiv,
has_ldiv!,
)

isconvertible(::MatrixOperator) = true

Check warning on line 107 in src/matrix.jl

View check run for this annotation

Codecov / codecov/patch

src/matrix.jl#L107

Added line #L107 was not covered by tests
islinear(::MatrixOperator) = true

function Base.show(io::IO, L::MatrixOperator)
Expand Down Expand Up @@ -162,7 +164,7 @@

# TODO - add tests for MatrixOperator indexing
# propagate_inbounds here for the getindex fallback
Base.@propagate_inbounds Base.convert(::Type{AbstractMatrix}, L::MatrixOperator) = L.A
Base.@propagate_inbounds Base.convert(::Type{AbstractMatrix}, L::MatrixOperator) = convert(AbstractMatrix, L.A)
Base.@propagate_inbounds Base.setindex!(L::MatrixOperator, v, i::Int) = (L.A[i] = v)
Base.@propagate_inbounds Base.setindex!(L::MatrixOperator, v, I::Vararg{Int, N}) where{N} = (L.A[I...] = v)

Expand Down Expand Up @@ -322,6 +324,7 @@

getops(L::InvertibleOperator) = (L.L, L.F,)
islinear(L::InvertibleOperator) = islinear(L.L)
isconvertible(L::InvertibleOperator) = isconvertible(L.L)

Check warning on line 327 in src/matrix.jl

View check run for this annotation

Codecov / codecov/patch

src/matrix.jl#L327

Added line #L327 was not covered by tests

@forward InvertibleOperator.L (
# LinearAlgebra
Expand Down Expand Up @@ -510,6 +513,7 @@
getops(L::AffineOperator) = (L.A, L.B, L.b)

islinear(::AffineOperator) = false
isconvertible(::AffineOperator) = false

Check warning on line 516 in src/matrix.jl

View check run for this annotation

Codecov / codecov/patch

src/matrix.jl#L516

Added line #L516 was not covered by tests

function Base.show(io::IO, L::AffineOperator)
show(io, L.A)
Expand Down Expand Up @@ -537,6 +541,12 @@
L
end

function Base.convert(::Type{AbstractMatrix}, L::AffineOperator)
m, n = size(L)
msg = """$L cannot be represented by an $m × $n AbstractMatrix"""
throw(ArgumentError(msg))

Check warning on line 547 in src/matrix.jl

View check run for this annotation

Codecov / codecov/patch

src/matrix.jl#L544-L547

Added lines #L544 - L547 were not covered by tests
end

has_adjoint(L::AffineOperator) = false
has_mul(L::AffineOperator) = has_mul(L.A)
has_mul!(L::AffineOperator) = has_mul!(L.A)
Expand Down
1 change: 1 addition & 0 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Base.transpose(α::AbstractSciMLScalarOperator) = α

has_mul!(::AbstractSciMLScalarOperator) = true
isconcrete(::AbstractSciMLScalarOperator) = true

Check warning on line 35 in src/scalar.jl

View check run for this annotation

Codecov / codecov/patch

src/scalar.jl#L35

Added line #L35 was not covered by tests
islinear(::AbstractSciMLScalarOperator) = true
has_adjoint(::AbstractSciMLScalarOperator) = true

Expand Down
1 change: 1 addition & 0 deletions src/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@

getops(L::TensorProductOperator) = L.ops
islinear(L::TensorProductOperator) = reduce(&, islinear.(L.ops))
isconvertible(::TensorProductOperator) = false

Check warning on line 124 in src/tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/tensor.jl#L124

Added line #L124 was not covered by tests
Base.iszero(L::TensorProductOperator) = reduce(|, iszero.(L.ops))
has_adjoint(L::TensorProductOperator) = reduce(&, has_adjoint.(L.ops))
has_mul(L::TensorProductOperator) = reduce(&, has_mul.(L.ops))
Expand Down