From fc1b1ed94bfdd83e826bdc13c7557ac78693de82 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 13:36:50 +0530 Subject: [PATCH 01/10] Add a BroadcastStyle for AbstractFill --- src/FillArrays.jl | 3 +- src/fillbroadcast.jl | 197 +++++++++++++++++++++++++------------------ test/runtests.jl | 46 ++++++++-- 3 files changed, 152 insertions(+), 94 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 170e07d3..fa339cf1 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -7,7 +7,8 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, any, all, axes, isone, iszero, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, map, zero, show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, - parent, similar, issorted, add_sum, accumulate, OneTo, permutedims + parent, similar, issorted, add_sum, accumulate, OneTo, permutedims, + real, imag, conj import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 2b5ea59c..426420f9 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -73,22 +73,79 @@ function mapreduce(f, op, A::AbstractFill, B::AbstractFill, Cs::AbstractArray... end -### Unary broadcasting +## BroadcastStyle + +abstract type AbstractFillStyle{N} <: Broadcast.AbstractArrayStyle{N} end +struct FillStyle{N} <: AbstractFillStyle{N} end +struct ZerosStyle{N} <: AbstractFillStyle{N} end +FillStyle{N}(::Val{M}) where {N,M} = FillStyle{M}() +ZerosStyle{N}(::Val{M}) where {N,M} = ZerosStyle{M}() +Broadcast.BroadcastStyle(::Type{<:AbstractFill{<:Any,N}}) where {N} = FillStyle{N}() +Broadcast.BroadcastStyle(::Type{<:AbstractZeros{<:Any,N}}) where {N} = ZerosStyle{N}() +Broadcast.BroadcastStyle(::FillStyle{M}, ::ZerosStyle{N}) where {M,N} = FillStyle{max(M,N)}() +Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{2}) = S +Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{1}) = S +Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{0}) = S + +_getindex_value(f::AbstractFill) = getindex_value(f) +_getindex_value(x::Number) = x +_getindex_value(x::Ref) = x[] +function _getindex_value(bc::Broadcast.Broadcasted) + bc.f(map(_getindex_value, bc.args)...) +end + +has_static_value(x) = false +has_static_value(x::Union{AbstractZeros, AbstractOnes}) = true +has_static_value(x::Broadcast.Broadcasted) = all(has_static_value, x.args) -function broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}) where {T,N} - return Fill(op(getindex_value(r)), axes(r)) +function _iszeros(bc::Broadcast.Broadcasted) + all(has_static_value, bc.args) && _iszero(_getindex_value(bc)) end +# conservative check for zeros. In most cases we can't really compare with zero +_iszero(x::Union{Number, AbstractArray}) = iszero(x) +_iszero(_) = false -broadcasted(::DefaultArrayStyle, ::typeof(+), r::AbstractZeros) = r -broadcasted(::DefaultArrayStyle, ::typeof(-), r::AbstractZeros) = r -broadcasted(::DefaultArrayStyle, ::typeof(+), r::AbstractOnes) = r +function _isones(bc::Broadcast.Broadcasted) + all(has_static_value, bc.args) && _isone(_getindex_value(bc)) +end +# conservative check for ones. In most cases we can't really compare with one +_isone(x::Union{Number, AbstractArray}) = isone(x) +_isone(_) = false + +_isfill(bc::Broadcast.Broadcasted) = all(_isfill, bc.args) +_isfill(f::AbstractFill) = true +_isfill(f::Number) = true +_isfill(f::Ref) = true +_isfill(::Any) = false + +function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{N}}) where {N} + if _iszeros(bc) + return Zeros(typeof(_getindex_value(bc)), axes(bc)) + elseif _isones(bc) + return Ones(typeof(_getindex_value(bc)), axes(bc)) + elseif _isfill(bc) + return Fill(_getindex_value(bc), axes(bc)) + else + # fallback style + S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{N}} + copy(convert(S, bc)) + end +end +# make the zero-dimensional case consistent with Base +function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) + S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} + copy(convert(S, bc)) +end -broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::AbstractZeros{T,N}) where {T,N} = r -broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::AbstractOnes{T,N}) where {T,N} = r -broadcasted(::DefaultArrayStyle{N}, ::typeof(real), r::AbstractZeros{T,N}) where {T,N} = Zeros{real(T)}(axes(r)) -broadcasted(::DefaultArrayStyle{N}, ::typeof(real), r::AbstractOnes{T,N}) where {T,N} = Ones{real(T)}(axes(r)) -broadcasted(::DefaultArrayStyle{N}, ::typeof(imag), r::AbstractZeros{T,N}) where {T,N} = Zeros{real(T)}(axes(r)) -broadcasted(::DefaultArrayStyle{N}, ::typeof(imag), r::AbstractOnes{T,N}) where {T,N} = Zeros{real(T)}(axes(r)) +# some cases that preserve 0d +function broadcast_preserving_0d(f, As...) + bc = Base.broadcasted(f, As...) + r = copy(bc) + length(axes(bc)) == 0 ? Fill(r) : r +end +for f in (:real, :imag, :conj) + @eval ($f)(A::AbstractFill) = broadcast_preserving_0d($f, A) +end ### Binary broadcasting @@ -100,12 +157,6 @@ broadcasted_zeros(f, a, b, elt, ax) = Zeros{elt}(ax) broadcasted_ones(f, a, elt, ax) = Ones{elt}(ax) broadcasted_ones(f, a, b, elt, ax) = Ones{elt}(ax) -function broadcasted(::DefaultArrayStyle, op, a::AbstractFill, b::AbstractFill) - val = op(getindex_value(a), getindex_value(b)) - ax = broadcast_shape(axes(a), axes(b)) - return broadcasted_fill(op, a, b, val, ax) -end - function _broadcasted_zeros(f, a, b) elt = Base.Broadcast.combine_eltypes(f, (a, b)) ax = broadcast_shape(axes(a), axes(b)) @@ -122,57 +173,40 @@ function _broadcasted_nan(f, a, b) return broadcasted_fill(f, a, b, val, ax) end -broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(+, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractOnes, b::AbstractZeros) = _broadcasted_ones(+, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractZeros, b::AbstractOnes) = _broadcasted_ones(+, a, b) - -broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(-, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractOnes, b::AbstractZeros) = _broadcasted_ones(-, a, b) -broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractOnes, b::AbstractOnes) = _broadcasted_zeros(-, a, b) - -broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractZerosVector, b::AbstractZerosVector) = _broadcasted_zeros(+, a, b) -broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractOnesVector, b::AbstractZerosVector) = _broadcasted_ones(+, a, b) -broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractZerosVector, b::AbstractOnesVector) = _broadcasted_ones(+, a, b) - -broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::AbstractZerosVector, b::AbstractZerosVector) = _broadcasted_zeros(-, a, b) -broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::AbstractOnesVector, b::AbstractZerosVector) = _broadcasted_ones(-, a, b) - - -broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(*, a, b) - # In following, need to restrict to <: Number as otherwise we cannot infer zero from type # TODO: generalise to things like SVector for op in (:*, :/) @eval begin - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractOnes) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractFill{<:Number}) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::Number) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractArray{<:Number}) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractFill{<:Number}) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::Number) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractOnes) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::AbstractArray{<:Number}) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractZeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros($op, a, b) end end for op in (:*, :\) @eval begin - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractOnes, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractFill{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::Number, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractArray{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle, ::typeof($op), a::Base.Broadcast.Broadcasted, b::AbstractZeros) = _broadcasted_zeros($op, a, b) - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractOnes, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractFill{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::Number, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::AbstractArray{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b) + broadcasted(::typeof($op), a::Base.Broadcast.Broadcasted, b::AbstractZeros) = _broadcasted_zeros($op, a, b) end end +broadcasted(::typeof(*), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(*, a, b) +broadcasted(::typeof(/), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(/, a, b) +broadcasted(::typeof(\), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(\, a, b) -for op in (:*, :/, :\) - @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractOnes, b::AbstractOnes) = _broadcasted_ones($op, a, b) -end +# for op in (:*, :/, :\) +# @eval broadcasted(::typeof($op), a::AbstractOnes, b::AbstractOnes) = _broadcasted_ones($op, a, b) +# end -for op in (:/, :\) - @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros{<:Number}, b::AbstractZeros{<:Number}) = _broadcasted_nan($op, a, b) -end +# for op in (:/, :\) +# @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros{<:Number}, b::AbstractZeros{<:Number}) = _broadcasted_nan($op, a, b) +# end # special case due to missing converts for ranges _range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a @@ -205,13 +239,13 @@ _range_convert(::Type{AbstractVector{T}}, a::ZerosVector) where T = ZerosVector{ # end # end -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractOnesVector, b::AbstractRange) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractOnes, b::AbstractRange) broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) TT = typeof(zero(eltype(a)) * zero(eltype(b))) return _range_convert(AbstractVector{TT}, b) end -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractOnesVector) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractOnes) broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) TT = typeof(zero(eltype(a)) * zero(eltype(b))) return _range_convert(AbstractVector{TT}, a) @@ -219,51 +253,46 @@ end for op in (:+, :-) @eval begin - function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractVector, b::AbstractZerosVector) - broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) + function broadcasted(::typeof($op), a::AbstractVector, b::AbstractZerosVector) + ax = broadcast_shape(axes(a), axes(b)) + ax == axes(a) || throw(ArgumentError("cannot broadcast an array with size $(size(a)) with $b")) TT = typeof($op(zero(eltype(a)), zero(eltype(b)))) # Use `TT ∘ (+)` to fix AD issues with `broadcasted(TT, x)` eltype(a) === TT ? a : broadcasted(TT ∘ (+), a) end - function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZerosVector, b::AbstractVector) - broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $a to a Vector first.")) + function broadcasted(::typeof($op), a::AbstractZerosVector, b::AbstractVector) + ax = broadcast_shape(axes(a), axes(b)) + ax == axes(b) || throw(ArgumentError("cannot broadcast $a with an array with size $(size(b))")) TT = typeof($op(zero(eltype(a)), zero(eltype(b)))) $op === (+) && eltype(b) === TT ? b : broadcasted(TT ∘ ($op), b) end - - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractFillVector, b::AbstractZerosVector) = - Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof($op), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), $op, a, b) - - broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZerosVector, b::AbstractFillVector) = - Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof($op), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), $op, a, b) + function broadcasted(::typeof($op), a::AbstractZerosVector, b::AbstractZerosVector) + ax = broadcast_shape(axes(a), axes(b)) + TT = typeof($op(zero(eltype(a)), zero(eltype(b)))) + Zeros(TT, ax) + end end end # Need to prevent array-valued fills from broadcasting over entry -_broadcast_getindex_value(a::AbstractFill{<:Number}) = getindex_value(a) -_broadcast_getindex_value(a::AbstractFill) = Ref(getindex_value(a)) - +_mayberef(x) = Ref(x) +_mayberef(x::Number) = x -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange) broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) - return broadcasted(*, _broadcast_getindex_value(a), b) + return broadcasted(*, _mayberef(getindex_value(a)), b) end -function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill) +function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill) broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first.")) - return broadcasted(*, a, _broadcast_getindex_value(b)) + return broadcasted(*, a, _mayberef(getindex_value(b))) end -broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x), axes(r)) -broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x, getindex_value(r)), axes(r)) -broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x[]), axes(r)) -broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x[], getindex_value(r)), axes(r)) - # support AbstractFill .^ k -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractFill{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r)) -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractOnes{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r)) -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractZeros{T,N}, ::Base.RefValue{Val{0}}) where {T,N} = broadcasted_ones(op, r, T, axes(r)) -broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractZeros{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_zeros(op, r, T, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractFill{T,N}, ::Val{k}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractOnes{T,N}, ::Val{k}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractZeros{T,N}, ::Val{0}) where {T,N} = broadcasted_ones(op, r, T, axes(r)) +broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractZeros{T,N}, ::Val{k}) where {T,N,k} = broadcasted_zeros(op, r, T, axes(r)) # supports structured broadcast if isdefined(LinearAlgebra, :fzero) diff --git a/test/runtests.jl b/test/runtests.jl index 0163df8a..c4700fef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -821,7 +821,7 @@ end @testset "maximum/minimum/svd/sort" begin @test maximum(Fill(1, 1_000_000_000)) == minimum(Fill(1, 1_000_000_000)) == 1 @test svdvals(fill(2,5,6)) ≈ svdvals(Fill(2,5,6)) - @test svdvals(Eye(5)) === Fill(1.0,5) + @test svdvals(Eye(5)) === Ones(5) @test sort(Ones(5)) == sort!(Ones(5)) @test_throws MethodError issorted(Fill(im, 2)) @@ -928,21 +928,21 @@ end rng = MersenneTwister(123456) sizes = [(5, 4), (5, 1), (1, 4), (1, 1), (5,)] - for sx in sizes, sy in sizes + @testset for sx in sizes, sy in sizes x, y = Fill(randn(rng), sx), Fill(randn(rng), sy) x_one, y_one = Ones(sx), Ones(sy) x_zero, y_zero = Zeros(sx), Zeros(sy) x_dense, y_dense = randn(rng, sx), randn(rng, sy) for x in [x, x_one, x_zero, x_dense], y in [y, y_one, y_zero, y_dense] - @test x .+ y == collect(x) .+ collect(y) + @test x .+ y ≈ collect(x) .+ collect(y) end @test x_zero .+ y_zero isa Zeros @test x_zero .+ y_one isa Ones @test x_one .+ y_zero isa Ones for x in [x, x_one, x_zero, x_dense], y in [y, y_one, y_zero, y_dense] - @test x .* y == collect(x) .* collect(y) + @test x .* y ≈ collect(x) .* collect(y) end for x in [x, x_one, x_zero, x_dense] @test x .* y_zero isa Zeros @@ -1062,7 +1062,7 @@ end @test_throws DimensionMismatch Zeros{Int}(2) .+ (1:5) @test_throws DimensionMismatch (1:5) .+ Zeros{Int}(2) - for v in (rand(Bool, 5), [1:5;], SVector{5}(1:5), SVector{5,ComplexF16}(1:5)), T in (Bool, Int, Float64) + @testset "$(typeof(v)) $T" for v in (rand(Bool, 5), [1:5;], SVector{5}(1:5), SVector{5,ComplexF16}(1:5)), T in (Bool, Int, Float64) TT = eltype(v + zeros(T, 5)) S = v isa SVector ? SVector{5,TT} : Vector{TT} @@ -1115,7 +1115,7 @@ end @testset "issue #208" begin TS = (Bool, Int, Float32, Float64) - for S in TS, T in TS + @testset for S in TS, T in TS u = rand(S, 2) v = Zeros(T, 2) if zero(S) + zero(T) isa S @@ -1148,6 +1148,32 @@ end end end end + + @testset "Zeros to Fill" begin + @test @inferred((f -> ((x -> (1,)).(f)))((Zeros(4)))) == Fill((1,), 4) + @test @inferred((f -> ((x -> Val(1)).(f)))((Zeros(4)))) == Fill(Val(1), 4) + end + + @testset "multi-element broadcast" begin + x = Fill(2, 2) + y = @. 2 * x * 2 + @test y === Fill(8, 2) + end + + @testset "nested broadcast" begin + bc = Broadcast.broadcasted(*, Zeros(4), Ones(4), Broadcast.broadcasted(*, Zeros(4), Ones(4), Zeros(4))) + @test copy(bc) === Zeros(4) + end + + @testset "0d" begin + @test real.(Fill(2)) == real.(fill(2)) + end + + @testset "preserve 0d" begin + @test real(Fill(4 + 5im)) == real(fill(4 + 5im)) + @test imag(Fill(4 + 5im)) == imag(fill(4 + 5im)) + @test conj(Fill(4 + 5im)) == conj(fill(4 + 5im)) + end end @testset "map" begin @@ -1156,7 +1182,7 @@ end @test map(isone,x1) === Fill(true,5) x0 = Zeros(5) - @test map(exp,x0) === exp.(x0) + @test map(exp,x0) == exp.(x0) x2 = Fill(2,5,3) @test map(exp,x2) === Fill(exp(2),5,3) @@ -2149,8 +2175,10 @@ end @test D - Zeros(5,5) isa Diagonal @test D .+ Zeros(5,5) isa Diagonal @test D .- Zeros(5,5) isa Diagonal - @test D .* Zeros(5,5) isa Diagonal - @test Zeros(5,5) .* D isa Diagonal + @test D .* Zeros(5,5) isa FillArrays.ZerosMatrix + @test ((x,y) -> x * y).(D, Zeros(5,5)) isa Diagonal + @test Zeros(5,5) .* D isa FillArrays.ZerosMatrix + @test ((x,y) -> x * y).(Zeros(5,5), D) isa Diagonal @test Zeros(5,5) - D isa Diagonal @test Zeros(5,5) + D isa Diagonal @test Zeros(5,5) .- D isa Diagonal From e09b5f68e1f992c220b6f4d95da84c5867cbd16c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 13:47:45 +0530 Subject: [PATCH 02/10] specialize real/imag/conj for real arrays --- src/fillbroadcast.jl | 6 ++++++ test/runtests.jl | 22 +++++++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 426420f9..f777b92d 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -145,6 +145,12 @@ function broadcast_preserving_0d(f, As...) end for f in (:real, :imag, :conj) @eval ($f)(A::AbstractFill) = broadcast_preserving_0d($f, A) + @eval ($f)(A::AbstractZeros) = A +end +for T in (:AbstractOnes, :(AbstractFill{<:Real})) + @eval real(A::$T) = A + @eval imag(A::$T) = Zeros{eltype(A)}(axes(A)) + @eval conj(A::$T) = A end ### Binary broadcasting diff --git a/test/runtests.jl b/test/runtests.jl index c4700fef..56692e2c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1170,9 +1170,25 @@ end end @testset "preserve 0d" begin - @test real(Fill(4 + 5im)) == real(fill(4 + 5im)) - @test imag(Fill(4 + 5im)) == imag(fill(4 + 5im)) - @test conj(Fill(4 + 5im)) == conj(fill(4 + 5im)) + for f in (real, imag, conj), + (F, A) in ((Fill(4 + 5im), fill(4 + 5im)), + (Zeros{ComplexF64}(), zeros(ComplexF64)), + (Zeros(), zeros()), + (Ones(), ones()), + (Ones{ComplexF64}(), ones(ComplexF64)), + ) + x = f(F) + y = f(A) + @test x == y + @test x isa FillArrays.AbstractFill + if F[] isa Real + if f === imag + @test x isa Zeros + else + @test x isa typeof(F) + end + end + end end end From f6f64888538327558d03d37f791b46ea7a22e8d8 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 13:55:24 +0530 Subject: [PATCH 03/10] Binary broadcast test --- src/fillbroadcast.jl | 17 ++++++++++------- test/runtests.jl | 33 ++++++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index f777b92d..d1fdf4af 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -143,15 +143,18 @@ function broadcast_preserving_0d(f, As...) r = copy(bc) length(axes(bc)) == 0 ? Fill(r) : r end -for f in (:real, :imag, :conj) +for f in (:real, :imag) @eval ($f)(A::AbstractFill) = broadcast_preserving_0d($f, A) - @eval ($f)(A::AbstractZeros) = A -end -for T in (:AbstractOnes, :(AbstractFill{<:Real})) - @eval real(A::$T) = A - @eval imag(A::$T) = Zeros{eltype(A)}(axes(A)) - @eval conj(A::$T) = A + @eval ($f)(A::AbstractZeros) = Zeros{real(eltype(A))}(axes(A)) end +conj(A::AbstractFill) = broadcast_preserving_0d(conj, A) +conj(A::AbstractZeros) = A +real(A::AbstractOnes) = Ones{real(eltype(A))}(axes(A)) +imag(A::AbstractOnes) = Zeros{real(eltype(A))}(axes(A)) +conj(A::AbstractOnes) = A +real(A::AbstractFill{<:Real}) = A +imag(A::AbstractFill{<:Real}) = Zeros{eltype(A)}(axes(A)) +conj(A::AbstractFill{<:Real}) = A ### Binary broadcasting diff --git a/test/runtests.jl b/test/runtests.jl index 56692e2c..ab6c600b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1170,26 +1170,41 @@ end end @testset "preserve 0d" begin - for f in (real, imag, conj), - (F, A) in ((Fill(4 + 5im), fill(4 + 5im)), - (Zeros{ComplexF64}(), zeros(ComplexF64)), - (Zeros(), zeros()), - (Ones(), ones()), - (Ones{ComplexF64}(), ones(ComplexF64)), - ) + @testset for f in (real, imag, conj), (F, A) in ( + (Fill(4), fill(4)), + (Fill(4 + 5im), fill(4 + 5im)), + (Fill(SMatrix{2,2,ComplexF64,4}(fill(4 + 5im, 4))), fill(SMatrix{2,2,ComplexF64,4}(fill(4 + 5im, 4)))), + (Zeros{ComplexF64}(), zeros(ComplexF64)), + (Zeros(), zeros()), + (Ones(), ones()), + (Ones{ComplexF64}(), ones(ComplexF64)), + ) x = f(F) y = f(A) @test x == y + @test eltype(x) == eltype(y) @test x isa FillArrays.AbstractFill - if F[] isa Real + if F isa Ones if f === imag @test x isa Zeros else - @test x isa typeof(F) + @test x isa Ones + end + end + if F[] isa Real + if f === imag + @test x isa Zeros end end end end + + @testset "issue #40" begin + f(x) = x + g(x, y) = x + F = Fill(1, 2) + @test g.(F, "a") === f.(F) + end end @testset "map" begin From 721dbf2c71dd92bc1a09ab4d012459c46796d449 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 15:06:09 +0530 Subject: [PATCH 04/10] Bump version to v1.13.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f22c04ab..4ec9f49a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.12.0" +version = "1.13.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 29b645257f87f39b1bbe8bbd7c41993d77fc07e7 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 15:36:50 +0530 Subject: [PATCH 05/10] Specialize scaling by a number --- src/fillalgebra.jl | 7 +++++++ test/runtests.jl | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 94d763bd..dcc78be2 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -78,6 +78,13 @@ mult_zeros(a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = mult_zeros( mult_zeros(a, b) = mult_fill(a, b, mult_axes(a, b)) mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b)) +# scaling +*(a::AbstractFill, b::Number) = Fill(getindex_value(a) * b, axes(a)) +*(a::Number, b::AbstractFill) = Fill(a * getindex_value(b), axes(b)) +*(a::AbstractZeros, b::Number) = Zeros(typeof(getindex_value(a) * b), axes(a)) +*(a::Number, b::AbstractZeros) = Zeros(typeof(a * getindex_value(b)), axes(b)) + +# matmul *(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b) *(a::AbstractFillMatrix, b::AbstractFillVector) = mult_fill(a,b) diff --git a/test/runtests.jl b/test/runtests.jl index ab6c600b..e46d9c85 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1167,6 +1167,11 @@ end @testset "0d" begin @test real.(Fill(2)) == real.(fill(2)) + @test (@. 2 * Fill(2) * 2) == (@. 2 * fill(2) * 2) + for (F, A) in ((Fill(2), fill(2)), (Zeros(), zeros()), (Ones(), ones())) + @test F * 2 == A * 2 + @test 2 * F == 2 * A + end end @testset "preserve 0d" begin From c7db4ecbf76a4d2789620d9b16431cbe06700dc1 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 15:38:33 +0530 Subject: [PATCH 06/10] Update comment --- src/fillbroadcast.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index d1fdf4af..6ea61834 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -101,14 +101,14 @@ has_static_value(x::Broadcast.Broadcasted) = all(has_static_value, x.args) function _iszeros(bc::Broadcast.Broadcasted) all(has_static_value, bc.args) && _iszero(_getindex_value(bc)) end -# conservative check for zeros. In most cases we can't really compare with zero +# conservative check for zeros. In most cases, there isn't a zero element to compare with _iszero(x::Union{Number, AbstractArray}) = iszero(x) _iszero(_) = false function _isones(bc::Broadcast.Broadcasted) all(has_static_value, bc.args) && _isone(_getindex_value(bc)) end -# conservative check for ones. In most cases we can't really compare with one +# conservative check for ones. In most cases, there isn't a unit element to compare with _isone(x::Union{Number, AbstractArray}) = isone(x) _isone(_) = false From 2f61d4b400d6cf5b35f5c109896c856a01378ce2 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 15:39:59 +0530 Subject: [PATCH 07/10] Delete commented out code --- src/fillbroadcast.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 6ea61834..02dc892b 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -209,14 +209,6 @@ broadcasted(::typeof(*), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zero broadcasted(::typeof(/), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(/, a, b) broadcasted(::typeof(\), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(\, a, b) -# for op in (:*, :/, :\) -# @eval broadcasted(::typeof($op), a::AbstractOnes, b::AbstractOnes) = _broadcasted_ones($op, a, b) -# end - -# for op in (:/, :\) -# @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros{<:Number}, b::AbstractZeros{<:Number}) = _broadcasted_nan($op, a, b) -# end - # special case due to missing converts for ranges _range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a _range_convert(::Type{AbstractVector{T}}, a::AbstractUnitRange) where T = convert(T,first(a)):convert(T,last(a)) From c141ff8bc8a0cc6e6342b3f3e873fe2747c5125f Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 16:18:02 +0530 Subject: [PATCH 08/10] Process Fill broadcasting before others --- src/fillbroadcast.jl | 20 ++++++++++++++------ test/runtests.jl | 5 +++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 02dc892b..dd23c057 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -118,7 +118,18 @@ _isfill(f::Number) = true _isfill(f::Ref) = true _isfill(::Any) = false -function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{N}}) where {N} +_broadcast_maybecopy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) = copy(bc) +_broadcast_maybecopy(x) = x + +function _fallback_copy(bc) + # treat the fill components + bc2 = Base.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...) + # fallback style + S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{ndims(bc)}} + copy(convert(S, bc2)) +end + +function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) if _iszeros(bc) return Zeros(typeof(_getindex_value(bc)), axes(bc)) elseif _isones(bc) @@ -126,15 +137,12 @@ function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{N}}) where {N} elseif _isfill(bc) return Fill(_getindex_value(bc), axes(bc)) else - # fallback style - S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{N}} - copy(convert(S, bc)) + _fallback_copy(bc) end end # make the zero-dimensional case consistent with Base function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) - S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} - copy(convert(S, bc)) + _fallback_copy(bc) end # some cases that preserve 0d diff --git a/test/runtests.jl b/test/runtests.jl index e46d9c85..1b695cde 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1210,6 +1210,11 @@ end F = Fill(1, 2) @test g.(F, "a") === f.(F) end + + @testset "early binding" begin + A = ones(2) .+ (x -> rand()).(Fill(2,2)) + @test all(==(A[1]), A) + end end @testset "map" begin From 5a53ef2a9debb7a5bf292f0299abec6bbc769db9 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 16:23:35 +0530 Subject: [PATCH 09/10] Recursively process fill components --- src/fillbroadcast.jl | 1 + test/runtests.jl | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index dd23c057..13074846 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -119,6 +119,7 @@ _isfill(f::Ref) = true _isfill(::Any) = false _broadcast_maybecopy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) = copy(bc) +_broadcast_maybecopy(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...) _broadcast_maybecopy(x) = x function _fallback_copy(bc) diff --git a/test/runtests.jl b/test/runtests.jl index 1b695cde..4d5e37d6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1214,6 +1214,8 @@ end @testset "early binding" begin A = ones(2) .+ (x -> rand()).(Fill(2,2)) @test all(==(A[1]), A) + A = ones(1,5) .+ (ones(1) .+ (_ -> rand()).(Fill("vec", 2))) + @test all(==(A[1]), A) end end From 712a201d5b9dfa9a8404f927c138c8bcfd0ba5ff Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 26 Aug 2024 16:41:27 +0530 Subject: [PATCH 10/10] Refactor common parts --- src/fillbroadcast.jl | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 13074846..115fc6ac 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -118,33 +118,36 @@ _isfill(f::Number) = true _isfill(f::Ref) = true _isfill(::Any) = false -_broadcast_maybecopy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) = copy(bc) -_broadcast_maybecopy(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...) -_broadcast_maybecopy(x) = x +function _copy_fill(bc) + v = _getindex_value(bc) + if _iszeros(bc) + return Zeros(typeof(v), axes(bc)) + elseif _isones(bc) + return Ones(typeof(v), axes(bc)) + end + return Fill(v, axes(bc)) +end + +# recursively copy the purely fill components +function _preprocess_fill(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) + _isfill(bc) ? _copy_fill(bc) : Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) +end +_preprocess_fill(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) +_preprocess_fill(x) = x function _fallback_copy(bc) - # treat the fill components - bc2 = Base.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...) + # copy the purely fill components + bc2 = Base.broadcasted(bc.f, map(_preprocess_fill, bc.args)...) # fallback style S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{ndims(bc)}} copy(convert(S, bc2)) end function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) - if _iszeros(bc) - return Zeros(typeof(_getindex_value(bc)), axes(bc)) - elseif _isones(bc) - return Ones(typeof(_getindex_value(bc)), axes(bc)) - elseif _isfill(bc) - return Fill(_getindex_value(bc), axes(bc)) - else - _fallback_copy(bc) - end + _isfill(bc) ? _copy_fill(bc) : _fallback_copy(bc) end # make the zero-dimensional case consistent with Base -function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) - _fallback_copy(bc) -end +Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) = _fallback_copy(bc) # some cases that preserve 0d function broadcast_preserving_0d(f, As...)