Skip to content

Commit

Permalink
[BlockSparseArrays] Fix nstored and norm(NaN) (#1565)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe authored Nov 5, 2024
1 parent e51fba3 commit ebde621
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using NDTensors.GradedAxes:
isdual
using NDTensors.LabelledNumbers: label
using NDTensors.SparseArrayInterface: nstored
using NDTensors.SymmetrySectors: U1
using NDTensors.TensorAlgebra: fusedims, splitdims
using LinearAlgebra: adjoint
using Random: randn!
Expand All @@ -26,13 +27,6 @@ function blockdiagonal!(f, a::AbstractArray)
return a
end

struct U1
n::Int
end
GradedAxes.dual(c::U1) = U1(-c.n)
GradedAxes.fuse_labels(c1::U1, c2::U1) = U1(c1.n + c2.n)
Base.isless(c1::U1, c2::U1) = isless(c1.n, c2.n)

const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "BlockSparseArraysGradedAxesExt (eltype=$elt)" for elt in elts
@testset "map" begin
Expand Down Expand Up @@ -66,8 +60,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test blocksize(b) == (2, 2, 2, 2)
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
@test nstored(b) == 256
# TODO: Fix this for `BlockedArray`.
@test_broken block_nstored(b) == 16
@test block_nstored(b) == 16
for i in 1:ndims(a)
@test axes(b, i) isa BlockedOneTo{Int}
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,5 +305,4 @@ function blocksparse_blocks(
end

using BlockArrays: BlocksView
# TODO: Is this correct in general?
SparseArrayInterface.nstored(a::BlocksView) = 1
SparseArrayInterface.nstored(a::BlocksView) = length(a)
5 changes: 4 additions & 1 deletion NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using BlockArrays:
blocksizes,
mortar
using Compat: @compat
using LinearAlgebra: Adjoint, mul!
using LinearAlgebra: Adjoint, mul!, norm
using NDTensors.BlockSparseArrays:
@view!,
BlockSparseArray,
Expand Down Expand Up @@ -94,6 +94,9 @@ include("TestBlockSparseArraysUtils.jl")
iszero(a[I])
end
end

a[3, 3] = NaN
@test isnan(norm(a))
end
@testset "Tensor algebra" begin
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,6 @@ end
function sparse_mapreduce(f, op, a::AbstractArray; init=reduce_init(f, op, a), kwargs...)
output = mapreduce(f, op, sparse_storage(a); init, kwargs...)
f_notstored = apply_notstored(f, a)
@assert op(output, eltype(output)(f_notstored)) == output
@assert isequal(op(output, eltype(output)(f_notstored)), output)
return output
end

0 comments on commit ebde621

Please sign in to comment.