Skip to content

Commit

Permalink
Merge branch 'master' into jishnub/unique
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored Aug 27, 2024
2 parents 0242502 + 7b64042 commit 2ae63f1
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FillArrays"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.12.0"
version = "1.13.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
4 changes: 2 additions & 2 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ end
rank(F::AbstractFill) = iszero(getindex_value(F)) ? 0 : 1
IndexStyle(::Type{<:AbstractFill{<:Any,N,<:NTuple{N,Base.OneTo{Int}}}}) where N = IndexLinear()

issymmetric(F::AbstractFillMatrix) = axes(F,1) == axes(F,2)
ishermitian(F::AbstractFillMatrix) = issymmetric(F) && iszero(imag(getindex_value(F)))
issymmetric(F::AbstractFillMatrix) = axes(F,1) == axes(F,2) && (isempty(F) || issymmetric(getindex_value(F)))
ishermitian(F::AbstractFillMatrix) = axes(F,1) == axes(F,2) && (isempty(F) || ishermitian(getindex_value(F)))

Base.IteratorSize(::Type{<:AbstractFill{T,N,Axes}}) where {T,N,Axes} = _IteratorSize(Axes)
_IteratorSize(::Type{Tuple{}}) = Base.HasShape{0}()
Expand Down
52 changes: 33 additions & 19 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ OneElement{T}(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where {T,N} = One
OneElement{T}(val, inds::Int, sz::Int) where T = OneElement{T}(val, (inds,), (sz,))

"""
OneElement{T}(val, ind::Int, n::Int)
OneElement{T}(ind::Int, n::Int)
Creates a length `n` vector where the `ind` entry is equal to `one(T)`, and all other entries are zero.
"""
Expand Down Expand Up @@ -141,6 +141,8 @@ function isone(A::OneElementMatrix)
isone(getindex_value(A))
end

-(O::OneElement) = OneElement(-O.val, O.ind, O.axes)

*(x::OneElement, b::Number) = OneElement(x.val * b, x.ind, x.axes)
*(b::Number, x::OneElement) = OneElement(b * x.val, x.ind, x.axes)
/(x::OneElement, b::Number) = OneElement(x.val / b, x.ind, x.axes)
Expand All @@ -157,13 +159,6 @@ function *(A::OneElementMatrix, B::OneElementVecOrMat)
OneElement(val, (A.ind[1], B.ind[2:end]...), (axes(A,1), axes(B)[2:end]...))
end

function *(A::AbstractFillMatrix, x::OneElementVector)
check_matmul_sizes(A, x)
val = getindex_value(A) * getindex_value(x)
Fill(val, (axes(A,1),))
end
*(A::AbstractZerosMatrix, x::OneElementVector) = mult_zeros(A, x)

*(A::OneElementMatrix, x::AbstractZerosVector) = mult_zeros(A, x)

function *(A::OneElementMatrix, B::AbstractFillVector)
Expand Down Expand Up @@ -392,16 +387,20 @@ function triu(A::OneElementMatrix, k::Integer=0)
OneElement(nzband < k ? zero(A.val) : A.val, A.ind, axes(A))
end

# diag
function diag(O::OneElementMatrix, k::Integer=0)
Base.require_one_based_indexing(O)
len = length(diagind(O, k))
ind = O.ind[2] - O.ind[1] == k ? (k >= 0 ? O.ind[2] - k : O.ind[1] + k) : len + 1
OneElement(getindex_value(O), ind, len)
end

# broadcast

function broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::OneElement{<:Any,N}) where {N}
OneElement(conj(r.val), r.ind, axes(r))
end
function broadcasted(::DefaultArrayStyle{N}, ::typeof(real), r::OneElement{<:Any,N}) where {N}
OneElement(real(r.val), r.ind, axes(r))
end
function broadcasted(::DefaultArrayStyle{N}, ::typeof(imag), r::OneElement{<:Any,N}) where {N}
OneElement(imag(r.val), r.ind, axes(r))
for f in (:abs, :abs2, :conj, :real, :imag)
@eval function broadcasted(::DefaultArrayStyle{N}, ::typeof($f), r::OneElement{<:Any,N}) where {N}
OneElement($f(r.val), r.ind, axes(r))
end
end
function broadcasted(::DefaultArrayStyle{N}, ::typeof(^), r::OneElement{<:Any,N}, x::Number) where {N}
OneElement(r.val^x, r.ind, axes(r))
Expand All @@ -420,9 +419,14 @@ end

function Base.reshape(A::OneElement, shape::Tuple{Vararg{Int}})
prod(shape) == length(A) || throw(DimensionMismatch("new dimension $shape must be consistent with array size $(length(A))"))
# we use the fact that the linear index of the non-zero value is preserved
oldlinind = LinearIndices(A)[A.ind...]
newcartind = CartesianIndices(shape)[oldlinind]
if all(in.(A.ind, axes(A)))
# we use the fact that the linear index of the non-zero value is preserved
oldlinind = LinearIndices(A)[A.ind...]
newcartind = CartesianIndices(shape)[oldlinind]
else
# arbitrarily set to some value outside the domain
newcartind = shape .+ 1
end
OneElement(A.val, Tuple(newcartind), shape)
end

Expand All @@ -446,3 +450,13 @@ _maybesize(t) = t
Base.show(io::IO, A::OneElement) = print(io, OneElement, "(", A.val, ", ", A.ind, ", ", _maybesize(axes(A)), ")")
Base.show(io::IO, A::OneElement{<:Any,1,Tuple{Int},Tuple{Base.OneTo{Int}}}) =
print(io, OneElement, "(", A.val, ", ", A.ind[1], ", ", size(A,1), ")")

# mapreduce
Base.sum(O::OneElement; dims=:, kw...) = _sum(O, dims; kw...)
_sum(O::OneElement, ::Colon; kw...) = sum((getindex_value(O),); kw...)
function _sum(O::OneElement, dims; kw...)
v = _sum(O, :; kw...)
ax = Base.reduced_indices(axes(O), dims)
ind = ntuple(x -> x in dims ? first(ax[x]) + (O.ind[x] in axes(O)[x]) - 1 : O.ind[x], ndims(O))
OneElement(v, ind, ax)
end
72 changes: 70 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ end
end

@testset "ishermitian" begin
for el in (2, 3+0im, 4+5im), size in [(3,3), (3,4)]
@testset for el in (2, 3+0im, 4+5im, [1 2; 3 4], fill(2, 2, 2)), size in [(3,3), (3,4), (0,0), (0,1)]
@test issymmetric(Fill(el, size...)) == issymmetric(fill(el, size...))
@test ishermitian(Fill(el, size...)) == ishermitian(fill(el, size...))
end
Expand Down Expand Up @@ -2171,6 +2171,7 @@ end
@test FillArrays.nzind(e₁) == CartesianIndex(2)
@test e₁[2] === e₁[2,1] === e₁[2,1,1] === 1
@test_throws BoundsError e₁[6]
@test -e₁ === OneElement(-1, 2, 5)

f₁ = AbstractArray{Float64}(e₁)
@test f₁ isa OneElement{Float64,1}
Expand All @@ -2190,6 +2191,7 @@ end
V = OneElement(2, (2,3), (3,4))
@test V == [0 0 0 0; 0 0 2 0; 0 0 0 0]
@test FillArrays.nzind(V) == CartesianIndex(2,3)
@test -V == OneElement(-2, (2,3), (3,4))

Vf = AbstractArray{Float64}(V)
@test Vf isa OneElement{Float64,2}
Expand Down Expand Up @@ -2326,6 +2328,10 @@ end
end
O = OneElement(2, (), ())
@test reshape(O, ()) === O

O = OneElement(5, 3)
@test reshape(O, 1, 3) == reshape(Array(O), 1, 3)
@test reshape(reshape(O, 1, 3), 3) == O
end

@testset "isassigned" begin
Expand Down Expand Up @@ -2667,9 +2673,11 @@ end
end

@testset "broadcasting" begin
for v in (OneElement(2, 3, 4), OneElement(2im, (1,2), (3,4)))
for v in (OneElement(-2, 3, 4), OneElement(2im, (1,2), (3,4)))
w = Array(v)
n = 2
@test abs.(v) == abs.(w)
@test abs2.(v) == abs2.(w)
@test real.(v) == real.(w)
@test imag.(v) == imag.(w)
@test conj.(v) == conj.(w)
Expand Down Expand Up @@ -2718,6 +2726,66 @@ end
end
end
end

@testset "sum" begin
@testset "OneElement($v, $ind, $sz)" for (v, ind, sz) in (
(Int8(2), 3, 4),
(3.0, 5, 4),
(3.0, 0, 0),
(SMatrix{2,2}(1:4), (4, 2), (12,6)),
)
O = OneElement(v,ind,sz)
A = Array(O)
if VERSION >= v"1.10"
@test @inferred(sum(O)) === sum(A)
else
@test @inferred(sum(O)) == sum(A)
end
@test @inferred(sum(O, init=zero(eltype(O)))) === sum(A, init=zero(eltype(O)))
@test @inferred(sum(x->1, O, init=0)) === sum(Fill(1, axes(O)), init=0)
end

@testset for O in (OneElement(Int8(2), (1,2), (2,4)),
OneElement(3, (1,2,3), (2,4,4)),
OneElement(2.0, (3,2,5), (2,3,2)),
OneElement(SMatrix{2,2}(1:4), (1,2), (2,4)),
)
A = Array(O)
init = sum((zero(FillArrays.getindex_value(O)),))
for i in 1:3
@test @inferred(sum(O, dims=i)) == sum(A, dims=i)
@test @inferred(sum(O, dims=i, init=init)) == sum(A, dims=i, init=init)
@test @inferred(sum(x->1, O, dims=i, init=0)) == sum(Fill(1, axes(O)), dims=i, init=0)
end
@test @inferred(sum(O, dims=1:1)) == sum(A, dims=1:1)
@test @inferred(sum(O, dims=1:2)) == sum(A, dims=1:2)
@test @inferred(sum(O, dims=1:3)) == sum(A, dims=1:3)
@test @inferred(sum(O, dims=(1,))) == sum(A, dims=(1,))
@test @inferred(sum(O, dims=(1,2))) == sum(A, dims=(1,2))
@test @inferred(sum(O, dims=(1,3))) == sum(A, dims=(1,3))
@test @inferred(sum(O, dims=(2,3))) == sum(A, dims=(2,3))
@test @inferred(sum(O, dims=(1,2,3))) == sum(A, dims=(1,2,3))
@test @inferred(sum(O, dims=1:1, init=init)) == sum(A, dims=1:1, init=init)
@test @inferred(sum(O, dims=1:2, init=init)) == sum(A, dims=1:2, init=init)
@test @inferred(sum(O, dims=1:3, init=init)) == sum(A, dims=1:3, init=init)
@test @inferred(sum(O, dims=(1,), init=init)) == sum(A, dims=(1,), init=init)
@test @inferred(sum(O, dims=(1,2), init=init)) == sum(A, dims=(1,2), init=init)
@test @inferred(sum(O, dims=(1,3), init=init)) == sum(A, dims=(1,3), init=init)
@test @inferred(sum(O, dims=(2,3), init=init)) == sum(A, dims=(2,3), init=init)
@test @inferred(sum(O, dims=(1,2,3), init=init)) == sum(A, dims=(1,2,3), init=init)
@test @inferred(sum(x->1, O, dims=(1,2,3), init=0)) == sum(Fill(1, axes(O)), dims=(1,2,3), init=0)
end
end

@testset "diag" begin
@testset for sz in [(0,0), (0,1), (1,0), (1,1), (4,4), (4,6), (6,3)], ind in CartesianIndices(sz)
O = OneElement(4, Tuple(ind), sz)
@testset for k in -maximum(sz):maximum(sz)
@test diag(O, k) == diag(Array(O), k)
@test diag(O, k) isa OneElement{Int,1}
end
end
end
end

@testset "repeat" begin
Expand Down

0 comments on commit 2ae63f1

Please sign in to comment.