From 39df7ec50dac0170e7ef6dd82c7fd4b8a9f7f7cd Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Sat, 4 May 2024 18:18:01 +0100 Subject: [PATCH] Overload accumulate and make types of cumsum consistent with Vector (#363) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Support accumulate(±, ::AbstractFill) * Overload accumulate and make types of cumsum consistent with Vector --- Project.toml | 2 +- src/FillArrays.jl | 28 +++++++++++++++++++---- src/fillbroadcast.jl | 2 ++ test/runtests.jl | 54 +++++++++++++++++++++++++++++++------------- 4 files changed, 64 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index caea01ad..dd70d301 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.10.2" +version = "1.11" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index d8db9f5d..4671053a 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -7,7 +7,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, any, all, axes, isone, iszero, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, map, zero, show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat, - parent, similar, issorted + parent, similar, issorted, add_sum, accumulate, OneTo import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, @@ -576,14 +576,32 @@ sum(x::AbstractZeros) = getindex_value(x) # needed to support infinite case steprangelen(st...) = StepRangeLen(st...) -cumsum(x::AbstractFill{<:Any,1}) = steprangelen(getindex_value(x), getindex_value(x), length(x)) +function cumsum(x::AbstractFill{T,1}) where T + V = promote_op(add_sum, T, T) + steprangelen(convert(V,getindex_value(x)), getindex_value(x), length(x)) +end -cumsum(x::AbstractZerosVector) = x -cumsum(x::AbstractZerosVector{Bool}) = x -cumsum(x::AbstractOnesVector{II}) where II<:Integer = convert(AbstractVector{II}, oneto(length(x))) +cumsum(x::AbstractZerosVector{T}) where T = _range_convert(AbstractVector{promote_op(add_sum, T, T)}, x) +cumsum(x::AbstractZerosVector{Bool}) = _range_convert(AbstractVector{Int}, x) +cumsum(x::AbstractOnesVector{T}) where T<:Integer = _range_convert(AbstractVector{promote_op(add_sum, T, T)}, oneto(length(x))) cumsum(x::AbstractOnesVector{Bool}) = oneto(length(x)) +for op in (:+, :-) + @eval begin + function accumulate(::typeof($op), x::AbstractFill{T,1}) where T + V = promote_op($op, T, T) + steprangelen(convert(V,getindex_value(x)), $op(getindex_value(x)), length(x)) + end + + accumulate(::typeof($op), x::AbstractZerosVector{T}) where T = _range_convert(AbstractVector{promote_op($op, T, T)}, x) + accumulate(::typeof($op), x::AbstractZerosVector{Bool}) = _range_convert(AbstractVector{Int}, x) + end +end + +accumulate(::typeof(+), x::AbstractOnesVector{T}) where T<:Integer = _range_convert(AbstractVector{promote_op(+, T, T)}, oneto(length(x))) +accumulate(::typeof(+), x::AbstractOnesVector{Bool}) = oneto(length(x)) + ######### # Diff ######### diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index d286418a..2b5ea59c 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -177,7 +177,9 @@ end # special case due to missing converts for ranges _range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a _range_convert(::Type{AbstractVector{T}}, a::AbstractUnitRange) where T = convert(T,first(a)):convert(T,last(a)) +_range_convert(::Type{AbstractVector{T}}, a::OneTo) where T = OneTo(convert(T, a.stop)) _range_convert(::Type{AbstractVector{T}}, a::AbstractRange) where T = convert(T,first(a)):step(a):convert(T,last(a)) +_range_convert(::Type{AbstractVector{T}}, a::ZerosVector) where T = ZerosVector{T}(length(a)) # TODO: replacing with the following will support more general broadcasting. diff --git a/test/runtests.jl b/test/runtests.jl index 8af5ed45..2e2aba22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -826,31 +826,53 @@ end @test_throws MethodError sort!(Fill(im, 2)) end -@testset "Cumsum and diff" begin - @test sum(Fill(3,10)) ≡ 30 - @test reduce(+, Fill(3,10)) ≡ 30 - @test sum(x -> x + 1, Fill(3,10)) ≡ 40 - @test cumsum(Fill(3,10)) ≡ StepRangeLen(3,3,10) - - @test sum(Ones(10)) ≡ 10.0 - @test sum(x -> x + 1, Ones(10)) ≡ 20.0 - @test cumsum(Ones(10)) ≡ StepRangeLen(1.0, 1.0, 10) +@testset "Cumsum, accumulate and diff" begin + @test @inferred(sum(Fill(3,10))) ≡ 30 + @test @inferred(reduce(+, Fill(3,10))) ≡ 30 + @test @inferred(sum(x -> x + 1, Fill(3,10))) ≡ 40 + @test @inferred(cumsum(Fill(3,10))) ≡ @inferred(accumulate(+, Fill(3,10))) ≡ StepRangeLen(3,3,10) + @test @inferred(accumulate(-, Fill(3,10))) ≡ StepRangeLen(3,-3,10) + + @test @inferred(sum(Ones(10))) ≡ 10.0 + @test @inferred(sum(x -> x + 1, Ones(10))) ≡ 20.0 + @test @inferred(cumsum(Ones(10))) ≡ @inferred(accumulate(+, Ones(10))) ≡ StepRangeLen(1.0, 1.0, 10) + @test @inferred(accumulate(-, Ones(10))) ≡ StepRangeLen(1.0,-1.0,10) @test sum(Ones{Int}(10)) ≡ 10 @test sum(x -> x + 1, Ones{Int}(10)) ≡ 20 - @test cumsum(Ones{Int}(10)) ≡ Base.OneTo(10) + @test cumsum(Ones{Int}(10)) ≡ accumulate(+,Ones{Int}(10)) ≡ Base.OneTo(10) + @test accumulate(-, Ones{Int}(10)) ≡ StepRangeLen(1,-1,10) @test sum(Zeros(10)) ≡ 0.0 @test sum(x -> x + 1, Zeros(10)) ≡ 10.0 - @test cumsum(Zeros(10)) ≡ Zeros(10) + @test cumsum(Zeros(10)) ≡ accumulate(+,Zeros(10)) ≡ accumulate(-,Zeros(10)) ≡ Zeros(10) @test sum(Zeros{Int}(10)) ≡ 0 @test sum(x -> x + 1, Zeros{Int}(10)) ≡ 10 - @test cumsum(Zeros{Int}(10)) ≡ Zeros{Int}(10) - - @test cumsum(Zeros{Bool}(10)) ≡ Zeros{Bool}(10) - @test cumsum(Ones{Bool}(10)) ≡ Base.OneTo{Int}(10) - @test cumsum(Fill(true,10)) ≡ StepRangeLen(true, true, 10) + @test cumsum(Zeros{Int}(10)) ≡ accumulate(+,Zeros{Int}(10)) ≡ accumulate(-,Zeros{Int}(10)) ≡ Zeros{Int}(10) + + # we want cumsum of fills to match the types of the standard cusum + @test all(cumsum(Zeros{Bool}(10)) .≡ cumsum(zeros(Bool,10))) + @test all(accumulate(+, Zeros{Bool}(10)) .≡ accumulate(+, zeros(Bool,10)) .≡ accumulate(-, zeros(Bool,10))) + @test cumsum(Zeros{Bool}(10)) ≡ accumulate(+, Zeros{Bool}(10)) ≡ accumulate(-, Zeros{Bool}(10)) ≡ Zeros{Int}(10) + @test cumsum(Ones{Bool}(10)) ≡ accumulate(+, Ones{Bool}(10)) ≡ Base.OneTo{Int}(10) + @test all(cumsum(Fill(true,10)) .≡ cumsum(fill(true,10))) + @test cumsum(Fill(true,10)) ≡ StepRangeLen(1, true, 10) + + @test all(cumsum(Zeros{UInt8}(10)) .≡ cumsum(zeros(UInt8,10))) + @test all(accumulate(+, Zeros{UInt8}(10)) .≡ accumulate(+, zeros(UInt8,10))) + @test cumsum(Zeros{UInt8}(10)) ≡ Zeros{UInt64}(10) + @test accumulate(+, Zeros{UInt8}(10)) ≡ accumulate(-, Zeros{UInt8}(10)) ≡ Zeros{UInt8}(10) + + @test all(cumsum(Ones{UInt8}(10)) .≡ cumsum(ones(UInt8,10))) + @test all(accumulate(+, Ones{UInt8}(10)) .≡ accumulate(+, ones(UInt8,10))) + @test cumsum(Ones{UInt8}(10)) ≡ Base.OneTo(UInt64(10)) + @test accumulate(+, Ones{UInt8}(10)) ≡ Base.OneTo(UInt8(10)) + + @test all(cumsum(Fill(UInt8(2),10)) .≡ cumsum(fill(UInt8(2),10))) + @test all(accumulate(+, Fill(UInt8(2))) .≡ accumulate(+, fill(UInt8(2)))) + @test cumsum(Fill(UInt8(2),10)) ≡ StepRangeLen(UInt64(2), UInt8(2), 10) + @test accumulate(+, Fill(UInt8(2),10)) ≡ StepRangeLen(UInt8(2), UInt8(2), 10) @test diff(Fill(1,10)) ≡ Zeros{Int}(9) @test diff(Ones{Float64}(10)) ≡ Zeros{Float64}(9)