Skip to content

Commit

Permalink
don't use reducedim_initarray
Browse files Browse the repository at this point in the history
  • Loading branch information
adienes committed Dec 21, 2024
1 parent 39a7e3c commit ae41002
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,13 @@ eigen(M::Bidiagonal) = Eigen(eigvals(M), eigvecs(M))

Base._sum(A::Bidiagonal, ::Colon) = sum(A.dv) + sum(A.ev)
function Base._sum(A::Bidiagonal, dims::Integer)
res = Base.reducedim_initarray(A, dims, zero(eltype(A)))
Base._check_valid_region(dims)
ax = (dims == 1) ? (1, axes(A, 2)) :
(dims == 2) ? (axes(A, 1), 1) :
axes(A)
res = Base.mapreduce_similar(A, eltype(A), ax)
fill!(res, zero(eltype(A)))

n = length(A.dv)
if n == 0
# Just to be sure. This shouldn't happen since there is a check whether
Expand Down
8 changes: 7 additions & 1 deletion src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,13 @@ end

Base._sum(A::Diagonal, ::Colon) = sum(A.diag)
function Base._sum(A::Diagonal, dims::Integer)
res = Base.reducedim_initarray(A, dims, zero(eltype(A)))
Base._check_valid_region(dims)
ax = (dims == 1) ? (1, axes(A, 2)) :
(dims == 2) ? (axes(A, 1), 1) :
axes(A)
res = Base.mapreduce_similar(A, eltype(A), ax)
fill!(res, zero(eltype(A)))

if dims <= 2
for i = 1:length(A.diag)
@inbounds res[i] = A.diag[i]
Expand Down
16 changes: 14 additions & 2 deletions src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,13 @@ function Base._sum(A::SymTridiagonal, ::Colon)
end

function Base._sum(A::Tridiagonal, dims::Integer)
res = Base.reducedim_initarray(A, dims, zero(eltype(A)))
Base._check_valid_region(dims)
ax = (dims == 1) ? (1, axes(A, 2)) :
(dims == 2) ? (axes(A, 1), 1) :
axes(A)
res = Base.mapreduce_similar(A, eltype(A), ax)
fill!(res, zero(eltype(A)))

n = length(A.d)
if n == 0
return res
Expand Down Expand Up @@ -927,7 +933,13 @@ function Base._sum(A::Tridiagonal, dims::Integer)
end

function Base._sum(A::SymTridiagonal, dims::Integer)
res = Base.reducedim_initarray(A, dims, zero(eltype(A)))
Base._check_valid_region(dims)
ax = (dims == 1) ? (1, axes(A, 2)) :
(dims == 2) ? (axes(A, 1), 1) :
axes(A)
res = Base.mapreduce_similar(A, eltype(A), ax)
fill!(res, zero(eltype(A)))

n = length(A.dv)
if n == 0
return res
Expand Down

0 comments on commit ae41002

Please sign in to comment.