diff --git a/Project.toml b/Project.toml index f22c04ab..4ec9f49a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.12.0" +version = "1.13.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 170e07d3..fa339cf1 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -7,7 +7,8 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, any, all, axes, isone, iszero, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, map, zero, show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, - parent, similar, issorted, add_sum, accumulate, OneTo, permutedims + parent, similar, issorted, add_sum, accumulate, OneTo, permutedims, + real, imag, conj import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 94d763bd..dcc78be2 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -78,6 +78,13 @@ mult_zeros(a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = mult_zeros( mult_zeros(a, b) = mult_fill(a, b, mult_axes(a, b)) mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b)) +# scaling +*(a::AbstractFill, b::Number) = Fill(getindex_value(a) * b, axes(a)) +*(a::Number, b::AbstractFill) = Fill(a * getindex_value(b), axes(b)) +*(a::AbstractZeros, b::Number) = Zeros(typeof(getindex_value(a) * b), axes(a)) +*(a::Number, b::AbstractZeros) = Zeros(typeof(a * getindex_value(b)), axes(b)) + +# matmul *(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b) *(a::AbstractFillMatrix, b::AbstractFillVector) = mult_fill(a,b) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 2b5ea59c..115fc6ac 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -73,22 +73,100 @@ function mapreduce(f, op, A::AbstractFill, B::AbstractFill, Cs::AbstractArray... end -### Unary broadcasting +## BroadcastStyle + +abstract type AbstractFillStyle{N} <: Broadcast.AbstractArrayStyle{N} end +struct FillStyle{N} <: AbstractFillStyle{N} end +struct ZerosStyle{N} <: AbstractFillStyle{N} end +FillStyle{N}(::Val{M}) where {N,M} = FillStyle{M}() +ZerosStyle{N}(::Val{M}) where {N,M} = ZerosStyle{M}() +Broadcast.BroadcastStyle(::Type{<:AbstractFill{<:Any,N}}) where {N} = FillStyle{N}() +Broadcast.BroadcastStyle(::Type{<:AbstractZeros{<:Any,N}}) where {N} = ZerosStyle{N}() +Broadcast.BroadcastStyle(::FillStyle{M}, ::ZerosStyle{N}) where {M,N} = FillStyle{max(M,N)}() +Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{2}) = S +Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{1}) = S +Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{0}) = S + +_getindex_value(f::AbstractFill) = getindex_value(f) +_getindex_value(x::Number) = x +_getindex_value(x::Ref) = x[] +function _getindex_value(bc::Broadcast.Broadcasted) + bc.f(map(_getindex_value, bc.args)...) +end + +has_static_value(x) = false +has_static_value(x::Union{AbstractZeros, AbstractOnes}) = true +has_static_value(x::Broadcast.Broadcasted) = all(has_static_value, x.args) -function broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}) where {T,N} - return Fill(op(getindex_value(r)), axes(r)) +function _iszeros(bc::Broadcast.Broadcasted) + all(has_static_value, bc.args) && _iszero(_getindex_value(bc)) end +# conservative check for zeros. In most cases, there isn't a zero element to compare with +_iszero(x::Union{Number, AbstractArray}) = iszero(x) +_iszero(_) = false -broadcasted(::DefaultArrayStyle, ::typeof(+), r::AbstractZeros) = r -broadcasted(::DefaultArrayStyle, ::typeof(-), r::AbstractZeros) = r -broadcasted(::DefaultArrayStyle, ::typeof(+), r::AbstractOnes) = r +function _isones(bc::Broadcast.Broadcasted) + all(has_static_value, bc.args) && _isone(_getindex_value(bc)) +end +# conservative check for ones. In most cases, there isn't a unit element to compare with +_isone(x::Union{Number, AbstractArray}) = isone(x) +_isone(_) = false + +_isfill(bc::Broadcast.Broadcasted) = all(_isfill, bc.args) +_isfill(f::AbstractFill) = true +_isfill(f::Number) = true +_isfill(f::Ref) = true +_isfill(::Any) = false + +function _copy_fill(bc) + v = _getindex_value(bc) + if _iszeros(bc) + return Zeros(typeof(v), axes(bc)) + elseif _isones(bc) + return Ones(typeof(v), axes(bc)) + end + return Fill(v, axes(bc)) +end -broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::AbstractZeros{T,N}) where {T,N} = r -broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::AbstractOnes{T,N}) where {T,N} = r -broadcasted(::DefaultArrayStyle{N}, ::typeof(real), r::AbstractZeros{T,N}) where {T,N} = Zeros{real(T)}(axes(r)) -broadcasted(::DefaultArrayStyle{N}, ::typeof(real), r::AbstractOnes{T,N}) where {T,N} = Ones{real(T)}(axes(r)) -broadcasted(::DefaultArrayStyle{N}, ::typeof(imag), r::AbstractZeros{T,N}) where {T,N} = Zeros{real(T)}(axes(r)) -broadcasted(::DefaultArrayStyle{N}, ::typeof(imag), r::AbstractOnes{T,N}) where {T,N} = Zeros{real(T)}(axes(r)) +# recursively copy the purely fill components +function _preprocess_fill(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) + _isfill(bc) ? _copy_fill(bc) : Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) +end +_preprocess_fill(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) +_preprocess_fill(x) = x + +function _fallback_copy(bc) + # copy the purely fill components + bc2 = Base.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) + # fallback style + S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{ndims(bc)}} + copy(convert(S, bc2)) +end + +function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) + _isfill(bc) ? _copy_fill(bc) : _fallback_copy(bc) +end +# make the zero-dimensional case consistent with Base +Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) = _fallback_copy(bc) + +# some cases that preserve 0d +function broadcast_preserving_0d(f, As...) + bc = Base.broadcasted(f, As...) + r = copy(bc) + length(axes(bc)) == 0 ? Fill(r) : r +end +for f in (:real, :imag) + @eval ($f)(A::AbstractFill) = broadcast_preserving_0d($f, A) + @eval ($f)(A::AbstractZeros) = Zeros{real(eltype(A))}(axes(A)) +end +conj(A::AbstractFill) = broadcast_preserving_0d(conj, A) +conj(A::AbstractZeros) = A +real(A::AbstractOnes) = Ones{real(eltype(A))}(axes(A)) +imag(A::AbstractOnes) = Zeros{real(eltype(A))}(axes(A)) +conj(A::AbstractOnes) = A +real(A::AbstractFill{<:Real}) = A +imag(A::AbstractFill{<:Real}) = Zeros{eltype(A)}(axes(A)) +conj(A::AbstractFill{<:Real}) = A ### Binary broadcasting @@ -100,12 +178,6 @@ broadcasted_zeros(f, a, b, elt, ax) = Zeros{elt}(ax) broadcasted_ones(f, a, elt, ax) = Ones{elt}(ax) broadcasted_ones(f, a, b, elt, ax) = Ones{elt}(ax) -function broadcasted(::DefaultArrayStyle, op, a::AbstractFill, b::AbstractFill) - val = op(getindex_value(a), getindex_value(b)) - ax = broadcast_shape(axes(a), axes(b)) - return broadcasted_fill(op, a, b, val, ax) -end - function _broadcasted_zeros(f, a, b) elt = Base.Broadcast.combine_eltypes(f, (a, b)) ax = broadcast_shape(axes(a), axes(b)) @@ -122,57 +194,32 @@ function _broadcasted_nan(f, a, b) return broadcasted_fill(f, a, b, val, ax) end -broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(+, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractOnes, b::AbstractZeros) = _broadcasted_ones(+, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractZeros, b::AbstractOnes) = _broadcasted_ones(+, a, b) - -broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(-, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractOnes, b::AbstractZeros) = _broadcasted_ones(-, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractOnes, b::AbstractOnes) = _broadcasted_zeros(-, a, b) - -broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractZerosVector, b::AbstractZerosVector) = _broadcasted_zeros(+, a, b) -broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractOnesVector, b::AbstractZerosVector) = _broadcasted_ones(+, a, b) -broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractZerosVector, b::AbstractOnesVector) = _broadcasted_ones(+, a, b) - -broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::AbstractZerosVector, b::AbstractZerosVector) = _broadcasted_zeros(-, a, b) -broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::AbstractOnesVector, b::AbstractZerosVector) = _broadcasted_ones(-, a, b) - - -broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(*, a, b) - # In following, need to restrict to <: Number as otherwise we cannot infer zero from type # TODO: generalise to things like SVector for op in (:*, :/) @eval begin - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractOnes) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractFill{<:Number}) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::Number) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractArray{<:Number}) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractFill{<:Number}) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::Number) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractOnes) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractArray{<:Number}) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros($op, a, b) end end for op in (:*, :\) @eval begin - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractOnes, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractFill{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::Number, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractArray{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::Base.Broadcast.Broadcasted, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractOnes, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractFill{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::Number, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractArray{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::Base.Broadcast.Broadcasted, b::AbstractZeros) = _broadcasted_zeros($op, a, b) end end - -for op in (:*, :/, :\) - @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractOnes, b::AbstractOnes) = _broadcasted_ones($op, a, b) -end - -for op in (:/, :\) - @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros{<:Number}, b::AbstractZeros{<:Number}) = _broadcasted_nan($op, a, b) -end +broadcasted(::typeof(*), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(*, a, b) +broadcasted(::typeof(/), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(/, a, b) +broadcasted(::typeof(\), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(\, a, b) # special case due to missing converts for ranges _range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a @@ -205,13 +252,13 @@ _range_convert(::Type{AbstractVector{T}}, a::ZerosVector) where T = ZerosVector{ # end # end -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractOnesVector, b::AbstractRange) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractOnes, b::AbstractRange) broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) TT = typeof(zero(eltype(a)) * zero(eltype(b))) return _range_convert(AbstractVector{TT}, b) end -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractOnesVector) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractOnes) broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) TT = typeof(zero(eltype(a)) * zero(eltype(b))) return _range_convert(AbstractVector{TT}, a) @@ -219,51 +266,46 @@ end for op in (:+, :-) @eval begin - function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractVector, b::AbstractZerosVector) - broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) + function broadcasted(::typeof($op), a::AbstractVector, b::AbstractZerosVector) + ax = broadcast_shape(axes(a), axes(b)) + ax == axes(a) || throw(ArgumentError("cannot broadcast an array with size $(size(a)) with $b")) TT = typeof($op(zero(eltype(a)), zero(eltype(b)))) # Use `TT ∘ (+)` to fix AD issues with `broadcasted(TT, x)` eltype(a) === TT ? a : broadcasted(TT ∘ (+), a) end - function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZerosVector, b::AbstractVector) - broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $a to a Vector first.")) + function broadcasted(::typeof($op), a::AbstractZerosVector, b::AbstractVector) + ax = broadcast_shape(axes(a), axes(b)) + ax == axes(b) || throw(ArgumentError("cannot broadcast $a with an array with size $(size(b))")) TT = typeof($op(zero(eltype(a)), zero(eltype(b)))) $op === (+) && eltype(b) === TT ? b : broadcasted(TT ∘ ($op), b) end - - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractFillVector, b::AbstractZerosVector) = - Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof($op), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), $op, a, b) - - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZerosVector, b::AbstractFillVector) = - Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof($op), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), $op, a, b) + function broadcasted(::typeof($op), a::AbstractZerosVector, b::AbstractZerosVector) + ax = broadcast_shape(axes(a), axes(b)) + TT = typeof($op(zero(eltype(a)), zero(eltype(b)))) + Zeros(TT, ax) + end end end # Need to prevent array-valued fills from broadcasting over entry -_broadcast_getindex_value(a::AbstractFill{<:Number}) = getindex_value(a) -_broadcast_getindex_value(a::AbstractFill) = Ref(getindex_value(a)) - +_mayberef(x) = Ref(x) +_mayberef(x::Number) = x -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange) broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) - return broadcasted(*, _broadcast_getindex_value(a), b) + return broadcasted(*, _mayberef(getindex_value(a)), b) end -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill) broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) - return broadcasted(*, a, _broadcast_getindex_value(b)) + return broadcasted(*, a, _mayberef(getindex_value(b))) end -broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x), axes(r)) -broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x, getindex_value(r)), axes(r)) -broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x[]), axes(r)) -broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x[], getindex_value(r)), axes(r)) - # support AbstractFill .^ k -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractFill{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r)) -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractOnes{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r)) -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractZeros{T,N}, ::Base.RefValue{Val{0}}) where {T,N} = broadcasted_ones(op, r, T, axes(r)) -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractZeros{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_zeros(op, r, T, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractFill{T,N}, ::Val{k}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractOnes{T,N}, ::Val{k}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractZeros{T,N}, ::Val{0}) where {T,N} = broadcasted_ones(op, r, T, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractZeros{T,N}, ::Val{k}) where {T,N,k} = broadcasted_zeros(op, r, T, axes(r)) # supports structured broadcast if isdefined(LinearAlgebra, :fzero) diff --git a/test/runtests.jl b/test/runtests.jl index 0163df8a..4d5e37d6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -821,7 +821,7 @@ end @testset "maximum/minimum/svd/sort" begin @test maximum(Fill(1, 1_000_000_000)) == minimum(Fill(1, 1_000_000_000)) == 1 @test svdvals(fill(2,5,6)) ≈ svdvals(Fill(2,5,6)) - @test svdvals(Eye(5)) === Fill(1.0,5) + @test svdvals(Eye(5)) === Ones(5) @test sort(Ones(5)) == sort!(Ones(5)) @test_throws MethodError issorted(Fill(im, 2)) @@ -928,21 +928,21 @@ end rng = MersenneTwister(123456) sizes = [(5, 4), (5, 1), (1, 4), (1, 1), (5,)] - for sx in sizes, sy in sizes + @testset for sx in sizes, sy in sizes x, y = Fill(randn(rng), sx), Fill(randn(rng), sy) x_one, y_one = Ones(sx), Ones(sy) x_zero, y_zero = Zeros(sx), Zeros(sy) x_dense, y_dense = randn(rng, sx), randn(rng, sy) for x in [x, x_one, x_zero, x_dense], y in [y, y_one, y_zero, y_dense] - @test x .+ y == collect(x) .+ collect(y) + @test x .+ y ≈ collect(x) .+ collect(y) end @test x_zero .+ y_zero isa Zeros @test x_zero .+ y_one isa Ones @test x_one .+ y_zero isa Ones for x in [x, x_one, x_zero, x_dense], y in [y, y_one, y_zero, y_dense] - @test x .* y == collect(x) .* collect(y) + @test x .* y ≈ collect(x) .* collect(y) end for x in [x, x_one, x_zero, x_dense] @test x .* y_zero isa Zeros @@ -1062,7 +1062,7 @@ end @test_throws DimensionMismatch Zeros{Int}(2) .+ (1:5) @test_throws DimensionMismatch (1:5) .+ Zeros{Int}(2) - for v in (rand(Bool, 5), [1:5;], SVector{5}(1:5), SVector{5,ComplexF16}(1:5)), T in (Bool, Int, Float64) + @testset "$(typeof(v)) $T" for v in (rand(Bool, 5), [1:5;], SVector{5}(1:5), SVector{5,ComplexF16}(1:5)), T in (Bool, Int, Float64) TT = eltype(v + zeros(T, 5)) S = v isa SVector ? SVector{5,TT} : Vector{TT} @@ -1115,7 +1115,7 @@ end @testset "issue #208" begin TS = (Bool, Int, Float32, Float64) - for S in TS, T in TS + @testset for S in TS, T in TS u = rand(S, 2) v = Zeros(T, 2) if zero(S) + zero(T) isa S @@ -1148,6 +1148,75 @@ end end end end + + @testset "Zeros to Fill" begin + @test @inferred((f -> ((x -> (1,)).(f)))((Zeros(4)))) == Fill((1,), 4) + @test @inferred((f -> ((x -> Val(1)).(f)))((Zeros(4)))) == Fill(Val(1), 4) + end + + @testset "multi-element broadcast" begin + x = Fill(2, 2) + y = @. 2 * x * 2 + @test y === Fill(8, 2) + end + + @testset "nested broadcast" begin + bc = Broadcast.broadcasted(*, Zeros(4), Ones(4), Broadcast.broadcasted(*, Zeros(4), Ones(4), Zeros(4))) + @test copy(bc) === Zeros(4) + end + + @testset "0d" begin + @test real.(Fill(2)) == real.(fill(2)) + @test (@. 2 * Fill(2) * 2) == (@. 2 * fill(2) * 2) + for (F, A) in ((Fill(2), fill(2)), (Zeros(), zeros()), (Ones(), ones())) + @test F * 2 == A * 2 + @test 2 * F == 2 * A + end + end + + @testset "preserve 0d" begin + @testset for f in (real, imag, conj), (F, A) in ( + (Fill(4), fill(4)), + (Fill(4 + 5im), fill(4 + 5im)), + (Fill(SMatrix{2,2,ComplexF64,4}(fill(4 + 5im, 4))), fill(SMatrix{2,2,ComplexF64,4}(fill(4 + 5im, 4)))), + (Zeros{ComplexF64}(), zeros(ComplexF64)), + (Zeros(), zeros()), + (Ones(), ones()), + (Ones{ComplexF64}(), ones(ComplexF64)), + ) + x = f(F) + y = f(A) + @test x == y + @test eltype(x) == eltype(y) + @test x isa FillArrays.AbstractFill + if F isa Ones + if f === imag + @test x isa Zeros + else + @test x isa Ones + end + end + if F[] isa Real + if f === imag + @test x isa Zeros + end + end + end + end + + @testset "issue #40" begin + f(x) = x + g(x, y) = x + F = Fill(1, 2) + @test g.(F, "a") === f.(F) + end + + @testset "early binding" begin + A = ones(2) .+ (x -> rand()).(Fill(2,2)) + @test all(==(A[1]), A) + A = ones(1,5) .+ (ones(1) .+ (_ -> rand()).(Fill("vec", 2))) + @test all(==(A[1]), A) + end end @testset "map" begin @@ -1156,7 +1225,7 @@ end @test map(isone,x1) === Fill(true,5) x0 = Zeros(5) - @test map(exp,x0) === exp.(x0) + @test map(exp,x0) == exp.(x0) x2 = Fill(2,5,3) @test map(exp,x2) === Fill(exp(2),5,3) @@ -2149,8 +2218,10 @@ end @test D - Zeros(5,5) isa Diagonal @test D .+ Zeros(5,5) isa Diagonal @test D .- Zeros(5,5) isa Diagonal - @test D .* Zeros(5,5) isa Diagonal - @test Zeros(5,5) .* D isa Diagonal + @test D .* Zeros(5,5) isa FillArrays.ZerosMatrix + @test ((x,y) -> x * y).(D, Zeros(5,5)) isa Diagonal + @test Zeros(5,5) .* D isa FillArrays.ZerosMatrix + @test ((x,y) -> x * y).(Zeros(5,5), D) isa Diagonal @test Zeros(5,5) - D isa Diagonal @test Zeros(5,5) + D isa Diagonal @test Zeros(5,5) .- D isa Diagonal