Skip to content

Commit

Permalink
don't materialize when broadcasting Zeros with Vector (#211)
Browse files Browse the repository at this point in the history
* don't materialize when broadcasting Zeros with Vector

* fix method overwrite

* fix type signature

* version bump to v0.13.9

* Add tests
  • Loading branch information
jishnub authored Apr 18, 2023
1 parent c3b38ad commit 8fac81d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
19 changes: 13 additions & 6 deletions src/fillbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,21 +198,28 @@ _copy_oftype(A::AbstractArray{T,N}, ::Type{S}) where {T,N,S} = convert(AbstractA
_copy_oftype(A::AbstractRange{T}, ::Type{T}) where T = copy(A)
_copy_oftype(A::AbstractRange{T}, ::Type{S}) where {T,S} = map(S, A)

for op in (:+, -)
for op in (:+, :-)
@eval begin
function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractVector{T}, b::ZerosVector{V}) where {T,V}
function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractRange{T}, b::ZerosVector{V}) where {T,V}
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
_copy_oftype(a, promote_type(T,V))
end
function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractVector{T}, b::ZerosVector{V}) where {T,V}
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
TT = promote_type(T,V)
broadcasted(TT$op, a)
end

broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractFill{T,1}, b::ZerosVector) where T =
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractFillVector{T}, b::ZerosVector) where T =
Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof($op), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), $op, a, b)
end
end

function broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::ZerosVector{T}, b::AbstractVector{V}) where {T,V}
broadcast_shape(axes(a), axes(b))
_copy_oftype(b, promote_type(T,V))
function broadcasted(S::DefaultArrayStyle{1}, ::typeof(+), a::ZerosVector, b::AbstractRange)
broadcasted(S, +, b, a)
end
function broadcasted(S::DefaultArrayStyle{1}, ::typeof(+), a::ZerosVector, b::AbstractVector)
broadcasted(S, +, b, a)
end

broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::ZerosVector, b::AbstractFillVector) =
Expand Down
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,15 @@ end
@test Zeros(10) .- Zeros(1,9) Zeros(10,9)
@test Ones(10) .- Zeros(1,9) Ones(10,9)
@test Ones(10) .- Ones(1,9) Zeros(10,9)

end

@testset "issue #208" begin
u = rand(2); v = Zeros(2)
@test Broadcast.broadcasted(-, u, v) isa Broadcast.Broadcasted
@test Broadcast.broadcasted(+, u, v) isa Broadcast.Broadcasted
@test Broadcast.broadcasted(-, v, u) isa Broadcast.Broadcasted
@test Broadcast.broadcasted(+, v, u) isa Broadcast.Broadcasted
end

@testset "Zero .*" begin
Expand Down

0 comments on commit 8fac81d

Please sign in to comment.