From 2869a8c6eb9baccbe6307992b94302be101e046f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 5 Sep 2024 19:38:58 -0400 Subject: [PATCH 01/43] pass GradedAxes test --- NDTensors/src/lib/GradedAxes/src/fusion.jl | 58 ++++++++-- .../src/lib/GradedAxes/src/unitrangedual.jl | 103 ++++++++++++------ .../src/lib/GradedAxes/test/test_dual.jl | 49 ++++++++- 3 files changed, 165 insertions(+), 45 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/fusion.jl b/NDTensors/src/lib/GradedAxes/src/fusion.jl index 5193449f2f..1fd1ef911e 100644 --- a/NDTensors/src/lib/GradedAxes/src/fusion.jl +++ b/NDTensors/src/lib/GradedAxes/src/fusion.jl @@ -26,11 +26,11 @@ function tensor_product(a1::Base.OneTo, a2::Base.OneTo) return Base.OneTo(length(a1) * length(a2)) end -function tensor_product(a1::OneToOne, a2::AbstractUnitRange) +function tensor_product(::OneToOne, a2::AbstractBlockedUnitRange) return a2 end -function tensor_product(a1::AbstractUnitRange, a2::OneToOne) +function tensor_product(a1::AbstractBlockedUnitRange, ::OneToOne) return a1 end @@ -38,6 +38,19 @@ function tensor_product(a1::OneToOne, a2::OneToOne) return OneToOne() end +# Handle dual. Always return a non-dual GradedUnitRange. +function tensor_product(a1::AbstractBlockedUnitRange, a2::BlockedUnitRangeDual) + return tensor_product(a1, flip(a2)) +end + +function tensor_product(a1::BlockedUnitRangeDual, a2::AbstractBlockedUnitRange) + return tensor_product(flip(a1), a2) +end + +function tensor_product(a1::BlockedUnitRangeDual, a2::BlockedUnitRangeDual) + return tensor_product(flip(a1), flip(a2)) +end + function fuse_labels(x, y) return error( "`fuse_labels` not implemented for object of type `$(typeof(x))` and `$(typeof(y))`." @@ -68,6 +81,11 @@ function blocksortperm(a::AbstractBlockedUnitRange) return Block.(sortperm(blocklabels(a))) end +# convention: sort BlockedUnitRangeDual according to nondual blocks +function blocksortperm(a::BlockedUnitRangeDual) + return Block.(sortperm(blocklabels(nondual(a)))) +end + using BlockArrays: Block, BlockVector using SplitApplyCombine: groupcount # Get the permutation for sorting, then group by common elements. @@ -95,12 +113,32 @@ end # Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`. invblockperm(a::Vector{<:Block{1}}) = Block.(invperm(Int.(a))) -# Used by `TensorAlgebra.fusedims` in `BlockSparseArraysGradedAxesExt`. -function blockmergesortperm(a::GradedUnitRange) - # If it is dual, reverse the sorting so the sectors - # end up sorted in the same way whether or not the space - # is dual. - # TODO: Figure out how to deal with dual sectors. - # TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`. - return Block.(groupsortperm(blocklabels(a))) +function blockmergesortperm(a::BlockedUnitRangeDual) + return Block.(groupsortperm(blocklabels(nondual(a)))) +end + +function blockmergesort(g::AbstractGradedUnitRange) + glabels = blocklabels(g) + gblocklengths = blocklengths(g) + new_blocklengths = map( + la -> labelled(sum(gblocklengths[findall(==(la), glabels)]; init=0), la), + sort(unique(glabels)), + ) + return GradedAxes.gradedrange(new_blocklengths) +end + +blockmergesort(g::BlockedUnitRangeDual) = dual(blockmergesort(flip(g))) +blockmergesort(g::OneToOne) = g + +# fusion_product produces a sorted, non-dual GradedUnitRange +function fusion_product(g1, g2) + return blockmergesort(tensor_product(g1, g2)) +end + +fusion_product(g::AbstractUnitRange) = blockmergesort(g) +fusion_product(g::BlockedUnitRangeDual) = fusion_product(flip(g)) + +# recursive fusion_product. Simpler than reduce + fix type stability issues with reduce +function fusion_product(g1, g2, g3...) + return fusion_product(fusion_product(g1, g2), g3...) end diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index aa04cc1600..67ff1b9ddc 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -1,36 +1,65 @@ -struct UnitRangeDual{T,NondualUnitRange<:AbstractUnitRange} <: AbstractUnitRange{T} +struct BlockedUnitRangeDual{T<:Integer,NondualUnitRange<:AbstractUnitRange} <: + AbstractBlockedUnitRange{T,Vector{T}} nondual_unitrange::NondualUnitRange end -UnitRangeDual(a::AbstractUnitRange) = UnitRangeDual{eltype(a),typeof(a)}(a) +BlockedUnitRangeDual(a::AbstractUnitRange) = BlockedUnitRangeDual{eltype(a),typeof(a)}(a) -dual(a::AbstractUnitRange) = UnitRangeDual(a) -nondual(a::UnitRangeDual) = a.nondual_unitrange -dual(a::UnitRangeDual) = nondual(a) +dual(a::AbstractUnitRange) = BlockedUnitRangeDual(a) +nondual(a::BlockedUnitRangeDual) = a.nondual_unitrange +dual(a::BlockedUnitRangeDual) = nondual(a) +flip(a::BlockedUnitRangeDual) = dual(flip(nondual(a))) nondual(a::AbstractUnitRange) = a +isdual(::AbstractGradedUnitRange) = false +isdual(::BlockedUnitRangeDual) = true ## TODO: Define this to instantiate a dual unit range. -## materialize_dual(a::UnitRangeDual) = materialize_dual(nondual(a)) +## materialize_dual(a::BlockedUnitRangeDual) = materialize_dual(nondual(a)) -Base.first(a::UnitRangeDual) = label_dual(first(nondual(a))) -Base.last(a::UnitRangeDual) = label_dual(last(nondual(a))) -Base.step(a::UnitRangeDual) = label_dual(step(nondual(a))) +Base.first(a::BlockedUnitRangeDual) = label_dual(first(nondual(a))) +Base.last(a::BlockedUnitRangeDual) = label_dual(last(nondual(a))) +Base.step(a::BlockedUnitRangeDual) = label_dual(step(nondual(a))) -Base.view(a::UnitRangeDual, index::Block{1}) = a[index] +Base.view(a::BlockedUnitRangeDual, index::Block{1}) = a[index] -function Base.getindex(a::UnitRangeDual, indices::AbstractUnitRange{<:Integer}) +function Base.show(io::IO, a::BlockedUnitRangeDual) + return print(io, BlockedUnitRangeDual, "(", blocklasts(a), ")") +end + +function Base.show(io::IO, mimetype::MIME"text/plain", a::BlockedUnitRangeDual) + return Base.invoke( + show, Tuple{typeof(io),MIME"text/plain",AbstractArray}, io, mimetype, a + ) +end + +function Base.getindex(a::BlockedUnitRangeDual, indices::AbstractUnitRange{<:Integer}) return dual(getindex(nondual(a), indices)) end using BlockArrays: Block, BlockIndexRange, BlockRange -function Base.getindex(a::UnitRangeDual, indices::Integer) +function Base.getindex(a::BlockedUnitRangeDual, indices::Integer) return label_dual(getindex(nondual(a), indices)) end # TODO: Use `label_dual.` here, make broadcasting work? -Base.getindex(a::UnitRangeDual, indices::Block{1}) = dual(getindex(nondual(a), indices)) +function Base.getindex(a::BlockedUnitRangeDual, indices::Block{1}) + return dual(getindex(nondual(a), indices)) +end # TODO: Use `label_dual.` here, make broadcasting work? -Base.getindex(a::UnitRangeDual, indices::BlockRange) = dual(getindex(nondual(a), indices)) +function Base.getindex(a::BlockedUnitRangeDual, indices::BlockRange) + return dual(getindex(nondual(a), indices)) +end + +# fix ambiguity +function Base.getindex( + a::BlockedUnitRangeDual, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} +) + return dual(getindex(nondual(a), indices)) +end + +function BlockArrays.blocklengths(a::BlockedUnitRangeDual) + return dual.(blocklengths(nondual(a))) +end # TODO: Use `label_dual.` here, make broadcasting work? function unitrangedual_getindices_blocks(a, indices) @@ -39,23 +68,23 @@ function unitrangedual_getindices_blocks(a, indices) end # TODO: Move this to a `BlockArraysExtensions` library. -function blockedunitrange_getindices(a::UnitRangeDual, indices::Block{1}) +function blockedunitrange_getindices(a::BlockedUnitRangeDual, indices::Block{1}) return a[indices] end -function Base.getindex(a::UnitRangeDual, indices::Vector{<:Block{1}}) +function Base.getindex(a::BlockedUnitRangeDual, indices::Vector{<:Block{1}}) return unitrangedual_getindices_blocks(a, indices) end -function Base.getindex(a::UnitRangeDual, indices::Vector{<:BlockIndexRange{1}}) +function Base.getindex(a::BlockedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}) return unitrangedual_getindices_blocks(a, indices) end -function to_blockindices(a::UnitRangeDual, indices::UnitRange{<:Integer}) +function to_blockindices(a::BlockedUnitRangeDual, indices::UnitRange{<:Integer}) return to_blockindices(nondual(a), indices) end -Base.axes(a::UnitRangeDual) = axes(nondual(a)) +Base.axes(a::BlockedUnitRangeDual) = axes(nondual(a)) using BlockArrays: BlockArrays, Block, BlockSlice using NDTensors.LabelledNumbers: LabelledUnitRange @@ -64,35 +93,45 @@ function BlockArrays.BlockSlice(b::Block, a::LabelledUnitRange) end using BlockArrays: BlockArrays, BlockSlice -using NDTensors.GradedAxes: UnitRangeDual, dual -function BlockArrays.BlockSlice(b::Block, r::UnitRangeDual) +using NDTensors.GradedAxes: BlockedUnitRangeDual, dual +function BlockArrays.BlockSlice(b::Block, r::BlockedUnitRangeDual) return BlockSlice(b, dual(r)) end using NDTensors.LabelledNumbers: LabelledNumbers, label -LabelledNumbers.label(a::UnitRangeDual) = dual(label(nondual(a))) +LabelledNumbers.label(a::BlockedUnitRangeDual) = dual(label(nondual(a))) using NDTensors.LabelledNumbers: LabelledUnitRange # The Base version of `length(::AbstractUnitRange)` drops the label. -function Base.length(a::UnitRangeDual{<:Any,<:LabelledUnitRange}) +function Base.length(a::BlockedUnitRangeDual{<:Any,<:LabelledUnitRange}) return dual(length(nondual(a))) end -function Base.iterate(a::UnitRangeDual, i) +function Base.iterate(a::BlockedUnitRangeDual, i) i == last(a) && return nothing return dual.(iterate(nondual(a), i)) end # TODO: Is this a good definition? -Base.unitrange(a::UnitRangeDual{<:Any,<:AbstractUnitRange}) = a +Base.unitrange(a::BlockedUnitRangeDual{<:Any,<:AbstractUnitRange}) = a using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i))) using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock -BlockArrays.blockaxes(a::UnitRangeDual) = blockaxes(nondual(a)) -BlockArrays.blockfirsts(a::UnitRangeDual) = label_dual.(blockfirsts(nondual(a))) -BlockArrays.blocklasts(a::UnitRangeDual) = label_dual.(blocklasts(nondual(a))) -BlockArrays.findblock(a::UnitRangeDual, index::Integer) = findblock(nondual(a), index) -function BlockArrays.combine_blockaxes(a1::UnitRangeDual, a2::UnitRangeDual) +BlockArrays.blockaxes(a::BlockedUnitRangeDual) = blockaxes(nondual(a)) +BlockArrays.blockfirsts(a::BlockedUnitRangeDual) = label_dual.(blockfirsts(nondual(a))) +BlockArrays.blocklasts(a::BlockedUnitRangeDual) = label_dual.(blocklasts(nondual(a))) +function BlockArrays.findblock(a::BlockedUnitRangeDual, index::Integer) + return findblock(nondual(a), index) +end + +blocklabels(a::BlockedUnitRangeDual) = dual.(blocklabels(nondual(a))) + +gradedisequal(::BlockedUnitRangeDual, ::AbstractGradedUnitRange) = false +gradedisequal(::AbstractGradedUnitRange, ::BlockedUnitRangeDual) = false +function gradedisequal(a1::BlockedUnitRangeDual, a2::BlockedUnitRangeDual) + return gradedisequal(nondual(a1), nondual(a2)) +end +function BlockArrays.combine_blockaxes(a1::BlockedUnitRangeDual, a2::BlockedUnitRangeDual) return dual(combine_blockaxes(dual(a1), dual(a2))) end @@ -102,7 +141,7 @@ end # `CartesianIndices`, maybe by defining conversion of `LabelledInteger` # to `Int`, defining a more general `convert` function, etc. function Base.OrdinalRange{Int,Int}( - r::UnitRangeDual{<:LabelledInteger{Int},<:LabelledUnitRange{Int,UnitRange{Int}}} + r::BlockedUnitRangeDual{<:LabelledInteger{Int},<:LabelledUnitRange{Int,UnitRange{Int}}} ) # TODO: Implement this broadcasting operation and use it here. # return Int.(r) @@ -114,6 +153,6 @@ end # TODO: Delete this once we drop Julia 1.6 support. # The type constraint `T<:Integer` is needed to avoid an ambiguity # error with a conversion method in Base. -function Base.UnitRange{T}(a::UnitRangeDual{<:LabelledInteger{T}}) where {T<:Integer} +function Base.UnitRange{T}(a::BlockedUnitRangeDual{<:LabelledInteger{T}}) where {T<:Integer} return UnitRange{T}(nondual(a)) end diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 0fb15d31bf..447a9a31ac 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -1,6 +1,18 @@ @eval module $(gensym()) -using BlockArrays: Block, blockaxes, blockfirsts, blocklasts, blocks, findblock -using NDTensors.GradedAxes: GradedAxes, UnitRangeDual, dual, gradedrange, nondual +using BlockArrays: + Block, blockaxes, blockfirsts, blocklasts, blocklength, blocklengths, blocks, findblock +using NDTensors.GradedAxes: + GradedAxes, + BlockedUnitRangeDual, + blocklabels, + blockmergesortperm, + blocksortperm, + dual, + flip, + gradedisequal, + gradedrange, + isdual, + nondual using NDTensors.LabelledNumbers: LabelledInteger, label, labelled using Test: @test, @test_broken, @testset struct U1 @@ -22,7 +34,7 @@ GradedAxes.dual(c::U1) = U1(-c.n) @test ad[4] == 4 @test label(ad[4]) == U1(-1) @test ad[2:4] == 2:4 - @test ad[2:4] isa UnitRangeDual + @test ad[2:4] isa BlockedUnitRangeDual @test label(ad[2:4][Block(2)]) == U1(-1) @test ad[[2, 4]] == [2, 4] @test label(ad[[2, 4]][2]) == U1(-1) @@ -34,5 +46,36 @@ GradedAxes.dual(c::U1) = U1(-c.n) @test label(ad[[Block(2), Block(1)]][Block(1)]) == U1(-1) @test ad[[Block(2)[1:2], Block(1)[1:2]]][Block(1)] == 3:4 @test label(ad[[Block(2)[1:2], Block(1)[1:2]]][Block(1)]) == U1(-1) + @test blocksortperm(a) == [Block(1), Block(2)] + @test blocksortperm(ad) == [Block(1), Block(2)] + @test blocklength(blockmergesortperm(a)) == 2 + @test blocklength(blockmergesortperm(ad)) == 2 + @test blockmergesortperm(a) == [Block(1), Block(2)] + @test blockmergesortperm(ad) == [Block(1), Block(2)] +end + +@testset "flip" begin + a = gradedrange([U1(0) => 2, U1(1) => 3]) + ad = dual(a) + @test gradedisequal(flip(a), dual(gradedrange([U1(0) => 2, U1(-1) => 3]))) + @test gradedisequal(flip(ad), gradedrange([U1(0) => 2, U1(-1) => 3])) + + @test blocklabels(a) == [U1(0), U1(1)] + @test blocklabels(dual(a)) == [U1(0), U1(-1)] + @test blocklabels(flip(a)) == [U1(0), U1(1)] + @test blocklabels(flip(dual(a))) == [U1(0), U1(-1)] + @test blocklabels(dual(flip(a))) == [U1(0), U1(-1)] + + @test blocklengths(a) == [2, 3] + @test blocklengths(ad) == [2, 3] + @test blocklengths(flip(a)) == [2, 3] + @test blocklengths(flip(ad)) == [2, 3] + @test blocklengths(dual(flip(a))) == [2, 3] + + @test !isdual(a) + @test isdual(ad) + @test isdual(flip(a)) + @test !isdual(flip(ad)) + @test !isdual(dual(flip(a))) end end From b21f404f9912eee2abbdb43faa40ca55ac7a980e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 10 Sep 2024 19:24:35 -0400 Subject: [PATCH 02/43] separate types UnitRangeDual and GradedUnitRangeDual --- .../src/lib/GradedAxes/src/GradedAxes.jl | 2 +- NDTensors/src/lib/GradedAxes/src/fusion.jl | 16 +- .../lib/GradedAxes/src/gradedunitrangedual.jl | 154 ++++++++++++++++++ .../src/lib/GradedAxes/src/unitrangedual.jl | 116 +++++-------- .../src/lib/GradedAxes/test/test_dual.jl | 4 +- 5 files changed, 203 insertions(+), 89 deletions(-) create mode 100644 NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl diff --git a/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl b/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl index 7756b71575..cbdec545e0 100644 --- a/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl +++ b/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl @@ -1,7 +1,7 @@ module GradedAxes include("blockedunitrange.jl") include("gradedunitrange.jl") -include("fusion.jl") +include("gradedunitrangedual.jl") include("dual.jl") include("unitrangedual.jl") include("../ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl") diff --git a/NDTensors/src/lib/GradedAxes/src/fusion.jl b/NDTensors/src/lib/GradedAxes/src/fusion.jl index 1fd1ef911e..419a28c03e 100644 --- a/NDTensors/src/lib/GradedAxes/src/fusion.jl +++ b/NDTensors/src/lib/GradedAxes/src/fusion.jl @@ -39,15 +39,15 @@ function tensor_product(a1::OneToOne, a2::OneToOne) end # Handle dual. Always return a non-dual GradedUnitRange. -function tensor_product(a1::AbstractBlockedUnitRange, a2::BlockedUnitRangeDual) +function tensor_product(a1::AbstractBlockedUnitRange, a2::GradedUnitRangeDual) return tensor_product(a1, flip(a2)) end -function tensor_product(a1::BlockedUnitRangeDual, a2::AbstractBlockedUnitRange) +function tensor_product(a1::GradedUnitRangeDual, a2::AbstractBlockedUnitRange) return tensor_product(flip(a1), a2) end -function tensor_product(a1::BlockedUnitRangeDual, a2::BlockedUnitRangeDual) +function tensor_product(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) return tensor_product(flip(a1), flip(a2)) end @@ -81,8 +81,8 @@ function blocksortperm(a::AbstractBlockedUnitRange) return Block.(sortperm(blocklabels(a))) end -# convention: sort BlockedUnitRangeDual according to nondual blocks -function blocksortperm(a::BlockedUnitRangeDual) +# convention: sort GradedUnitRangeDual according to nondual blocks +function blocksortperm(a::GradedUnitRangeDual) return Block.(sortperm(blocklabels(nondual(a)))) end @@ -113,7 +113,7 @@ end # Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`. invblockperm(a::Vector{<:Block{1}}) = Block.(invperm(Int.(a))) -function blockmergesortperm(a::BlockedUnitRangeDual) +function blockmergesortperm(a::GradedUnitRangeDual) return Block.(groupsortperm(blocklabels(nondual(a)))) end @@ -127,7 +127,7 @@ function blockmergesort(g::AbstractGradedUnitRange) return GradedAxes.gradedrange(new_blocklengths) end -blockmergesort(g::BlockedUnitRangeDual) = dual(blockmergesort(flip(g))) +blockmergesort(g::GradedUnitRangeDual) = dual(blockmergesort(flip(g))) blockmergesort(g::OneToOne) = g # fusion_product produces a sorted, non-dual GradedUnitRange @@ -136,7 +136,7 @@ function fusion_product(g1, g2) end fusion_product(g::AbstractUnitRange) = blockmergesort(g) -fusion_product(g::BlockedUnitRangeDual) = fusion_product(flip(g)) +fusion_product(g::GradedUnitRangeDual) = fusion_product(flip(g)) # recursive fusion_product. Simpler than reduce + fix type stability issues with reduce function fusion_product(g1, g2, g3...) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl new file mode 100644 index 0000000000..f5e5f986a3 --- /dev/null +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -0,0 +1,154 @@ +struct GradedUnitRangeDual{ + T<:LabelledInteger,NondualUnitRange<:AbstractGradedUnitRange{T} +} <: AbstractGradedUnitRange{T,Vector{T}} + nondual_unitrange::NondualUnitRange +end + +dual(a::AbstractGradedUnitRange) = GradedUnitRangeDual(a) +nondual(a::GradedUnitRangeDual) = a.nondual_unitrange +dual(a::GradedUnitRangeDual) = nondual(a) +flip(a::GradedUnitRangeDual) = dual(flip(nondual(a))) +isdual(::AbstractGradedUnitRange) = false +isdual(::GradedUnitRangeDual) = true +## TODO: Define this to instantiate a dual unit range. +## materialize_dual(a::GradedUnitRangeDual) = materialize_dual(nondual(a)) + +Base.first(a::GradedUnitRangeDual) = label_dual(first(nondual(a))) +Base.last(a::GradedUnitRangeDual) = label_dual(last(nondual(a))) +Base.step(a::GradedUnitRangeDual) = label_dual(step(nondual(a))) + +Base.view(a::GradedUnitRangeDual, index::Block{1}) = a[index] + +function Base.show(io::IO, a::GradedUnitRangeDual) + return print(io, GradedUnitRangeDual, "(", blocklasts(a), ")") +end + +function Base.show(io::IO, mimetype::MIME"text/plain", a::GradedUnitRangeDual) + return Base.invoke( + show, Tuple{typeof(io),MIME"text/plain",AbstractArray}, io, mimetype, a + ) +end + +function Base.getindex(a::GradedUnitRangeDual, indices::AbstractUnitRange{<:Integer}) + return dual(getindex(nondual(a), indices)) +end + +using BlockArrays: Block, BlockIndexRange, BlockRange + +function Base.getindex(a::GradedUnitRangeDual, indices::Integer) + return label_dual(getindex(nondual(a), indices)) +end + +function Base.getindex(a::GradedUnitRangeDual, indices::Block{1}) + return label_dual(getindex(nondual(a), indices)) +end + +function Base.getindex(a::GradedUnitRangeDual, indices::BlockRange) + return label_dual(getindex(nondual(a), indices)) +end + +# fix ambiguity +function Base.getindex( + a::GradedUnitRangeDual, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} +) + return dual(getindex(nondual(a), indices)) +end + +function BlockArrays.blocklengths(a::GradedUnitRangeDual) + return dual.(blocklengths(nondual(a))) +end + +function unitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices) + a_indices = getindex(nondual(a), indices) + return mortar([label_dual(b) for b in blocks(a_indices)]) +end + +# TODO: Move this to a `BlockArraysExtensions` library. +function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1}) + return a[indices] +end + +function Base.getindex(a::GradedUnitRangeDual, indices::Vector{<:Block{1}}) + return unitrangedual_getindices_blocks(a, indices) +end + +function Base.getindex(a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}) + return unitrangedual_getindices_blocks(a, indices) +end + +function to_blockindices(a::GradedUnitRangeDual, indices::UnitRange{<:Integer}) + return to_blockindices(nondual(a), indices) +end + +Base.axes(a::GradedUnitRangeDual) = axes(nondual(a)) + +using BlockArrays: BlockArrays, Block, BlockSlice +using NDTensors.LabelledNumbers: LabelledUnitRange +function BlockArrays.BlockSlice(b::Block, a::LabelledUnitRange) + return BlockSlice(b, unlabel(a)) +end + +using BlockArrays: BlockArrays, BlockSlice +using NDTensors.GradedAxes: GradedUnitRangeDual, dual +function BlockArrays.BlockSlice(b::Block, r::GradedUnitRangeDual) + return BlockSlice(b, dual(r)) +end + +using NDTensors.LabelledNumbers: LabelledNumbers, label +LabelledNumbers.label(a::GradedUnitRangeDual) = dual(label(nondual(a))) + +using NDTensors.LabelledNumbers: LabelledUnitRange +# The Base version of `length(::AbstractUnitRange)` drops the label. +function Base.length(a::GradedUnitRangeDual{<:Any,<:LabelledUnitRange}) + return dual(length(nondual(a))) +end +function Base.iterate(a::GradedUnitRangeDual, i) + i == last(a) && return nothing + return dual.(iterate(nondual(a), i)) +end +# TODO: Is this a good definition? +Base.unitrange(a::GradedUnitRangeDual) = a + +using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel +dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i))) + +using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock +BlockArrays.blockaxes(a::GradedUnitRangeDual) = blockaxes(nondual(a)) +BlockArrays.blockfirsts(a::GradedUnitRangeDual) = label_dual.(blockfirsts(nondual(a))) +BlockArrays.blocklasts(a::GradedUnitRangeDual) = label_dual.(blocklasts(nondual(a))) +function BlockArrays.findblock(a::GradedUnitRangeDual, index::Integer) + return findblock(nondual(a), index) +end + +blocklabels(a::GradedUnitRangeDual) = dual.(blocklabels(nondual(a))) + +gradedisequal(::GradedUnitRangeDual, ::AbstractGradedUnitRange) = false +gradedisequal(::AbstractGradedUnitRange, ::GradedUnitRangeDual) = false +function gradedisequal(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) + return gradedisequal(nondual(a1), nondual(a2)) +end +function BlockArrays.combine_blockaxes(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) + return dual(combine_blockaxes(dual(a1), dual(a2))) +end + +# This is needed when constructing `CartesianIndices` from +# a tuple of unit ranges that have this kind of dual unit range. +# TODO: See if we can find some more elegant way of constructing +# `CartesianIndices`, maybe by defining conversion of `LabelledInteger` +# to `Int`, defining a more general `convert` function, etc. +function Base.OrdinalRange{Int,Int}( + r::GradedUnitRangeDual{<:LabelledInteger{Int},<:LabelledUnitRange{Int,UnitRange{Int}}} +) + # TODO: Implement this broadcasting operation and use it here. + # return Int.(r) + return unlabel(nondual(r)) +end + +# This is only needed in certain Julia versions below 1.10 +# (for example Julia 1.6). +# TODO: Delete this once we drop Julia 1.6 support. +# The type constraint `T<:Integer` is needed to avoid an ambiguity +# error with a conversion method in Base. +function Base.UnitRange{T}(a::GradedUnitRangeDual{<:LabelledInteger{T}}) where {T<:Integer} + return UnitRange{T}(nondual(a)) +end diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index 67ff1b9ddc..6f32958d0c 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -1,63 +1,61 @@ -struct BlockedUnitRangeDual{T<:Integer,NondualUnitRange<:AbstractUnitRange} <: - AbstractBlockedUnitRange{T,Vector{T}} +struct UnitRangeDual{T<:Integer,NondualUnitRange<:AbstractUnitRange} <: AbstractUnitRange{T} nondual_unitrange::NondualUnitRange end -BlockedUnitRangeDual(a::AbstractUnitRange) = BlockedUnitRangeDual{eltype(a),typeof(a)}(a) +UnitRangeDual(a::AbstractUnitRange) = UnitRangeDual{eltype(a),typeof(a)}(a) -dual(a::AbstractUnitRange) = BlockedUnitRangeDual(a) -nondual(a::BlockedUnitRangeDual) = a.nondual_unitrange -dual(a::BlockedUnitRangeDual) = nondual(a) -flip(a::BlockedUnitRangeDual) = dual(flip(nondual(a))) +dual(a::AbstractUnitRange) = UnitRangeDual(a) +nondual(a::UnitRangeDual) = a.nondual_unitrange +dual(a::UnitRangeDual) = nondual(a) +flip(a::UnitRangeDual) = dual(flip(nondual(a))) nondual(a::AbstractUnitRange) = a -isdual(::AbstractGradedUnitRange) = false -isdual(::BlockedUnitRangeDual) = true +isdual(::UnitRangeDual) = true ## TODO: Define this to instantiate a dual unit range. -## materialize_dual(a::BlockedUnitRangeDual) = materialize_dual(nondual(a)) +## materialize_dual(a::UnitRangeDual) = materialize_dual(nondual(a)) -Base.first(a::BlockedUnitRangeDual) = label_dual(first(nondual(a))) -Base.last(a::BlockedUnitRangeDual) = label_dual(last(nondual(a))) -Base.step(a::BlockedUnitRangeDual) = label_dual(step(nondual(a))) +Base.first(a::UnitRangeDual) = first(nondual(a)) +Base.last(a::UnitRangeDual) = last(nondual(a)) +Base.step(a::UnitRangeDual) = step(nondual(a)) -Base.view(a::BlockedUnitRangeDual, index::Block{1}) = a[index] +Base.view(a::UnitRangeDual, index::Block{1}) = a[index] -function Base.show(io::IO, a::BlockedUnitRangeDual) - return print(io, BlockedUnitRangeDual, "(", blocklasts(a), ")") +function Base.show(io::IO, a::UnitRangeDual) + return print(io, UnitRangeDual, "(", blocklasts(a), ")") end -function Base.show(io::IO, mimetype::MIME"text/plain", a::BlockedUnitRangeDual) +function Base.show(io::IO, mimetype::MIME"text/plain", a::UnitRangeDual) return Base.invoke( show, Tuple{typeof(io),MIME"text/plain",AbstractArray}, io, mimetype, a ) end -function Base.getindex(a::BlockedUnitRangeDual, indices::AbstractUnitRange{<:Integer}) +function Base.getindex(a::UnitRangeDual, indices::AbstractUnitRange{<:Integer}) return dual(getindex(nondual(a), indices)) end using BlockArrays: Block, BlockIndexRange, BlockRange -function Base.getindex(a::BlockedUnitRangeDual, indices::Integer) +function Base.getindex(a::UnitRangeDual, indices::Integer) return label_dual(getindex(nondual(a), indices)) end # TODO: Use `label_dual.` here, make broadcasting work? -function Base.getindex(a::BlockedUnitRangeDual, indices::Block{1}) +function Base.getindex(a::UnitRangeDual, indices::Block{1}) return dual(getindex(nondual(a), indices)) end # TODO: Use `label_dual.` here, make broadcasting work? -function Base.getindex(a::BlockedUnitRangeDual, indices::BlockRange) +function Base.getindex(a::UnitRangeDual, indices::BlockRange) return dual(getindex(nondual(a), indices)) end # fix ambiguity function Base.getindex( - a::BlockedUnitRangeDual, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} + a::UnitRangeDual, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} ) return dual(getindex(nondual(a), indices)) end -function BlockArrays.blocklengths(a::BlockedUnitRangeDual) +function BlockArrays.blocklengths(a::UnitRangeDual) return dual.(blocklengths(nondual(a))) end @@ -68,91 +66,53 @@ function unitrangedual_getindices_blocks(a, indices) end # TODO: Move this to a `BlockArraysExtensions` library. -function blockedunitrange_getindices(a::BlockedUnitRangeDual, indices::Block{1}) +function blockedunitrange_getindices(a::UnitRangeDual, indices::Block{1}) return a[indices] end -function Base.getindex(a::BlockedUnitRangeDual, indices::Vector{<:Block{1}}) +function Base.getindex(a::UnitRangeDual, indices::Vector{<:Block{1}}) return unitrangedual_getindices_blocks(a, indices) end -function Base.getindex(a::BlockedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}) +function Base.getindex(a::UnitRangeDual, indices::Vector{<:BlockIndexRange{1}}) return unitrangedual_getindices_blocks(a, indices) end -function to_blockindices(a::BlockedUnitRangeDual, indices::UnitRange{<:Integer}) +function to_blockindices(a::UnitRangeDual, indices::UnitRange{<:Integer}) return to_blockindices(nondual(a), indices) end -Base.axes(a::BlockedUnitRangeDual) = axes(nondual(a)) +Base.axes(a::UnitRangeDual) = axes(nondual(a)) using BlockArrays: BlockArrays, Block, BlockSlice using NDTensors.LabelledNumbers: LabelledUnitRange -function BlockArrays.BlockSlice(b::Block, a::LabelledUnitRange) - return BlockSlice(b, unlabel(a)) -end using BlockArrays: BlockArrays, BlockSlice -using NDTensors.GradedAxes: BlockedUnitRangeDual, dual -function BlockArrays.BlockSlice(b::Block, r::BlockedUnitRangeDual) +using NDTensors.GradedAxes: UnitRangeDual, dual +function BlockArrays.BlockSlice(b::Block, r::UnitRangeDual) return BlockSlice(b, dual(r)) end -using NDTensors.LabelledNumbers: LabelledNumbers, label -LabelledNumbers.label(a::BlockedUnitRangeDual) = dual(label(nondual(a))) - -using NDTensors.LabelledNumbers: LabelledUnitRange -# The Base version of `length(::AbstractUnitRange)` drops the label. -function Base.length(a::BlockedUnitRangeDual{<:Any,<:LabelledUnitRange}) - return dual(length(nondual(a))) -end -function Base.iterate(a::BlockedUnitRangeDual, i) +function Base.iterate(a::UnitRangeDual, i) i == last(a) && return nothing return dual.(iterate(nondual(a), i)) end # TODO: Is this a good definition? -Base.unitrange(a::BlockedUnitRangeDual{<:Any,<:AbstractUnitRange}) = a - -using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel -dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i))) +Base.unitrange(a::UnitRangeDual{<:Any,<:AbstractUnitRange}) = a using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock -BlockArrays.blockaxes(a::BlockedUnitRangeDual) = blockaxes(nondual(a)) -BlockArrays.blockfirsts(a::BlockedUnitRangeDual) = label_dual.(blockfirsts(nondual(a))) -BlockArrays.blocklasts(a::BlockedUnitRangeDual) = label_dual.(blocklasts(nondual(a))) -function BlockArrays.findblock(a::BlockedUnitRangeDual, index::Integer) +BlockArrays.blockaxes(a::UnitRangeDual) = blockaxes(nondual(a)) +BlockArrays.blockfirsts(a::UnitRangeDual) = blockfirsts(nondual(a)) +BlockArrays.blocklasts(a::UnitRangeDual) = blocklasts(nondual(a)) +function BlockArrays.findblock(a::UnitRangeDual, index::Integer) return findblock(nondual(a), index) end -blocklabels(a::BlockedUnitRangeDual) = dual.(blocklabels(nondual(a))) - -gradedisequal(::BlockedUnitRangeDual, ::AbstractGradedUnitRange) = false -gradedisequal(::AbstractGradedUnitRange, ::BlockedUnitRangeDual) = false -function gradedisequal(a1::BlockedUnitRangeDual, a2::BlockedUnitRangeDual) +gradedisequal(::UnitRangeDual, ::AbstractGradedUnitRange) = false +gradedisequal(::AbstractGradedUnitRange, ::UnitRangeDual) = false +function gradedisequal(a1::UnitRangeDual, a2::UnitRangeDual) return gradedisequal(nondual(a1), nondual(a2)) end -function BlockArrays.combine_blockaxes(a1::BlockedUnitRangeDual, a2::BlockedUnitRangeDual) +function BlockArrays.combine_blockaxes(a1::UnitRangeDual, a2::UnitRangeDual) return dual(combine_blockaxes(dual(a1), dual(a2))) end - -# This is needed when constructing `CartesianIndices` from -# a tuple of unit ranges that have this kind of dual unit range. -# TODO: See if we can find some more elegant way of constructing -# `CartesianIndices`, maybe by defining conversion of `LabelledInteger` -# to `Int`, defining a more general `convert` function, etc. -function Base.OrdinalRange{Int,Int}( - r::BlockedUnitRangeDual{<:LabelledInteger{Int},<:LabelledUnitRange{Int,UnitRange{Int}}} -) - # TODO: Implement this broadcasting operation and use it here. - # return Int.(r) - return unlabel(nondual(r)) -end - -# This is only needed in certain Julia versions below 1.10 -# (for example Julia 1.6). -# TODO: Delete this once we drop Julia 1.6 support. -# The type constraint `T<:Integer` is needed to avoid an ambiguity -# error with a conversion method in Base. -function Base.UnitRange{T}(a::BlockedUnitRangeDual{<:LabelledInteger{T}}) where {T<:Integer} - return UnitRange{T}(nondual(a)) -end diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 447a9a31ac..ca99faf7f8 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -3,7 +3,7 @@ using BlockArrays: Block, blockaxes, blockfirsts, blocklasts, blocklength, blocklengths, blocks, findblock using NDTensors.GradedAxes: GradedAxes, - BlockedUnitRangeDual, + GradedUnitRangeDual, blocklabels, blockmergesortperm, blocksortperm, @@ -34,7 +34,7 @@ GradedAxes.dual(c::U1) = U1(-c.n) @test ad[4] == 4 @test label(ad[4]) == U1(-1) @test ad[2:4] == 2:4 - @test ad[2:4] isa BlockedUnitRangeDual + @test ad[2:4] isa GradedUnitRangeDual @test label(ad[2:4][Block(2)]) == U1(-1) @test ad[[2, 4]] == [2, 4] @test label(ad[[2, 4]][2]) == U1(-1) From 743592ae7d544edf38c528b264421cea766a1afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 10 Sep 2024 19:59:34 -0400 Subject: [PATCH 03/43] pass BlockSparseArrays tests --- .../test/runtests.jl | 14 +++++++++++--- .../src/lib/GradedAxes/src/gradedunitrangedual.jl | 4 ++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 38142b65f5..b98b4fc7b6 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -4,7 +4,13 @@ using Test: @test, @testset, @test_broken using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored using NDTensors.GradedAxes: - GradedAxes, GradedOneTo, UnitRangeDual, blocklabels, dual, gradedrange + GradedAxes, + GradedOneTo, + GradedUnitRangeDual, + UnitRangeDual, + blocklabels, + dual, + gradedrange using NDTensors.LabelledNumbers: label using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: fusedims, splitdims @@ -149,6 +155,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end # Test case when all axes are dual. + @test dual(gradedrange([U1(0) => 2])) isa GradedUnitRangeDual + @test dual(blockedrange([2, 2])) isa UnitRangeDual for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2])) a = BlockSparseArray{elt}(dual(r), dual(r)) @views for i in [Block(1, 1), Block(2, 2)] @@ -158,7 +166,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test block_nstored(b) == 2 @test Array(b) == 2 * Array(a) for ax in axes(b) - @test ax isa UnitRangeDual + @test ax isa typeof(dual(r)) end end @@ -173,7 +181,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test block_nstored(b) == 2 @test Array(b) == 2 * Array(a)' for ax in axes(b) - @test ax isa UnitRangeDual + @test ax isa typeof(dual(r)) end end end diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index f5e5f986a3..7c031d2811 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -152,3 +152,7 @@ end function Base.UnitRange{T}(a::GradedUnitRangeDual{<:LabelledInteger{T}}) where {T<:Integer} return UnitRange{T}(nondual(a)) end + +function unlabel_blocks(a::GradedUnitRangeDual) + return unlabel_blocks(nondual(a)) +end From 36e7bfb2771493d843619c6ce1be7cd5c5e36ce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Wed, 11 Sep 2024 14:34:58 -0400 Subject: [PATCH 04/43] test UnitRangeDual --- NDTensors/src/lib/GradedAxes/src/dual.jl | 1 + .../lib/GradedAxes/src/gradedunitrangedual.jl | 1 - .../src/lib/GradedAxes/src/unitrangedual.jl | 3 +- .../src/lib/GradedAxes/test/test_dual.jl | 70 ++++++++++++++++++- 4 files changed, 70 insertions(+), 5 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/dual.jl b/NDTensors/src/lib/GradedAxes/src/dual.jl index 985ebe33cc..3e9e7253d8 100644 --- a/NDTensors/src/lib/GradedAxes/src/dual.jl +++ b/NDTensors/src/lib/GradedAxes/src/dual.jl @@ -1,4 +1,5 @@ function dual end +isdual(::AbstractUnitRange) = false # default behavior using NDTensors.LabelledNumbers: LabelledStyle, IsLabelled, NotLabelled, label, labelled, unlabel diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index 7c031d2811..81b667f6d9 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -8,7 +8,6 @@ dual(a::AbstractGradedUnitRange) = GradedUnitRangeDual(a) nondual(a::GradedUnitRangeDual) = a.nondual_unitrange dual(a::GradedUnitRangeDual) = nondual(a) flip(a::GradedUnitRangeDual) = dual(flip(nondual(a))) -isdual(::AbstractGradedUnitRange) = false isdual(::GradedUnitRangeDual) = true ## TODO: Define this to instantiate a dual unit range. ## materialize_dual(a::GradedUnitRangeDual) = materialize_dual(nondual(a)) diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index 6f32958d0c..c06397a49a 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -59,9 +59,8 @@ function BlockArrays.blocklengths(a::UnitRangeDual) return dual.(blocklengths(nondual(a))) end -# TODO: Use `label_dual.` here, make broadcasting work? function unitrangedual_getindices_blocks(a, indices) - a_indices = getindex(nondual(a), indices) + a_indices = blockedunitrange_getindices(nondual(a), indices) return mortar([dual(b) for b in blocks(a_indices)]) end diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index ca99faf7f8..f48ccc593d 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -1,9 +1,19 @@ @eval module $(gensym()) using BlockArrays: - Block, blockaxes, blockfirsts, blocklasts, blocklength, blocklengths, blocks, findblock + Block, + blockaxes, + blockedrange, + blockfirsts, + blocklasts, + blocklength, + blocklengths, + blocks, + findblock using NDTensors.GradedAxes: GradedAxes, GradedUnitRangeDual, + OneToOne, + UnitRangeDual, blocklabels, blockmergesortperm, blocksortperm, @@ -19,9 +29,65 @@ struct U1 n::Int end GradedAxes.dual(c::U1) = U1(-c.n) -@testset "dual" begin +Base.isless(c1::U1, c2::U1) = c1.n < c2.n + +@testset "UnitRangeDual" begin + @testset "dual(OneToOne)" begin + a = OneToOne() + ad = dual(a) + @test ad isa UnitRangeDual + @test eltype(ad) == Bool + @test nondual(ad) === a + @test dual(ad) === a + + @test isdual(ad) + @test !isdual(a) + @test length(ad) == 1 + end + @testset "dual(UnitRange)" begin + a = 1:3 + ad = dual(a) + @test ad isa UnitRangeDual + @test eltype(ad) == Int + @test nondual(ad) === a + @test dual(ad) === a + + @test isdual(ad) + @test !isdual(a) + @test length(ad) == 3 + end + @testset "dual(BlockedOneTo)" begin + a = blockedrange([2, 3]) + ad = dual(a) + @test ad isa UnitRangeDual + @test eltype(ad) == Int + @test nondual(ad) === a + @test dual(ad) === a + + @test isdual(ad) + @test !isdual(a) + @test length(ad) == 5 + + @test blockfirsts(ad) == [1, 3] + @test blocklasts(ad) == [2, 5] + @test findblock(ad, 4) == Block(2) + @test only(blockaxes(ad)) == Block(1):Block(2) + @test blocks(ad) == [1:2, 3:5] + @test ad[4] == 4 + @test ad[2:4] == 2:4 + @test ad[2:4] isa UnitRangeDual + @test ad[[2, 4]] == [2, 4] + @test ad[Block(2)] == 3:5 + @test ad[Block(1):Block(2)][Block(2)] == 3:5 + @test ad[[Block(2), Block(1)]][Block(1)] == 3:5 + @test ad[[Block(2)[1:2], Block(1)[1:2]]][Block(1)] == 3:4 + end +end + +@testset "GradedUnitRangeDual" begin a = gradedrange([U1(0) => 2, U1(1) => 3]) ad = dual(a) + @test ad isa GradedUnitRangeDual @test eltype(ad) == LabelledInteger{Int,U1} @test dual(ad) == a @test nondual(ad) == a From cb1c4d92b5e1ee2c3ade5fdd411e3e90f8d0cec9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Wed, 11 Sep 2024 17:29:06 -0400 Subject: [PATCH 05/43] fix blocklengths --- NDTensors/src/lib/GradedAxes/src/unitrangedual.jl | 2 +- NDTensors/src/lib/GradedAxes/test/test_dual.jl | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index c06397a49a..95c85f9de2 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -56,7 +56,7 @@ function Base.getindex( end function BlockArrays.blocklengths(a::UnitRangeDual) - return dual.(blocklengths(nondual(a))) + return blocklengths(nondual(a)) end function unitrangedual_getindices_blocks(a, indices) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index f48ccc593d..a304ed0274 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -70,6 +70,8 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n @test blockfirsts(ad) == [1, 3] @test blocklasts(ad) == [2, 5] + @test blocklength(ad) == 2 + @test blocklengths(ad) == [2, 3] @test findblock(ad, 4) == Block(2) @test only(blockaxes(ad)) == Block(1):Block(2) @test blocks(ad) == [1:2, 3:5] @@ -94,6 +96,8 @@ end @test nondual(a) == a @test blockfirsts(ad) == [labelled(1, U1(0)), labelled(3, U1(-1))] @test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))] + @test blocklength(ad) == 2 + @test blocklengths(ad) == [2, 3] @test findblock(ad, 4) == Block(2) @test only(blockaxes(ad)) == Block(1):Block(2) @test blocks(ad) == [labelled(1:2, U1(0)), labelled(3:5, U1(-1))] From 5a7d4a7e01c1585996f8776b44783461fee652bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Wed, 11 Sep 2024 16:14:11 -0400 Subject: [PATCH 06/43] add tests --- .../test/runtests.jl | 11 ++++++++++ .../src/abstractblocksparsearray/views.jl | 9 +++++++- .../lib/BlockSparseArrays/test/test_basics.jl | 21 +++++++++++++++++-- .../src/lib/GradedAxes/test/test_dual.jl | 2 +- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index b98b4fc7b6..116e321d10 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -168,6 +168,12 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) for ax in axes(b) @test ax isa typeof(dual(r)) end + + I = [Block(1)[1:1]] + @test_broken a[I, :] + @test_broken a[:, I] + @test size(a[I, I]) == (1, 1) + @test_broken GradedAxes.isdual(axes(a[I, I], 1)) end # Test case when all axes are dual @@ -183,6 +189,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) for ax in axes(b) @test ax isa typeof(dual(r)) end + + I = [Block(1)[1:1]] + @test size(a[I, :]) == (1, 4) + @test size(a[:, I]) == (4, 1) + @test size(a[I, I]) == (1, 1) end end @testset "Matrix multiplication" begin diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl index e409ed5500..a07ba72913 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl @@ -1,5 +1,12 @@ using BlockArrays: - BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock + AbstractBlockedUnitRange, + BlockArrays, + Block, + BlockIndexRange, + BlockedVector, + blocklength, + blocksize, + viewblock # This splits `BlockIndexRange{N}` into # `NTuple{N,BlockIndexRange{1}}`. diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 10f8d6e35d..ad517cc921 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -15,9 +15,15 @@ using BlockArrays: blocksizes, mortar using Compat: @compat -using LinearAlgebra: mul! +using LinearAlgebra: Adjoint, mul! using NDTensors.BlockSparseArrays: - @view!, BlockSparseArray, BlockView, block_nstored, block_reshape, view! + @view!, + BlockSparseArray, + BlockView, + block_nstored, + block_reshape, + block_stored_indices, + view! using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: contract using Test: @test, @test_broken, @test_throws, @testset @@ -44,6 +50,17 @@ include("TestBlockSparseArraysUtils.jl") a[Block(2, 2)] = randn(elt, 3, 3) @test a[2:4, 4] == Array(a)[2:4, 4] @test_broken a[4, 2:4] + + @test a[Block(1), :] isa BlockSparseArray{elt} + @test adjoint(a) isa Adjoint{elt,<:BlockSparseArray} + @test_broken adjoint(a)[Block(1), :] isa Adjoint{elt,<:BlockSparseArray} + # could also be directly a BlockSparseArray + + a = BlockSparseArray{elt}([1], [1, 1]) + a[1, 2] = 1 + @test [a[Block(Tuple(it))] for it in eachindex(block_stored_indices(a))] isa Vector + ah = adjoint(a) + @test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector end @testset "Basics" begin a = BlockSparseArray{elt}([2, 3], [2, 3]) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index a304ed0274..f8f0003de6 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -24,7 +24,7 @@ using NDTensors.GradedAxes: isdual, nondual using NDTensors.LabelledNumbers: LabelledInteger, label, labelled -using Test: @test, @test_broken, @testset +using Test: @test, @testset struct U1 n::Int end From c37372929d11ad183be9d1cc239525cf7b300ed0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 12 Sep 2024 19:01:32 -0400 Subject: [PATCH 07/43] fix some slicing --- .../test/runtests.jl | 52 +++++++++++++++++-- .../lib/GradedAxes/src/blockedunitrange.jl | 2 +- .../lib/GradedAxes/src/gradedunitrangedual.jl | 4 -- 3 files changed, 49 insertions(+), 9 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 116e321d10..216598f333 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -155,9 +155,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end # Test case when all axes are dual. - @test dual(gradedrange([U1(0) => 2])) isa GradedUnitRangeDual - @test dual(blockedrange([2, 2])) isa UnitRangeDual - for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2])) + @testset "BlockedOneTo" begin + r = gradedrange([U1(0) => 2, U1(1) => 2]) a = BlockSparseArray{elt}(dual(r), dual(r)) @views for i in [Block(1, 1), Block(2, 2)] a[i] = randn(elt, size(a[i])) @@ -165,8 +164,51 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) b = 2 * a @test block_nstored(b) == 2 @test Array(b) == 2 * Array(a) + @test a[:, :] isa BlockSparseArray for ax in axes(b) - @test ax isa typeof(dual(r)) + @test ax isa GradedUnitRangeDual + end + + I = [Block(1)[1:1]] + @test_broken a[I, :] + @test_broken a[:, I] + @test size(a[I, I]) == (1, 1) + @test_broken GradedAxes.isdual(axes(a[I, I], 1)) + end + + @testset "GradedUnitRange" begin + r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3] + a = BlockSparseArray{elt}(dual(r), dual(r)) + @views for i in [Block(1, 1), Block(2, 2)] + a[i] = randn(elt, size(a[i])) + end + b = 2 * a + @test block_nstored(b) == 2 + @test Array(b) == 2 * Array(a) + @test a[:, :] isa BlockSparseArray + for ax in axes(b) + @test ax isa GradedUnitRangeDual + end + + I = [Block(1)[1:1]] + @test_broken a[I, :] + @test_broken a[:, I] + @test size(a[I, I]) == (1, 1) + @test_broken GradedAxes.isdual(axes(a[I, I], 1)) + end + + @testset "BlockedUnitRange" begin + r = blockedrange([2, 2]) + a = BlockSparseArray{elt}(dual(r), dual(r)) + @views for i in [Block(1, 1), Block(2, 2)] + a[i] = randn(elt, size(a[i])) + end + b = 2 * a + @test block_nstored(b) == 2 + @test Array(b) == 2 * Array(a) + @test_broken a[:, :] isa BlockSparseArray + for ax in axes(b) + @test ax isa UnitRangeDual end I = [Block(1)[1:1]] @@ -190,6 +232,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test ax isa typeof(dual(r)) end + @test a[:, :] isa BlockSparseArray + I = [Block(1)[1:1]] @test size(a[I, :]) == (1, 4) @test size(a[:, I]) == (4, 1) diff --git a/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl index 883025df12..d913dd60ab 100644 --- a/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl @@ -167,7 +167,7 @@ end # Slice `a` by `I`, returning a: # `BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}` # with the `BlockIndex{1}` corresponding to each value of `I`. -function to_blockindices(a::BlockedOneTo{<:Integer}, I::UnitRange{<:Integer}) +function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<:Integer}) return mortar( map(blocks(blockedunitrange_getindices(a, I))) do r bi_first = findblockindex(a, first(r)) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index 81b667f6d9..0bd78cb6e8 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -75,10 +75,6 @@ function Base.getindex(a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange return unitrangedual_getindices_blocks(a, indices) end -function to_blockindices(a::GradedUnitRangeDual, indices::UnitRange{<:Integer}) - return to_blockindices(nondual(a), indices) -end - Base.axes(a::GradedUnitRangeDual) = axes(nondual(a)) using BlockArrays: BlockArrays, Block, BlockSlice From 113a1fed8328d2ae7046992322c400ee12f216ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 16 Sep 2024 17:13:37 -0400 Subject: [PATCH 08/43] fix tests --- .../ext/GradedAxesSectorsExt/Project.toml | 2 -- .../src/GradedAxesSectorsExt.jl | 9 ------ .../GradedAxesSectorsExt/test/Project.toml | 3 -- .../ext/GradedAxesSectorsExt/test/runtests.jl | 15 ---------- .../src/lib/GradedAxes/src/GradedAxes.jl | 4 +-- NDTensors/src/lib/GradedAxes/src/dual.jl | 2 ++ NDTensors/src/lib/GradedAxes/src/fusion.jl | 28 +++++++++---------- .../src/lib/GradedAxes/src/gradedunitrange.jl | 7 ++++- .../src/lib/GradedAxes/test/test_dual.jl | 14 ++++++++-- 9 files changed, 35 insertions(+), 49 deletions(-) delete mode 100644 NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/Project.toml delete mode 100644 NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl delete mode 100644 NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/Project.toml delete mode 100644 NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/runtests.jl diff --git a/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/Project.toml b/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/Project.toml deleted file mode 100644 index 9b1d5ccd25..0000000000 --- a/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/Project.toml +++ /dev/null @@ -1,2 +0,0 @@ -[deps] -NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" diff --git a/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl b/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl deleted file mode 100644 index aa3056438e..0000000000 --- a/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl +++ /dev/null @@ -1,9 +0,0 @@ -module GradedAxesSectorsExt -using ..GradedAxes: GradedAxes -using ...Sectors: Sectors, AbstractCategory, ⊗ # , dual - -GradedAxes.fuse_labels(c1::AbstractCategory, c2::AbstractCategory) = only(c1 ⊗ c2) - -# TODO: Decide the fate of `dual`. -## GradedAxes.dual(c::AbstractCategory) = dual(c) -end diff --git a/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/Project.toml b/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/Project.toml deleted file mode 100644 index ef491a529c..0000000000 --- a/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/Project.toml +++ /dev/null @@ -1,3 +0,0 @@ -[deps] -NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/runtests.jl b/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/runtests.jl deleted file mode 100644 index 371e7e57cd..0000000000 --- a/NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/runtests.jl +++ /dev/null @@ -1,15 +0,0 @@ -@eval module $(gensym()) -using NDTensors.GradedAxes: dual, fuse_labels -using NDTensors.Sectors: U1, Z -using Test: @test, @testset - -@testset "GradedAxesSectorsExt" begin - @test fuse_labels(U1(1), U1(2)) == U1(3) - @test dual(U1(2)) == U1(-2) - - @test fuse_labels(Z{2}(1), Z{2}(1)) == Z{2}(0) - @test fuse_labels(Z{2}(0), Z{2}(1)) == Z{2}(1) - @test dual(Z{2}(1)) == Z{2}(1) - @test dual(Z{2}(0)) == Z{2}(0) -end -end diff --git a/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl b/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl index cbdec545e0..d391004447 100644 --- a/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl +++ b/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl @@ -1,8 +1,8 @@ module GradedAxes include("blockedunitrange.jl") include("gradedunitrange.jl") -include("gradedunitrangedual.jl") include("dual.jl") +include("gradedunitrangedual.jl") include("unitrangedual.jl") -include("../ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl") +include("fusion.jl") end diff --git a/NDTensors/src/lib/GradedAxes/src/dual.jl b/NDTensors/src/lib/GradedAxes/src/dual.jl index 3e9e7253d8..d986fa2f8d 100644 --- a/NDTensors/src/lib/GradedAxes/src/dual.jl +++ b/NDTensors/src/lib/GradedAxes/src/dual.jl @@ -6,3 +6,5 @@ using NDTensors.LabelledNumbers: label_dual(x) = label_dual(LabelledStyle(x), x) label_dual(::NotLabelled, x) = x label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x))) + +flip(g::AbstractGradedUnitRange) = dual(gradedrange(label_dual.(blocklengths(g)))) diff --git a/NDTensors/src/lib/GradedAxes/src/fusion.jl b/NDTensors/src/lib/GradedAxes/src/fusion.jl index 419a28c03e..b1f54ef7cc 100644 --- a/NDTensors/src/lib/GradedAxes/src/fusion.jl +++ b/NDTensors/src/lib/GradedAxes/src/fusion.jl @@ -6,6 +6,10 @@ OneToOne() = OneToOne{Bool}() Base.first(a::OneToOne) = one(eltype(a)) Base.last(a::OneToOne) = one(eltype(a)) +gradedisequal(::AbstractUnitRange, ::OneToOne) = false +gradedisequal(::OneToOne, ::AbstractUnitRange) = false +gradedisequal(::OneToOne, ::OneToOne) = true + # https://github.com/ITensor/ITensors.jl/blob/v0.3.57/NDTensors/src/lib/GradedAxes/src/tensor_product.jl # https://en.wikipedia.org/wiki/Tensor_product # https://github.com/KeitaNakamura/Tensorial.jl @@ -18,7 +22,7 @@ function tensor_product( return foldl(tensor_product, (a1, a2, a3, a_rest...)) end -function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange) +function tensor_product(::AbstractUnitRange, ::AbstractUnitRange) return error("Not implemented yet.") end @@ -34,7 +38,7 @@ function tensor_product(a1::AbstractBlockedUnitRange, ::OneToOne) return a1 end -function tensor_product(a1::OneToOne, a2::OneToOne) +function tensor_product(::OneToOne, ::OneToOne) return OneToOne() end @@ -66,18 +70,20 @@ function fuse_blocklengths(x::LabelledInteger, y::LabelledInteger) return labelled(unlabel(x) * unlabel(y), fuse_labels(label(x), label(y))) end +flatten_maybe_nested(v::Vector{<:Integer}) = v +flatten_maybe_nested(v::Vector{<:AbstractGradedUnitRange}) = reduce(vcat, blocklengths.(v)) + using BlockArrays: blockedrange, blocks function tensor_product(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) - blocklengths = map(vec(collect(Iterators.product(blocks(a1), blocks(a2))))) do x - return mapreduce(length, fuse_blocklengths, x) - end + maybe_nested = map( + it -> mapreduce(length, fuse_blocklengths, it), + Iterators.flatten((Iterators.product(blocks(a1), blocks(a2)),)), + ) + blocklengths = flatten_maybe_nested(maybe_nested) return blockedrange(blocklengths) end function blocksortperm(a::AbstractBlockedUnitRange) - # TODO: Figure out how to deal with dual sectors. - # TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`. - ## return Block.(sortperm(nondual_sectors(a); rev=isdual(a))) return Block.(sortperm(blocklabels(a))) end @@ -101,12 +107,6 @@ end # Get the permutation for sorting, then group by common elements. # groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]] function blockmergesortperm(a::AbstractBlockedUnitRange) - # If it is dual, reverse the sorting so the sectors - # end up sorted in the same way whether or not the space - # is dual. - # TODO: Figure out how to deal with dual sectors. - # TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`. - ## return Block.(groupsortperm(nondual_sectors(a); rev=isdual(a))) return Block.(groupsortperm(blocklabels(a))) end diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 57e9420d88..5961d4098d 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -12,7 +12,7 @@ using BlockArrays: blockedrange, BlockIndexRange, blockfirsts, - blocklasts, + blockisequal, blocklength, blocklengths, findblock, @@ -37,6 +37,11 @@ function Base.OrdinalRange{T,T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T} return unlabel_blocks(a) end +# == is just a range comparison that ignores labels. Need dedicated function to check equality. +function gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange) + return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2)) +end + # This is only needed in certain Julia versions below 1.10 # (for example Julia 1.6). # TODO: Delete this once we drop Julia 1.6 support. diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index f8f0003de6..cfb84fecde 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -91,9 +91,17 @@ end ad = dual(a) @test ad isa GradedUnitRangeDual @test eltype(ad) == LabelledInteger{Int,U1} - @test dual(ad) == a - @test nondual(ad) == a - @test nondual(a) == a + + @test gradedisequal(dual(ad), a) + @test gradedisequal(nondual(ad), a) + @test gradedisequal(nondual(a), a) + @test gradedisequal(ad, ad) + @test !gradedisequal(a, ad) + @test !gradedisequal(ad, a) + + @test isdual(ad) + @test !isdual(a) + @test blockfirsts(ad) == [labelled(1, U1(0)), labelled(3, U1(-1))] @test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))] @test blocklength(ad) == 2 From 14c95cac9c7ed421ed6b05f8c723e46fdf59bad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 16 Sep 2024 19:13:08 -0400 Subject: [PATCH 09/43] generalize gradedisequal --- .../src/lib/GradedAxes/src/GradedAxes.jl | 3 +- NDTensors/src/lib/GradedAxes/src/fusion.jl | 10 ----- .../src/lib/GradedAxes/src/gradedunitrange.jl | 5 --- .../lib/GradedAxes/src/gradedunitrangedual.jl | 5 --- NDTensors/src/lib/GradedAxes/src/onetoone.jl | 42 +++++++++++++++++++ .../src/lib/GradedAxes/src/unitrangedual.jl | 5 --- .../src/lib/GradedAxes/test/test_basics.jl | 21 +++++++++- .../src/lib/GradedAxes/test/test_dual.jl | 13 ++++++ 8 files changed, 77 insertions(+), 27 deletions(-) create mode 100644 NDTensors/src/lib/GradedAxes/src/onetoone.jl diff --git a/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl b/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl index d391004447..668aac02cb 100644 --- a/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl +++ b/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl @@ -2,7 +2,8 @@ module GradedAxes include("blockedunitrange.jl") include("gradedunitrange.jl") include("dual.jl") -include("gradedunitrangedual.jl") include("unitrangedual.jl") +include("gradedunitrangedual.jl") +include("onetoone.jl") include("fusion.jl") end diff --git a/NDTensors/src/lib/GradedAxes/src/fusion.jl b/NDTensors/src/lib/GradedAxes/src/fusion.jl index b1f54ef7cc..f6e0006929 100644 --- a/NDTensors/src/lib/GradedAxes/src/fusion.jl +++ b/NDTensors/src/lib/GradedAxes/src/fusion.jl @@ -1,15 +1,5 @@ using BlockArrays: AbstractBlockedUnitRange -# Represents the range `1:1` or `Base.OneTo(1)`. -struct OneToOne{T} <: AbstractUnitRange{T} end -OneToOne() = OneToOne{Bool}() -Base.first(a::OneToOne) = one(eltype(a)) -Base.last(a::OneToOne) = one(eltype(a)) - -gradedisequal(::AbstractUnitRange, ::OneToOne) = false -gradedisequal(::OneToOne, ::AbstractUnitRange) = false -gradedisequal(::OneToOne, ::OneToOne) = true - # https://github.com/ITensor/ITensors.jl/blob/v0.3.57/NDTensors/src/lib/GradedAxes/src/tensor_product.jl # https://en.wikipedia.org/wiki/Tensor_product # https://github.com/KeitaNakamura/Tensorial.jl diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 5961d4098d..c1fd93c4c9 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -37,11 +37,6 @@ function Base.OrdinalRange{T,T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T} return unlabel_blocks(a) end -# == is just a range comparison that ignores labels. Need dedicated function to check equality. -function gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange) - return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2)) -end - # This is only needed in certain Julia versions below 1.10 # (for example Julia 1.6). # TODO: Delete this once we drop Julia 1.6 support. diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index 0bd78cb6e8..78f3a00417 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -117,11 +117,6 @@ end blocklabels(a::GradedUnitRangeDual) = dual.(blocklabels(nondual(a))) -gradedisequal(::GradedUnitRangeDual, ::AbstractGradedUnitRange) = false -gradedisequal(::AbstractGradedUnitRange, ::GradedUnitRangeDual) = false -function gradedisequal(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) - return gradedisequal(nondual(a1), nondual(a2)) -end function BlockArrays.combine_blockaxes(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) return dual(combine_blockaxes(dual(a1), dual(a2))) end diff --git a/NDTensors/src/lib/GradedAxes/src/onetoone.jl b/NDTensors/src/lib/GradedAxes/src/onetoone.jl new file mode 100644 index 0000000000..19dadf9eb0 --- /dev/null +++ b/NDTensors/src/lib/GradedAxes/src/onetoone.jl @@ -0,0 +1,42 @@ +using BlockArrays: AbstractBlockedUnitRange +using ..LabelledNumbers: islabelled + +# Represents the range `1:1` or `Base.OneTo(1)`. +struct OneToOne{T} <: AbstractUnitRange{T} end +OneToOne() = OneToOne{Bool}() +Base.first(a::OneToOne) = one(eltype(a)) +Base.last(a::OneToOne) = one(eltype(a)) + +# == is just a range comparison that ignores labels. Need dedicated function to check equality. +gradedisequal(::AbstractBlockedUnitRange, ::AbstractUnitRange) = false +gradedisequal(::AbstractUnitRange, ::AbstractBlockedUnitRange) = false +gradedisequal(::AbstractBlockedUnitRange, ::OneToOne) = false +gradedisequal(::OneToOne, ::AbstractBlockedUnitRange) = false +function gradedisequal(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) + return blockisequal(a1, a2) +end +function gradedisequal(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange) + return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2)) +end +gradedisequal(::GradedUnitRangeDual, ::GradedUnitRange) = false +gradedisequal(::GradedUnitRange, ::GradedUnitRangeDual) = false +function gradedisequal(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) + return gradedisequal(nondual(a1), nondual(a2)) +end + +gradedisequal(::OneToOne, ::OneToOne) = true + +function gradedisequal(::OneToOne, g::AbstractUnitRange) + return !islabelled(eltype(g)) && (first(g) == last(g) == 1) +end +gradedisequal(g::AbstractUnitRange, a0::OneToOne) = gradedisequal(a0, g) + +gradedisequal(::UnitRangeDual, ::AbstractUnitRange) = false +gradedisequal(::AbstractUnitRange, ::UnitRangeDual) = false +gradedisequal(::OneToOne, ::UnitRangeDual) = false +gradedisequal(::UnitRangeDual, ::OneToOne) = false +function gradedisequal(a1::UnitRangeDual, a2::UnitRangeDual) + return gradedisequal(nondual(a1), nondual(a2)) +end + +gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange) = a1 == a2 diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index 95c85f9de2..91b294bb11 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -107,11 +107,6 @@ function BlockArrays.findblock(a::UnitRangeDual, index::Integer) return findblock(nondual(a), index) end -gradedisequal(::UnitRangeDual, ::AbstractGradedUnitRange) = false -gradedisequal(::AbstractGradedUnitRange, ::UnitRangeDual) = false -function gradedisequal(a1::UnitRangeDual, a2::UnitRangeDual) - return gradedisequal(nondual(a1), nondual(a2)) -end function BlockArrays.combine_blockaxes(a1::UnitRangeDual, a2::UnitRangeDual) return dual(combine_blockaxes(dual(a1), dual(a2))) end diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index e1dcc67174..6de5790767 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -9,16 +9,35 @@ using BlockArrays: blocklength, blocklengths, blocks -using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, blocklabels, gradedrange +using NDTensors.GradedAxes: + GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedisequal, gradedrange using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, unlabel using Test: @test, @test_broken, @testset + +@testset "OneToOne" begin + a0 = OneToOne() + @test a0 isa OneToOne{Bool} + @test eltype(a0) == Bool + @test length(a0) == 1 + @test gradedisequal(a0, a0) + + @test gradedisequal(a0, 1:1) + @test gradedisequal(1:1, a0) + @test !gradedisequal(a0, 1:2) + @test !gradedisequal(1:2, a0) +end @testset "GradedAxes basics" begin + a0 = OneToOne() for a in ( blockedrange([labelled(2, "x"), labelled(3, "y")]), gradedrange([labelled(2, "x"), labelled(3, "y")]), gradedrange(["x" => 2, "y" => 3]), ) @test a isa GradedOneTo + @test gradedisequal(a, a) + @test !gradedisequal(a0, a) + @test !gradedisequal(a, a0) + @test !gradedisequal(a, 1:5) for x in iterate(a) @test x == 1 @test label(x) == "x" diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index cfb84fecde..10fd7b2068 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -43,6 +43,9 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n @test isdual(ad) @test !isdual(a) @test length(ad) == 1 + @test !gradedisequal(a, ad) + @test !gradedisequal(ad, a) + @test gradedisequal(ad, ad) end @testset "dual(UnitRange)" begin a = 1:3 @@ -55,6 +58,16 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n @test isdual(ad) @test !isdual(a) @test length(ad) == 3 + + @test !gradedisequal(ad, a) + @test !gradedisequal(a, ad) + @test gradedisequal(ad, ad) + + a0 = OneToOne() + @test !gradedisequal(ad, a0) + @test !gradedisequal(a0, ad) + @test !gradedisequal(dual(a0), ad) + @test !gradedisequal(ad, dual(a0)) end @testset "dual(BlockedOneTo)" begin a = blockedrange([2, 3]) From a11107cdd2786e9f46688ab07cd4372ec27328dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 17 Sep 2024 13:23:38 -0400 Subject: [PATCH 10/43] remove UnitRangeDual --- .../test/runtests.jl | 12 +- .../src/lib/GradedAxes/src/GradedAxes.jl | 1 - NDTensors/src/lib/GradedAxes/src/dual.jl | 6 +- NDTensors/src/lib/GradedAxes/src/onetoone.jl | 9 - .../src/lib/GradedAxes/src/unitrangedual.jl | 112 ----------- .../src/lib/GradedAxes/test/test_basics.jl | 1 + .../src/lib/GradedAxes/test/test_dual.jl | 177 +++++++----------- 7 files changed, 73 insertions(+), 245 deletions(-) delete mode 100644 NDTensors/src/lib/GradedAxes/src/unitrangedual.jl diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 216598f333..60ed5b8a85 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -4,13 +4,7 @@ using Test: @test, @testset, @test_broken using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored using NDTensors.GradedAxes: - GradedAxes, - GradedOneTo, - GradedUnitRangeDual, - UnitRangeDual, - blocklabels, - dual, - gradedrange + GradedAxes, GradedOneTo, GradedUnitRangeDual, blocklabels, dual, gradedrange using NDTensors.LabelledNumbers: label using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: fusedims, splitdims @@ -208,14 +202,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test Array(b) == 2 * Array(a) @test_broken a[:, :] isa BlockSparseArray for ax in axes(b) - @test ax isa UnitRangeDual + @test ax isa BlockedUnitRange end I = [Block(1)[1:1]] @test_broken a[I, :] @test_broken a[:, I] @test size(a[I, I]) == (1, 1) - @test_broken GradedAxes.isdual(axes(a[I, I], 1)) + @test !GradedAxes.isdual(axes(a[I, I], 1)) end # Test case when all axes are dual diff --git a/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl b/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl index 668aac02cb..7edd09bf84 100644 --- a/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl +++ b/NDTensors/src/lib/GradedAxes/src/GradedAxes.jl @@ -2,7 +2,6 @@ module GradedAxes include("blockedunitrange.jl") include("gradedunitrange.jl") include("dual.jl") -include("unitrangedual.jl") include("gradedunitrangedual.jl") include("onetoone.jl") include("fusion.jl") diff --git a/NDTensors/src/lib/GradedAxes/src/dual.jl b/NDTensors/src/lib/GradedAxes/src/dual.jl index d986fa2f8d..28f140d69e 100644 --- a/NDTensors/src/lib/GradedAxes/src/dual.jl +++ b/NDTensors/src/lib/GradedAxes/src/dual.jl @@ -1,5 +1,7 @@ -function dual end -isdual(::AbstractUnitRange) = false # default behavior +# default behavior: self-dual +dual(r::AbstractUnitRange) = r +nondual(r::AbstractUnitRange) = r +isdual(::AbstractUnitRange) = false using NDTensors.LabelledNumbers: LabelledStyle, IsLabelled, NotLabelled, label, labelled, unlabel diff --git a/NDTensors/src/lib/GradedAxes/src/onetoone.jl b/NDTensors/src/lib/GradedAxes/src/onetoone.jl index 19dadf9eb0..61ee3c2096 100644 --- a/NDTensors/src/lib/GradedAxes/src/onetoone.jl +++ b/NDTensors/src/lib/GradedAxes/src/onetoone.jl @@ -30,13 +30,4 @@ function gradedisequal(::OneToOne, g::AbstractUnitRange) return !islabelled(eltype(g)) && (first(g) == last(g) == 1) end gradedisequal(g::AbstractUnitRange, a0::OneToOne) = gradedisequal(a0, g) - -gradedisequal(::UnitRangeDual, ::AbstractUnitRange) = false -gradedisequal(::AbstractUnitRange, ::UnitRangeDual) = false -gradedisequal(::OneToOne, ::UnitRangeDual) = false -gradedisequal(::UnitRangeDual, ::OneToOne) = false -function gradedisequal(a1::UnitRangeDual, a2::UnitRangeDual) - return gradedisequal(nondual(a1), nondual(a2)) -end - gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange) = a1 == a2 diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl deleted file mode 100644 index 91b294bb11..0000000000 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ /dev/null @@ -1,112 +0,0 @@ -struct UnitRangeDual{T<:Integer,NondualUnitRange<:AbstractUnitRange} <: AbstractUnitRange{T} - nondual_unitrange::NondualUnitRange -end -UnitRangeDual(a::AbstractUnitRange) = UnitRangeDual{eltype(a),typeof(a)}(a) - -dual(a::AbstractUnitRange) = UnitRangeDual(a) -nondual(a::UnitRangeDual) = a.nondual_unitrange -dual(a::UnitRangeDual) = nondual(a) -flip(a::UnitRangeDual) = dual(flip(nondual(a))) -nondual(a::AbstractUnitRange) = a -isdual(::UnitRangeDual) = true -## TODO: Define this to instantiate a dual unit range. -## materialize_dual(a::UnitRangeDual) = materialize_dual(nondual(a)) - -Base.first(a::UnitRangeDual) = first(nondual(a)) -Base.last(a::UnitRangeDual) = last(nondual(a)) -Base.step(a::UnitRangeDual) = step(nondual(a)) - -Base.view(a::UnitRangeDual, index::Block{1}) = a[index] - -function Base.show(io::IO, a::UnitRangeDual) - return print(io, UnitRangeDual, "(", blocklasts(a), ")") -end - -function Base.show(io::IO, mimetype::MIME"text/plain", a::UnitRangeDual) - return Base.invoke( - show, Tuple{typeof(io),MIME"text/plain",AbstractArray}, io, mimetype, a - ) -end - -function Base.getindex(a::UnitRangeDual, indices::AbstractUnitRange{<:Integer}) - return dual(getindex(nondual(a), indices)) -end - -using BlockArrays: Block, BlockIndexRange, BlockRange - -function Base.getindex(a::UnitRangeDual, indices::Integer) - return label_dual(getindex(nondual(a), indices)) -end - -# TODO: Use `label_dual.` here, make broadcasting work? -function Base.getindex(a::UnitRangeDual, indices::Block{1}) - return dual(getindex(nondual(a), indices)) -end - -# TODO: Use `label_dual.` here, make broadcasting work? -function Base.getindex(a::UnitRangeDual, indices::BlockRange) - return dual(getindex(nondual(a), indices)) -end - -# fix ambiguity -function Base.getindex( - a::UnitRangeDual, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} -) - return dual(getindex(nondual(a), indices)) -end - -function BlockArrays.blocklengths(a::UnitRangeDual) - return blocklengths(nondual(a)) -end - -function unitrangedual_getindices_blocks(a, indices) - a_indices = blockedunitrange_getindices(nondual(a), indices) - return mortar([dual(b) for b in blocks(a_indices)]) -end - -# TODO: Move this to a `BlockArraysExtensions` library. -function blockedunitrange_getindices(a::UnitRangeDual, indices::Block{1}) - return a[indices] -end - -function Base.getindex(a::UnitRangeDual, indices::Vector{<:Block{1}}) - return unitrangedual_getindices_blocks(a, indices) -end - -function Base.getindex(a::UnitRangeDual, indices::Vector{<:BlockIndexRange{1}}) - return unitrangedual_getindices_blocks(a, indices) -end - -function to_blockindices(a::UnitRangeDual, indices::UnitRange{<:Integer}) - return to_blockindices(nondual(a), indices) -end - -Base.axes(a::UnitRangeDual) = axes(nondual(a)) - -using BlockArrays: BlockArrays, Block, BlockSlice -using NDTensors.LabelledNumbers: LabelledUnitRange - -using BlockArrays: BlockArrays, BlockSlice -using NDTensors.GradedAxes: UnitRangeDual, dual -function BlockArrays.BlockSlice(b::Block, r::UnitRangeDual) - return BlockSlice(b, dual(r)) -end - -function Base.iterate(a::UnitRangeDual, i) - i == last(a) && return nothing - return dual.(iterate(nondual(a), i)) -end -# TODO: Is this a good definition? -Base.unitrange(a::UnitRangeDual{<:Any,<:AbstractUnitRange}) = a - -using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock -BlockArrays.blockaxes(a::UnitRangeDual) = blockaxes(nondual(a)) -BlockArrays.blockfirsts(a::UnitRangeDual) = blockfirsts(nondual(a)) -BlockArrays.blocklasts(a::UnitRangeDual) = blocklasts(nondual(a)) -function BlockArrays.findblock(a::UnitRangeDual, index::Integer) - return findblock(nondual(a), index) -end - -function BlockArrays.combine_blockaxes(a1::UnitRangeDual, a2::UnitRangeDual) - return dual(combine_blockaxes(dual(a1), dual(a2))) -end diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 6de5790767..3e5bb05bc8 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -106,6 +106,7 @@ end @test length(ax) == length(a) @test blocklengths(ax) == blocklengths(a) @test blocklabels(ax) == blocklabels(a) + @test_broken(blockfirsts(a)) == [2, 3] # Regression test for ambiguity error. x = gradedrange(["x" => 2, "y" => 3]) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 10fd7b2068..0641d63a87 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -13,7 +13,6 @@ using NDTensors.GradedAxes: GradedAxes, GradedUnitRangeDual, OneToOne, - UnitRangeDual, blocklabels, blockmergesortperm, blocksortperm, @@ -31,142 +30,96 @@ end GradedAxes.dual(c::U1) = U1(-c.n) Base.isless(c1::U1, c2::U1) = c1.n < c2.n -@testset "UnitRangeDual" begin - @testset "dual(OneToOne)" begin - a = OneToOne() - ad = dual(a) - @test ad isa UnitRangeDual - @test eltype(ad) == Bool - @test nondual(ad) === a - @test dual(ad) === a - - @test isdual(ad) - @test !isdual(a) - @test length(ad) == 1 - @test !gradedisequal(a, ad) - @test !gradedisequal(ad, a) - @test gradedisequal(ad, ad) - end - @testset "dual(UnitRange)" begin - a = 1:3 - ad = dual(a) - @test ad isa UnitRangeDual - @test eltype(ad) == Int - @test nondual(ad) === a - @test dual(ad) === a +@testset "AbstractUnitRange" begin + a0 = OneToOne() + @test gradedisequal(a0, dual(a0)) + @test !isdual(a0) - @test isdual(ad) - @test !isdual(a) - @test length(ad) == 3 + a = 1:3 + ad = dual(a) + @test !isdual(ad) + @test !isdual(a) + @test gradedisequal(ad, a) - @test !gradedisequal(ad, a) - @test !gradedisequal(a, ad) - @test gradedisequal(ad, ad) + a = blockedrange([2, 3]) + ad = dual(a) + @test !isdual(ad) + @test !isdual(a) + @test gradedisequal(ad, a) +end - a0 = OneToOne() - @test !gradedisequal(ad, a0) - @test !gradedisequal(a0, ad) - @test !gradedisequal(dual(a0), ad) - @test !gradedisequal(ad, dual(a0)) - end - @testset "dual(BlockedOneTo)" begin - a = blockedrange([2, 3]) +@testset "GradedUnitRangeDual" begin + for a in + [gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]] ad = dual(a) - @test ad isa UnitRangeDual - @test eltype(ad) == Int - @test nondual(ad) === a - @test dual(ad) === a + @test ad isa GradedUnitRangeDual + @test eltype(ad) == LabelledInteger{Int,U1} + + @test gradedisequal(dual(ad), a) + @test gradedisequal(nondual(ad), a) + @test gradedisequal(nondual(a), a) + @test gradedisequal(ad, ad) + @test !gradedisequal(a, ad) + @test !gradedisequal(ad, a) @test isdual(ad) @test !isdual(a) - @test length(ad) == 5 - @test blockfirsts(ad) == [1, 3] - @test blocklasts(ad) == [2, 5] + @test blockfirsts(ad) == [labelled(1, U1(0)), labelled(3, U1(-1))] + @test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))] @test blocklength(ad) == 2 @test blocklengths(ad) == [2, 3] @test findblock(ad, 4) == Block(2) @test only(blockaxes(ad)) == Block(1):Block(2) - @test blocks(ad) == [1:2, 3:5] + @test blocks(ad) == [labelled(1:2, U1(0)), labelled(3:5, U1(-1))] @test ad[4] == 4 + @test label(ad[4]) == U1(-1) @test ad[2:4] == 2:4 - @test ad[2:4] isa UnitRangeDual + @test ad[2:4] isa GradedUnitRangeDual + @test label(ad[2:4][Block(2)]) == U1(-1) @test ad[[2, 4]] == [2, 4] + @test label(ad[[2, 4]][2]) == U1(-1) @test ad[Block(2)] == 3:5 + @test label(ad[Block(2)]) == U1(-1) @test ad[Block(1):Block(2)][Block(2)] == 3:5 + @test label(ad[Block(1):Block(2)][Block(2)]) == U1(-1) @test ad[[Block(2), Block(1)]][Block(1)] == 3:5 + @test label(ad[[Block(2), Block(1)]][Block(1)]) == U1(-1) @test ad[[Block(2)[1:2], Block(1)[1:2]]][Block(1)] == 3:4 + @test label(ad[[Block(2)[1:2], Block(1)[1:2]]][Block(1)]) == U1(-1) + @test blocksortperm(a) == [Block(1), Block(2)] + @test blocksortperm(ad) == [Block(1), Block(2)] + @test blocklength(blockmergesortperm(a)) == 2 + @test blocklength(blockmergesortperm(ad)) == 2 + @test blockmergesortperm(a) == [Block(1), Block(2)] + @test blockmergesortperm(ad) == [Block(1), Block(2)] end end -@testset "GradedUnitRangeDual" begin - a = gradedrange([U1(0) => 2, U1(1) => 3]) - ad = dual(a) - @test ad isa GradedUnitRangeDual - @test eltype(ad) == LabelledInteger{Int,U1} - - @test gradedisequal(dual(ad), a) - @test gradedisequal(nondual(ad), a) - @test gradedisequal(nondual(a), a) - @test gradedisequal(ad, ad) - @test !gradedisequal(a, ad) - @test !gradedisequal(ad, a) - - @test isdual(ad) - @test !isdual(a) - - @test blockfirsts(ad) == [labelled(1, U1(0)), labelled(3, U1(-1))] - @test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))] - @test blocklength(ad) == 2 - @test blocklengths(ad) == [2, 3] - @test findblock(ad, 4) == Block(2) - @test only(blockaxes(ad)) == Block(1):Block(2) - @test blocks(ad) == [labelled(1:2, U1(0)), labelled(3:5, U1(-1))] - @test ad[4] == 4 - @test label(ad[4]) == U1(-1) - @test ad[2:4] == 2:4 - @test ad[2:4] isa GradedUnitRangeDual - @test label(ad[2:4][Block(2)]) == U1(-1) - @test ad[[2, 4]] == [2, 4] - @test label(ad[[2, 4]][2]) == U1(-1) - @test ad[Block(2)] == 3:5 - @test label(ad[Block(2)]) == U1(-1) - @test ad[Block(1):Block(2)][Block(2)] == 3:5 - @test label(ad[Block(1):Block(2)][Block(2)]) == U1(-1) - @test ad[[Block(2), Block(1)]][Block(1)] == 3:5 - @test label(ad[[Block(2), Block(1)]][Block(1)]) == U1(-1) - @test ad[[Block(2)[1:2], Block(1)[1:2]]][Block(1)] == 3:4 - @test label(ad[[Block(2)[1:2], Block(1)[1:2]]][Block(1)]) == U1(-1) - @test blocksortperm(a) == [Block(1), Block(2)] - @test blocksortperm(ad) == [Block(1), Block(2)] - @test blocklength(blockmergesortperm(a)) == 2 - @test blocklength(blockmergesortperm(ad)) == 2 - @test blockmergesortperm(a) == [Block(1), Block(2)] - @test blockmergesortperm(ad) == [Block(1), Block(2)] -end - @testset "flip" begin - a = gradedrange([U1(0) => 2, U1(1) => 3]) - ad = dual(a) - @test gradedisequal(flip(a), dual(gradedrange([U1(0) => 2, U1(-1) => 3]))) - @test gradedisequal(flip(ad), gradedrange([U1(0) => 2, U1(-1) => 3])) + for a in + [gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]] + ad = dual(a) + @test gradedisequal(flip(a), dual(gradedrange([U1(0) => 2, U1(-1) => 3]))) + @test gradedisequal(flip(ad), gradedrange([U1(0) => 2, U1(-1) => 3])) - @test blocklabels(a) == [U1(0), U1(1)] - @test blocklabels(dual(a)) == [U1(0), U1(-1)] - @test blocklabels(flip(a)) == [U1(0), U1(1)] - @test blocklabels(flip(dual(a))) == [U1(0), U1(-1)] - @test blocklabels(dual(flip(a))) == [U1(0), U1(-1)] + @test blocklabels(a) == [U1(0), U1(1)] + @test blocklabels(dual(a)) == [U1(0), U1(-1)] + @test blocklabels(flip(a)) == [U1(0), U1(1)] + @test blocklabels(flip(dual(a))) == [U1(0), U1(-1)] + @test blocklabels(dual(flip(a))) == [U1(0), U1(-1)] - @test blocklengths(a) == [2, 3] - @test blocklengths(ad) == [2, 3] - @test blocklengths(flip(a)) == [2, 3] - @test blocklengths(flip(ad)) == [2, 3] - @test blocklengths(dual(flip(a))) == [2, 3] + @test blocklengths(a) == [2, 3] + @test blocklengths(ad) == [2, 3] + @test blocklengths(flip(a)) == [2, 3] + @test blocklengths(flip(ad)) == [2, 3] + @test blocklengths(dual(flip(a))) == [2, 3] - @test !isdual(a) - @test isdual(ad) - @test isdual(flip(a)) - @test !isdual(flip(ad)) - @test !isdual(dual(flip(a))) + @test !isdual(a) + @test isdual(ad) + @test isdual(flip(a)) + @test !isdual(flip(ad)) + @test !isdual(dual(flip(a))) + end end end From dc762be1a2e58b3cd7ba603d7be06c1ba00af6d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 17 Sep 2024 13:42:22 -0400 Subject: [PATCH 11/43] fix tests --- .../ext/BlockSparseArraysGradedAxesExt/test/runtests.jl | 8 ++++---- NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl | 2 +- NDTensors/src/lib/GradedAxes/test/test_basics.jl | 2 +- .../ext/TensorAlgebraGradedAxesExt/test/test_contract.jl | 5 ++++- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 60ed5b8a85..e8a242ac75 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -200,14 +200,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) b = 2 * a @test block_nstored(b) == 2 @test Array(b) == 2 * Array(a) - @test_broken a[:, :] isa BlockSparseArray + @test a[:, :] isa BlockSparseArray for ax in axes(b) - @test ax isa BlockedUnitRange + @test ax isa BlockedOneTo end I = [Block(1)[1:1]] - @test_broken a[I, :] - @test_broken a[:, I] + @test a[I, :] isa BlockSparseArray + @test a[:, I] isa BlockSparseArray @test size(a[I, I]) == (1, 1) @test !GradedAxes.isdual(axes(a[I, I], 1)) end diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index c1fd93c4c9..b535e9d843 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -177,7 +177,7 @@ end function gradedunitrange_blockfirsts(a::AbstractGradedUnitRange) return labelled.(blockfirsts(unlabel_blocks(a)), blocklabels(a)) end -function BlockArrays.blockfirsts(a::AbstractGradedUnitRange) +function BlockArrays.blockfirsts(a::GradedUnitRange) return gradedunitrange_blockfirsts(a) end function BlockArrays.blockfirsts(a::GradedOneTo) diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 3e5bb05bc8..6e9b0c90de 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -106,7 +106,7 @@ end @test length(ax) == length(a) @test blocklengths(ax) == blocklengths(a) @test blocklabels(ax) == blocklabels(a) - @test_broken(blockfirsts(a)) == [2, 3] + @test blockfirsts(a) == [2, 3] # Regression test for ambiguity error. x = gradedrange(["x" => 2, "y" => 3]) diff --git a/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl b/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl index 68e1374531..b3368cdc8c 100644 --- a/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl +++ b/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl @@ -2,13 +2,16 @@ using BlockArrays: Block, blocksize using Compat: Returns using NDTensors.BlockSparseArrays: BlockSparseArray -using NDTensors.GradedAxes: gradedrange +using NDTensors.GradedAxes: GradedAxes, gradedrange using NDTensors.Sectors: U1 using NDTensors.SparseArrayInterface: densearray using NDTensors.TensorAlgebra: contract using Random: randn! using Test: @test, @testset +#TODO remove once fuse_labels is defined in Sectors +GradedAxes.fuse_labels(m::U1, n::U1) = U1(m.n + n.n) + function randn_blockdiagonal(elt::Type, axes::Tuple) a = BlockSparseArray{elt}(axes) blockdiaglength = minimum(blocksize(a)) From 32d0f115289306c86dae5b3216c2f06109e54125 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 17 Sep 2024 14:59:31 -0400 Subject: [PATCH 12/43] add tests --- .../test/runtests.jl | 28 ++++++++++--------- .../src/lib/GradedAxes/test/test_dual.jl | 10 +++++-- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index e8a242ac75..a91982d10b 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -8,6 +8,7 @@ using NDTensors.GradedAxes: using NDTensors.LabelledNumbers: label using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: fusedims, splitdims +using LinearAlgebra: adjoint using Random: randn! function blockdiagonal!(f, a::AbstractArray) for i in 1:minimum(blocksize(a)) @@ -38,8 +39,6 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2]) @test nstored(b) == 32 @test block_nstored(b) == 2 - # TODO: Have to investigate why this fails - # on Julia v1.6, or drop support for v1.6. for i in 1:ndims(a) @test axes(b, i) isa GradedOneTo end @@ -158,11 +157,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) b = 2 * a @test block_nstored(b) == 2 @test Array(b) == 2 * Array(a) - @test a[:, :] isa BlockSparseArray - for ax in axes(b) - @test ax isa GradedUnitRangeDual + @test a[:, :] isa BlockSparseArray # broken in 1.6 + for i in 1:2 + @test axes(b, i) isa GradedUnitRangeDual + @test_broken axes(a[:, :], i) isa GradedUnitRangeDual end - I = [Block(1)[1:1]] @test_broken a[I, :] @test_broken a[:, I] @@ -179,9 +178,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) b = 2 * a @test block_nstored(b) == 2 @test Array(b) == 2 * Array(a) - @test a[:, :] isa BlockSparseArray - for ax in axes(b) - @test ax isa GradedUnitRangeDual + @test a[:, :] isa BlockSparseArray # broken in 1.6 + for i in 1:2 + @test axes(b, i) isa GradedUnitRangeDual + @test_broken axes(a[:, :], i) isa GradedUnitRangeDual end I = [Block(1)[1:1]] @@ -191,7 +191,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test_broken GradedAxes.isdual(axes(a[I, I], 1)) end - @testset "BlockedUnitRange" begin + @testset "BlockedUnitRange" begin # self dual r = blockedrange([2, 2]) a = BlockSparseArray{elt}(dual(r), dual(r)) @views for i in [Block(1, 1), Block(2, 2)] @@ -201,8 +201,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test block_nstored(b) == 2 @test Array(b) == 2 * Array(a) @test a[:, :] isa BlockSparseArray - for ax in axes(b) - @test ax isa BlockedOneTo + for i in 1:2 + @test axes(b, i) isa BlockedOneTo + @test axes(a[:, :], i) isa BlockedOneTo end I = [Block(1)[1:1]] @@ -226,7 +227,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test ax isa typeof(dual(r)) end - @test a[:, :] isa BlockSparseArray + @test a[:, :] isa BlockSparseArray # broken in 1.6 + @test axes(a[:, :]) isa Tuple{BlockedOneTo,BlockedOneTo} # broken in 1.6 I = [Block(1)[1:1]] @test size(a[I, :]) == (1, 4) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 0641d63a87..78d7f74ff5 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -1,6 +1,7 @@ @eval module $(gensym()) using BlockArrays: Block, + BlockedOneTo, blockaxes, blockedrange, blockfirsts, @@ -32,19 +33,22 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n @testset "AbstractUnitRange" begin a0 = OneToOne() - @test gradedisequal(a0, dual(a0)) @test !isdual(a0) + @test dual(a0) isa OneToOne + @test gradedisequal(a0, dual(a0)) a = 1:3 ad = dual(a) - @test !isdual(ad) @test !isdual(a) + @test !isdual(ad) + @test ad isa UnitRange @test gradedisequal(ad, a) a = blockedrange([2, 3]) ad = dual(a) - @test !isdual(ad) @test !isdual(a) + @test !isdual(ad) + @test ad isa BlockedOneTo @test gradedisequal(ad, a) end From 47a9ed14103c8644e2e512b384a1d72a57aa6ccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 17 Sep 2024 15:11:09 -0400 Subject: [PATCH 13/43] remove tests on a[:,:] --- .../ext/BlockSparseArraysGradedAxesExt/test/runtests.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index a91982d10b..637955a5ac 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -157,7 +157,6 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) b = 2 * a @test block_nstored(b) == 2 @test Array(b) == 2 * Array(a) - @test a[:, :] isa BlockSparseArray # broken in 1.6 for i in 1:2 @test axes(b, i) isa GradedUnitRangeDual @test_broken axes(a[:, :], i) isa GradedUnitRangeDual @@ -178,7 +177,6 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) b = 2 * a @test block_nstored(b) == 2 @test Array(b) == 2 * Array(a) - @test a[:, :] isa BlockSparseArray # broken in 1.6 for i in 1:2 @test axes(b, i) isa GradedUnitRangeDual @test_broken axes(a[:, :], i) isa GradedUnitRangeDual @@ -227,9 +225,6 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test ax isa typeof(dual(r)) end - @test a[:, :] isa BlockSparseArray # broken in 1.6 - @test axes(a[:, :]) isa Tuple{BlockedOneTo,BlockedOneTo} # broken in 1.6 - I = [Block(1)[1:1]] @test size(a[I, :]) == (1, 4) @test size(a[:, I]) == (4, 1) From 11319451b89925ed451a39b0e8251801d3e74225 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 25 Oct 2024 10:30:57 -0400 Subject: [PATCH 14/43] WIP --- NDTensors/src/lib/GradedAxes/test/test_basics.jl | 3 ++- NDTensors/src/lib/GradedAxes/test/test_dual.jl | 1 + .../ext/TensorAlgebraGradedAxesExt/test/test_contract.jl | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 4593d47688..430b84a368 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -11,7 +11,8 @@ using BlockArrays: blocks using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedisequal, gradedrange -using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, unlabel +using NDTensors.LabelledNumbers: + LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel using Test: @test, @test_broken, @testset @testset "OneToOne" begin diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 443e9d626d..80c218adf2 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -35,6 +35,7 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n a0 = OneToOne() @test !isdual(a0) @test dual(a0) isa OneToOne + @test space_isequal(a0, a0) @test space_isequal(a0, dual(a0)) a = 1:3 diff --git a/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl b/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl index 99e679e4b3..c74b0ed179 100644 --- a/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl +++ b/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl @@ -3,7 +3,6 @@ using BlockArrays: Block, blocksize using Compat: Returns using NDTensors.BlockSparseArrays: BlockSparseArray using NDTensors.GradedAxes: GradedAxes, gradedrange -using NDTensors.Sectors: U1 using NDTensors.SparseArrayInterface: densearray using NDTensors.SymmetrySectors: U1 using NDTensors.TensorAlgebra: contract From 3f6bd2e9443187ea5d384a7663c4dbd3e0566526 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 28 Oct 2024 16:22:00 -0400 Subject: [PATCH 15/43] passing tests --- NDTensors/src/lib/GradedAxes/src/fusion.jl | 27 ++++++------------- NDTensors/src/lib/GradedAxes/src/onetoone.jl | 26 +----------------- .../src/lib/GradedAxes/test/test_basics.jl | 22 +++++++-------- .../src/lib/GradedAxes/test/test_dual.jl | 6 +++-- .../GradedAxes/test/test_tensor_product.jl | 4 ++- 5 files changed, 27 insertions(+), 58 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/fusion.jl b/NDTensors/src/lib/GradedAxes/src/fusion.jl index 2506e67393..320244ea75 100644 --- a/NDTensors/src/lib/GradedAxes/src/fusion.jl +++ b/NDTensors/src/lib/GradedAxes/src/fusion.jl @@ -12,19 +12,21 @@ function tensor_product( return foldl(tensor_product, (a1, a2, a3, a_rest...)) end -function tensor_product(::AbstractUnitRange, ::AbstractUnitRange) - return error("Not implemented yet.") +flip_dual(r::AbstractUnitRange) = r +flip_dual(r::GradedUnitRangeDual) = flip(r) +function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange) + return tensor_product(flip_dual(a1), flip_dual(a2)) end function tensor_product(a1::Base.OneTo, a2::Base.OneTo) return Base.OneTo(length(a1) * length(a2)) end -function tensor_product(::OneToOne, a2::AbstractBlockedUnitRange) +function tensor_product(::OneToOne, a2::AbstractUnitRange) return a2 end -function tensor_product(a1::AbstractBlockedUnitRange, ::OneToOne) +function tensor_product(a1::AbstractUnitRange, ::OneToOne) return a1 end @@ -32,19 +34,6 @@ function tensor_product(::OneToOne, ::OneToOne) return OneToOne() end -# Handle dual. Always return a non-dual GradedUnitRange. -function tensor_product(a1::AbstractBlockedUnitRange, a2::GradedUnitRangeDual) - return tensor_product(a1, flip(a2)) -end - -function tensor_product(a1::GradedUnitRangeDual, a2::AbstractBlockedUnitRange) - return tensor_product(flip(a1), a2) -end - -function tensor_product(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) - return tensor_product(flip(a1), flip(a2)) -end - function fuse_labels(x, y) return error( "`fuse_labels` not implemented for object of type `$(typeof(x))` and `$(typeof(y))`." @@ -98,7 +87,8 @@ end # Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`. # Get the permutation for sorting, then group by common elements. # groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]] -function blockmergesortperm(a::AbstractBlockedUnitRange) +blockmergesort(g::AbstractUnitRange) = g +function blockmergesortperm(a::AbstractUnitRange) return Block.(groupsortperm(blocklabels(a))) end @@ -120,7 +110,6 @@ function blockmergesort(g::AbstractGradedUnitRange) end blockmergesort(g::GradedUnitRangeDual) = dual(blockmergesort(flip(g))) -blockmergesort(g::OneToOne) = g # fusion_product produces a sorted, non-dual GradedUnitRange function fusion_product(g1, g2) diff --git a/NDTensors/src/lib/GradedAxes/src/onetoone.jl b/NDTensors/src/lib/GradedAxes/src/onetoone.jl index 61ee3c2096..426df396b1 100644 --- a/NDTensors/src/lib/GradedAxes/src/onetoone.jl +++ b/NDTensors/src/lib/GradedAxes/src/onetoone.jl @@ -6,28 +6,4 @@ struct OneToOne{T} <: AbstractUnitRange{T} end OneToOne() = OneToOne{Bool}() Base.first(a::OneToOne) = one(eltype(a)) Base.last(a::OneToOne) = one(eltype(a)) - -# == is just a range comparison that ignores labels. Need dedicated function to check equality. -gradedisequal(::AbstractBlockedUnitRange, ::AbstractUnitRange) = false -gradedisequal(::AbstractUnitRange, ::AbstractBlockedUnitRange) = false -gradedisequal(::AbstractBlockedUnitRange, ::OneToOne) = false -gradedisequal(::OneToOne, ::AbstractBlockedUnitRange) = false -function gradedisequal(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) - return blockisequal(a1, a2) -end -function gradedisequal(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange) - return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2)) -end -gradedisequal(::GradedUnitRangeDual, ::GradedUnitRange) = false -gradedisequal(::GradedUnitRange, ::GradedUnitRangeDual) = false -function gradedisequal(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) - return gradedisequal(nondual(a1), nondual(a2)) -end - -gradedisequal(::OneToOne, ::OneToOne) = true - -function gradedisequal(::OneToOne, g::AbstractUnitRange) - return !islabelled(eltype(g)) && (first(g) == last(g) == 1) -end -gradedisequal(g::AbstractUnitRange, a0::OneToOne) = gradedisequal(a0, g) -gradedisequal(a1::AbstractUnitRange, a2::AbstractUnitRange) = a1 == a2 +BlockArrays.blockaxes(g::OneToOne) = (Block.(g),) # BlockArrays default crashes for OneToOne{Bool} diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 430b84a368..43dc53302d 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -9,8 +9,7 @@ using BlockArrays: blocklength, blocklengths, blocks -using NDTensors.GradedAxes: - GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedisequal, gradedrange +using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel using Test: @test, @test_broken, @testset @@ -20,13 +19,14 @@ using Test: @test, @test_broken, @testset @test a0 isa OneToOne{Bool} @test eltype(a0) == Bool @test length(a0) == 1 - @test gradedisequal(a0, a0) + @test labelled_isequal(a0, a0) - @test gradedisequal(a0, 1:1) - @test gradedisequal(1:1, a0) - @test !gradedisequal(a0, 1:2) - @test !gradedisequal(1:2, a0) + @test labelled_isequal(a0, 1:1) + @test labelled_isequal(1:1, a0) + @test !labelled_isequal(a0, 1:2) + @test !labelled_isequal(1:2, a0) end + @testset "GradedAxes basics" begin a0 = OneToOne() for a in ( @@ -35,10 +35,10 @@ end gradedrange(["x" => 2, "y" => 3]), ) @test a isa GradedOneTo - @test gradedisequal(a, a) - @test !gradedisequal(a0, a) - @test !gradedisequal(a, a0) - @test !gradedisequal(a, 1:5) + @test labelled_isequal(a, a) + @test !labelled_isequal(a0, a) + @test !labelled_isequal(a, a0) + @test !labelled_isequal(a, 1:5) for x in iterate(a) @test x == 1 @test label(x) == "x" diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 80c218adf2..a0ca3bdf49 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -5,6 +5,7 @@ using BlockArrays: blockaxes, blockedrange, blockfirsts, + blockisequal, blocklasts, blocklength, blocklengths, @@ -23,7 +24,7 @@ using NDTensors.GradedAxes: gradedrange, isdual, nondual -using NDTensors.LabelledNumbers: LabelledInteger, label, labelled +using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal using Test: @test, @testset struct U1 n::Int @@ -36,6 +37,7 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n @test !isdual(a0) @test dual(a0) isa OneToOne @test space_isequal(a0, a0) + @test labelled_isequal(a0, a0) @test space_isequal(a0, dual(a0)) a = 1:3 @@ -50,7 +52,7 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n @test !isdual(a) @test !isdual(ad) @test ad isa BlockedOneTo - @test space_isequal(ad, a) + @test blockisequal(ad, a) end @testset "GradedUnitRangeDual" begin diff --git a/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl b/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl index 02435b5ba7..99e41454ff 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_tensor_product.jl @@ -11,11 +11,12 @@ using NDTensors.GradedAxes: fusion_product, flip, gradedrange, - labelled_isequal, space_isequal, isdual, tensor_product +using NDTensors.LabelledNumbers: labelled_isequal + struct U1 n::Int end @@ -27,6 +28,7 @@ GradedAxes.fuse_labels(x::U1, y::U1) = U1(x.n + y.n) GradedAxes.fuse_labels(x::String, y::String) = x * y g0 = OneToOne() + @test labelled_isequal(g0, g0) @test labelled_isequal(tensor_product(g0, g0), g0) a = gradedrange(["x" => 2, "y" => 3]) From 8a353bdf7a902a6166412a55394cd03ac2c95b51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 28 Oct 2024 17:01:48 -0400 Subject: [PATCH 16/43] finish merging --- .../src/abstractblocksparsearray/views.jl | 9 +------- NDTensors/src/lib/GradedAxes/src/fusion.jl | 21 ++++++------------- 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl index a07ba72913..e409ed5500 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl @@ -1,12 +1,5 @@ using BlockArrays: - AbstractBlockedUnitRange, - BlockArrays, - Block, - BlockIndexRange, - BlockedVector, - blocklength, - blocksize, - viewblock + BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock # This splits `BlockIndexRange{N}` into # `NTuple{N,BlockIndexRange{1}}`. diff --git a/NDTensors/src/lib/GradedAxes/src/fusion.jl b/NDTensors/src/lib/GradedAxes/src/fusion.jl index 320244ea75..1a6d98f318 100644 --- a/NDTensors/src/lib/GradedAxes/src/fusion.jl +++ b/NDTensors/src/lib/GradedAxes/src/fusion.jl @@ -51,9 +51,6 @@ function fuse_blocklengths(x::LabelledInteger, y::LabelledInteger) return blockedrange([labelled(x * y, fuse_labels(label(x), label(y)))]) end -flatten_maybe_nested(v::Vector{<:Integer}) = v -flatten_maybe_nested(v::Vector{<:AbstractGradedUnitRange}) = reduce(vcat, blocklengths.(v)) - using BlockArrays: blockedrange, blocks function tensor_product(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) nested = map(Iterators.flatten((Iterators.product(blocks(a1), blocks(a2)),))) do it @@ -63,13 +60,8 @@ function tensor_product(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRan return blockedrange(new_blocklengths) end -# convention: sort UnitRangeDual according to nondual blocks -function blocksortperm(a::AbstractUnitRange) - return Block.(sortperm(blocklabels(nondual(a)))) -end - # convention: sort GradedUnitRangeDual according to nondual blocks -function blocksortperm(a::GradedUnitRangeDual) +function blocksortperm(a::AbstractUnitRange) return Block.(sortperm(blocklabels(nondual(a)))) end @@ -102,14 +94,13 @@ end function blockmergesort(g::AbstractGradedUnitRange) glabels = blocklabels(g) gblocklengths = blocklengths(g) - new_blocklengths = map( - la -> labelled(sum(gblocklengths[findall(==(la), glabels)]; init=0), la), - sort(unique(glabels)), - ) - return GradedAxes.gradedrange(new_blocklengths) + new_blocklengths = map(sort(unique(glabels))) do la + return labelled(sum(gblocklengths[findall(==(la), glabels)]; init=0), la) + end + return gradedrange(new_blocklengths) end -blockmergesort(g::GradedUnitRangeDual) = dual(blockmergesort(flip(g))) +blockmergesort(g::GradedUnitRangeDual) = flip(blockmergesort(flip(g))) # fusion_product produces a sorted, non-dual GradedUnitRange function fusion_product(g1, g2) From aa1b655a6c4afd51b16afa69aa7366d171defbd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 28 Oct 2024 18:27:18 -0400 Subject: [PATCH 17/43] more tests --- .../test/runtests.jl | 81 +++++++++++++++---- .../src/abstractblocksparsearray/views.jl | 13 ++- 2 files changed, 77 insertions(+), 17 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 637955a5ac..ffd4407c30 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -1,10 +1,17 @@ @eval module $(gensym()) using Compat: Returns using Test: @test, @testset, @test_broken -using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize +using BlockArrays: + AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored using NDTensors.GradedAxes: - GradedAxes, GradedOneTo, GradedUnitRangeDual, blocklabels, dual, gradedrange + GradedAxes, + GradedOneTo, + GradedUnitRange, + GradedUnitRangeDual, + blocklabels, + dual, + gradedrange using NDTensors.LabelledNumbers: label using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: fusedims, splitdims @@ -147,8 +154,50 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test @view(a[Block(1, 1)]) == a[Block(1, 1)] end + @testset "GradedOneTo" begin + r = gradedrange([U1(0) => 2, U1(1) => 2]) + a = BlockSparseArray{elt}(r, r) + @views for i in [Block(1, 1), Block(2, 2)] + a[i] = randn(elt, size(a[i])) + end + b = 2 * a + @test block_nstored(b) == 2 + @test Array(b) == 2 * Array(a) + for i in 1:2 + @test axes(b, i) isa GradedOneTo + @test axes(a[:, :], i) isa GradedOneTo + end + + I = [Block(1)[1:1]] + @test a[I, :] isa AbstractBlockArray + @test a[:, I] isa AbstractBlockArray + @test size(a[I, I]) == (1, 1) + @test !GradedAxes.isdual(axes(a[I, I], 1)) + end + + @testset "GradedUnitRange" begin + r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3] + a = BlockSparseArray{elt}(r, r) + @views for i in [Block(1, 1), Block(2, 2)] + a[i] = randn(elt, size(a[i])) + end + b = 2 * a + @test block_nstored(b) == 2 + @test Array(b) == 2 * Array(a) + for i in 1:2 + @test axes(b, i) isa GradedUnitRange + @test_broken axes(a[:, :], i) isa GradedUnitRange + end + + I = [Block(1)[1:1]] + @test_broken a[I, :] isa AbstractBlockArray + @test_broken a[:, I] isa AbstractBlockArray + @test size(a[I, I]) == (1, 1) + @test_broken GradedAxes.isdual(axes(a[I, I], 1)) + end + # Test case when all axes are dual. - @testset "BlockedOneTo" begin + @testset "dual BlockedOneTo" begin r = gradedrange([U1(0) => 2, U1(1) => 2]) a = BlockSparseArray{elt}(dual(r), dual(r)) @views for i in [Block(1, 1), Block(2, 2)] @@ -162,13 +211,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test_broken axes(a[:, :], i) isa GradedUnitRangeDual end I = [Block(1)[1:1]] - @test_broken a[I, :] - @test_broken a[:, I] + @test_broken a[I, :] isa AbstractBlockArray + @test_broken a[:, I] isa AbstractBlockArray @test size(a[I, I]) == (1, 1) @test_broken GradedAxes.isdual(axes(a[I, I], 1)) end - @testset "GradedUnitRange" begin + @testset "dual GradedUnitRange" begin r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3] a = BlockSparseArray{elt}(dual(r), dual(r)) @views for i in [Block(1, 1), Block(2, 2)] @@ -183,13 +232,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end I = [Block(1)[1:1]] - @test_broken a[I, :] - @test_broken a[:, I] + @test_broken a[I, :] isa AbstractBlockArray + @test_broken a[:, I] isa AbstractBlockArray @test size(a[I, I]) == (1, 1) @test_broken GradedAxes.isdual(axes(a[I, I], 1)) end - @testset "BlockedUnitRange" begin # self dual + @testset "dual BlockedUnitRange" begin # self dual r = blockedrange([2, 2]) a = BlockSparseArray{elt}(dual(r), dual(r)) @views for i in [Block(1, 1), Block(2, 2)] @@ -211,9 +260,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test !GradedAxes.isdual(axes(a[I, I], 1)) end - # Test case when all axes are dual - # from taking the adjoint. - for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2])) + # Test case when all axes are dual from taking the adjoint. + for r in ( + gradedrange([U1(0) => 2, U1(1) => 2]), + gradedrange([U1(0) => 2, U1(1) => 2])[begin:end], + ) a = BlockSparseArray{elt}(r, r) @views for i in [Block(1, 1), Block(2, 2)] a[i] = randn(elt, size(a[i])) @@ -226,9 +277,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end I = [Block(1)[1:1]] - @test size(a[I, :]) == (1, 4) - @test size(a[:, I]) == (4, 1) - @test size(a[I, I]) == (1, 1) + @test_broken size(b[I, :]) == (1, 4) + @test_broken size(b[:, I]) == (4, 1) + @test size(b[I, I]) == (1, 1) end end @testset "Matrix multiplication" begin diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl index e409ed5500..aa3c7711c6 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl @@ -1,5 +1,12 @@ using BlockArrays: - BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock + AbstractBlockedUnitRange, + BlockArrays, + Block, + BlockIndexRange, + BlockedVector, + blocklength, + blocksize, + viewblock # This splits `BlockIndexRange{N}` into # `NTuple{N,BlockIndexRange{1}}`. @@ -191,7 +198,9 @@ function to_blockindexrange( # work right now. return blocks(a.blocks)[Int(I)] end -function to_blockindexrange(a::Base.Slice{<:BlockedOneTo{<:Integer}}, I::Block{1}) +function to_blockindexrange( + a::Base.Slice{<:AbstractBlockedUnitRange{<:Integer}}, I::Block{1} +) @assert I in only(blockaxes(a.indices)) return I end From 5c5abd634f3c1e1ee99da60579b861874376e82e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 28 Oct 2024 18:36:26 -0400 Subject: [PATCH 18/43] more tests --- NDTensors/src/lib/GradedAxes/test/test_dual.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index a0ca3bdf49..04203961a7 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -61,6 +61,8 @@ end ad = dual(a) @test ad isa GradedUnitRangeDual @test eltype(ad) == LabelledInteger{Int,U1} + @test blocklengths(ad) isa Vector + @test eltype(blocklengths(ad)) == eltype(blocklengths(a)) @test space_isequal(dual(ad), a) @test space_isequal(nondual(ad), a) From d7e3a8d63df4941242c1d49d56787186dbb9cb8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 29 Oct 2024 09:58:35 -0400 Subject: [PATCH 19/43] reduce diff --- NDTensors/src/lib/GradedAxes/src/fusion.jl | 8 ++------ .../ext/TensorAlgebraGradedAxesExt/test/test_contract.jl | 5 +---- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/fusion.jl b/NDTensors/src/lib/GradedAxes/src/fusion.jl index 1a6d98f318..2e893ec1db 100644 --- a/NDTensors/src/lib/GradedAxes/src/fusion.jl +++ b/NDTensors/src/lib/GradedAxes/src/fusion.jl @@ -79,18 +79,13 @@ end # Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`. # Get the permutation for sorting, then group by common elements. # groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]] -blockmergesort(g::AbstractUnitRange) = g function blockmergesortperm(a::AbstractUnitRange) - return Block.(groupsortperm(blocklabels(a))) + return Block.(groupsortperm(blocklabels(nondual(a)))) end # Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`. invblockperm(a::Vector{<:Block{1}}) = Block.(invperm(Int.(a))) -function blockmergesortperm(a::GradedUnitRangeDual) - return Block.(groupsortperm(blocklabels(nondual(a)))) -end - function blockmergesort(g::AbstractGradedUnitRange) glabels = blocklabels(g) gblocklengths = blocklengths(g) @@ -101,6 +96,7 @@ function blockmergesort(g::AbstractGradedUnitRange) end blockmergesort(g::GradedUnitRangeDual) = flip(blockmergesort(flip(g))) +blockmergesort(g::AbstractUnitRange) = g # fusion_product produces a sorted, non-dual GradedUnitRange function fusion_product(g1, g2) diff --git a/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl b/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl index c74b0ed179..1ada7d3393 100644 --- a/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl +++ b/NDTensors/src/lib/TensorAlgebra/ext/TensorAlgebraGradedAxesExt/test/test_contract.jl @@ -2,16 +2,13 @@ using BlockArrays: Block, blocksize using Compat: Returns using NDTensors.BlockSparseArrays: BlockSparseArray -using NDTensors.GradedAxes: GradedAxes, gradedrange +using NDTensors.GradedAxes: gradedrange using NDTensors.SparseArrayInterface: densearray using NDTensors.SymmetrySectors: U1 using NDTensors.TensorAlgebra: contract using Random: randn! using Test: @test, @testset -#TODO remove once fuse_labels is defined in Sectors -GradedAxes.fuse_labels(m::U1, n::U1) = U1(m.n + n.n) - function randn_blockdiagonal(elt::Type, axes::Tuple) a = BlockSparseArray{elt}(axes) blockdiaglength = minimum(blocksize(a)) From efe1c7c175c5e69fdfc6e11479be25655c6f5529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 29 Oct 2024 18:03:14 -0400 Subject: [PATCH 20/43] custom type for AbstractGradedUnitRange --- .../src/lib/GradedAxes/src/gradedunitrange.jl | 83 +++++++++++-------- .../lib/GradedAxes/src/gradedunitrangedual.jl | 5 +- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 38fcfb0531..fd8e0d0a23 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -29,19 +29,26 @@ using ..LabelledNumbers: labelled_isequal, unlabel -const AbstractGradedUnitRange{T<:LabelledInteger} = AbstractBlockedUnitRange{T} +abstract type AbstractGradedUnitRange{T,CS} <: AbstractBlockedUnitRange{T,CS} end -const GradedUnitRange{T<:LabelledInteger,BlockLasts<:Vector{T}} = BlockedUnitRange{ - T,BlockLasts -} +struct GradedUnitRange{T,BlockLasts<:Vector{T}} <: AbstractGradedUnitRange{T,BlockLasts} + first::T + lasts::BlockLasts +end -const GradedOneTo{T<:LabelledInteger,BlockLasts<:Vector{T}} = BlockedOneTo{T,BlockLasts} +struct GradedOneTo{T,BlockLasts<:Vector{T}} <: AbstractGradedUnitRange{T,BlockLasts} + lasts::BlockLasts -# This is only needed in certain Julia versions below 1.10 -# (for example Julia 1.6). -# TODO: Delete this once we drop Julia 1.6 support. -function Base.OrdinalRange{T,T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T} - return unlabel_blocks(a) + # assume that lasts is sorted, no checks carried out here + function GradedOneTo(lasts::CS) where {T<:Integer,CS<:AbstractVector{T}} + Base.require_one_based_indexing(lasts) + isempty(lasts) || first(lasts) >= 0 || throw(ArgumentError("blocklasts must be >= 0")) + return new{T,CS}(lasts) + end + function GradedOneTo(lasts::CS) where {T<:Integer,CS<:Tuple{T,Vararg{T}}} + first(lasts) >= 0 || throw(ArgumentError("blocklasts must be >= 0")) + return new{T,CS}(lasts) + end end # == is just a range comparison that ignores labels. Need dedicated function to check equality. @@ -90,7 +97,7 @@ Base.eltype(::Type{<:GradedUnitRange{T}}) where {T} = T function gradedrange(lblocklengths::AbstractVector{<:LabelledInteger}) brange = blockedrange(unlabel.(lblocklengths)) lblocklasts = labelled.(blocklasts(brange), label.(lblocklengths)) - return BlockedOneTo(lblocklasts) + return GradedOneTo(lblocklasts) end # To help with generic code. @@ -118,14 +125,12 @@ end function labelled_blocks(a::BlockedOneTo, labels) # TODO: Use `blocklasts(a)`? That might # cause a recursive loop. - return BlockedOneTo(labelled.(a.lasts, labels)) + return GradedOneTo(labelled.(a.lasts, labels)) end function labelled_blocks(a::BlockedUnitRange, labels) # TODO: Use `first(a)` and `blocklasts(a)`? Those might # cause a recursive loop. - return BlockArrays._BlockedUnitRange( - labelled(a.first, labels[1]), labelled.(a.lasts, labels) - ) + return GradedUnitRange(labelled(a.first, labels[1]), labelled.(a.lasts, labels)) end function BlockArrays.findblock(a::AbstractGradedUnitRange, index::Integer) @@ -185,7 +190,15 @@ function unlabel_blocks(a::BlockedUnitRange) return BlockArrays._BlockedUnitRange(a.first, unlabel.(a.lasts)) end -## BlockedUnitRage interface +function unlabel_blocks(a::GradedOneTo) + # TODO: Use `blocklasts(a)`. + return BlockedOneTo(unlabel.(a.lasts)) +end +function unlabel_blocks(a::GradedUnitRange) + return BlockArrays._BlockedUnitRange(a.first, unlabel.(a.lasts)) +end + +## BlockedUnitRange interface function Base.axes(ga::AbstractGradedUnitRange) return map(axes(unlabel_blocks(ga))) do a @@ -217,9 +230,6 @@ end function Base.first(a::AbstractGradedUnitRange) return gradedunitrange_first(a) end -function Base.first(a::GradedOneTo) - return gradedunitrange_first(a) -end Base.iterate(a::AbstractGradedUnitRange) = isempty(a) ? nothing : (first(a), first(a)) function Base.iterate(a::AbstractGradedUnitRange, i) @@ -232,7 +242,7 @@ function firstblockindices(a::AbstractGradedUnitRange) return labelled.(firstblockindices(unlabel_blocks(a)), blocklabels(a)) end -function blockedunitrange_getindex(a::AbstractGradedUnitRange, index) +function gradedunitrange_getindices(a::AbstractGradedUnitRange, index) # This uses `blocklasts` since that is what is stored # in `BlockedUnitRange`, maybe abstract that away. return labelled(unlabel_blocks(a)[index], get_label(a, index)) @@ -245,27 +255,34 @@ function blocklabels(a::AbstractUnitRange, indices) end end -function blockedunitrange_getindices( +function gradedunitrange_getindices( ga::AbstractGradedUnitRange, indices::AbstractUnitRange{<:Integer} ) a_indices = blockedunitrange_getindices(unlabel_blocks(ga), indices) return labelled_blocks(a_indices, blocklabels(ga, indices)) end +function gradedunitrange_getindices( + a::AbstractGradedUnitRange, + indices::Union{AbstractVector{<:Block{1}},AbstractVector{<:BlockIndexRange{1}}}, +) + return blockedunitrange_getindices(a, indices) +end + # Fixes ambiguity error with: # ```julia -# blockedunitrange_getindices(::GradedUnitRange, ::AbstractUnitRange{<:Integer}) +# gradedunitrange_getindices(::GradedUnitRange, ::AbstractUnitRange{<:Integer}) # ``` # TODO: Try removing once GradedAxes is rewritten for BlockArrays v1. -function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockSlice) +function gradedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockSlice) return a[indices.block] end -function blockedunitrange_getindices(ga::AbstractGradedUnitRange, indices::BlockRange) +function gradedunitrange_getindices(ga::AbstractGradedUnitRange, indices::BlockRange) return labelled_blocks(unlabel_blocks(ga)[indices], blocklabels(ga, indices)) end -function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockIndex{1}) +function gradedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockIndex{1}) return a[block(indices)][blockindex(indices)] end @@ -276,7 +293,7 @@ function Base.getindex(a::AbstractGradedUnitRange, index::Integer) end function Base.getindex(a::AbstractGradedUnitRange, index::Block{1}) - return blockedunitrange_getindex(a, index) + return gradedunitrange_getindices(a, index) end function Base.getindex(a::AbstractGradedUnitRange, indices::BlockIndexRange) @@ -286,18 +303,18 @@ end function Base.getindex( a::AbstractGradedUnitRange, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} ) - return blockedunitrange_getindices(a, indices) + return gradedunitrange_getindices(a, indices) end # Fixes ambiguity error with `BlockArrays`. function Base.getindex( a::AbstractGradedUnitRange, indices::BlockRange{1,Tuple{Base.OneTo{Int}}} ) - return blockedunitrange_getindices(a, indices) + return gradedunitrange_getindices(a, indices) end function Base.getindex(a::AbstractGradedUnitRange, indices::BlockIndex{1}) - return blockedunitrange_getindices(a, indices) + return gradedunitrange_getindices(a, indices) end # Fixes ambiguity issues with: @@ -310,15 +327,15 @@ end # TODO: Maybe not needed once GradedAxes is rewritten # for BlockArrays v1. function Base.getindex(a::AbstractGradedUnitRange, indices::BlockSlice) - return blockedunitrange_getindices(a, indices) + return gradedunitrange_getindices(a, indices) end function Base.getindex(a::AbstractGradedUnitRange, indices) - return blockedunitrange_getindices(a, indices) + return gradedunitrange_getindices(a, indices) end function Base.getindex(a::AbstractGradedUnitRange, indices::AbstractUnitRange{<:Integer}) - return blockedunitrange_getindices(a, indices) + return gradedunitrange_getindices(a, indices) end # This fixes an issue that `combine_blockaxes` was promoting @@ -352,7 +369,7 @@ end # blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices)) # return blockedrange(blocklengths) # ``` -function blockedunitrange_getindices( +function gradedunitrange_getindices( a::AbstractGradedUnitRange, indices::AbstractBlockVector{<:Block{1}} ) blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index 78f3a00417..1ffee324fc 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -1,6 +1,5 @@ -struct GradedUnitRangeDual{ - T<:LabelledInteger,NondualUnitRange<:AbstractGradedUnitRange{T} -} <: AbstractGradedUnitRange{T,Vector{T}} +struct GradedUnitRangeDual{T,CS,NondualUnitRange<:AbstractGradedUnitRange{T,CS}} <: + AbstractGradedUnitRange{T,CS} nondual_unitrange::NondualUnitRange end From 3dcdd20db514bb08674ebb155300358a66f2b45d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 29 Oct 2024 18:39:38 -0400 Subject: [PATCH 21/43] clean new implementation --- .../src/lib/GradedAxes/src/gradedunitrange.jl | 65 +------------------ 1 file changed, 1 insertion(+), 64 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index fd8e0d0a23..032f6b4b69 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -63,29 +63,6 @@ function space_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange) return (isdual(a1) == isdual(a2)) && labelled_isequal(a1, a2) end -# This is only needed in certain Julia versions below 1.10 -# (for example Julia 1.6). -# TODO: Delete this once we drop Julia 1.6 support. -# The type constraint `T<:Integer` is needed to avoid an ambiguity -# error with a conversion method in Base. -function Base.UnitRange{T}( - a::AbstractGradedUnitRange{<:LabelledInteger{T}} -) where {T<:Integer} - return UnitRange(unlabel_blocks(a)) -end - -# This is only needed in certain Julia versions below 1.10 -# (for example Julia 1.6). -# TODO: Delete this once we drop Julia 1.6 support. -# The type constraint `T<:Integer` is needed to avoid an ambiguity -# error with a conversion method in Base. -using BlockArrays: BlockSlice -function Base.UnitRange{T}( - a::BlockSlice{<:Any,<:LabelledInteger{T},<:AbstractUnitRange{<:LabelledInteger{T}}} -) where {T<:Integer} - return UnitRange{T}(a.indices) -end - # TODO: See if this is needed. function Base.AbstractUnitRange{T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T} return unlabel_blocks(a) @@ -107,17 +84,6 @@ end Base.last(a::AbstractGradedUnitRange) = isempty(a.lasts) ? first(a) - 1 : last(a.lasts) -# TODO: This needs to be defined to circumvent an issue -# in the `BlockArrays.BlocksView` constructor. This -# is likely caused by issues around `BlockedUnitRange` constraining -# the element type to be `Int`, which is being fixed in: -# https://github.com/JuliaArrays/BlockArrays.jl/pull/337 -# Remove this definition once that is fixed. -function BlockArrays.blocks(a::AbstractGradedUnitRange) - # TODO: Fix `BlockRange`, try using `BlockRange` instead. - return [a[Block(i)] for i in 1:blocklength(a)] -end - function gradedrange(lblocklengths::AbstractVector{<:Pair{<:Any,<:Integer}}) return gradedrange(labelled.(last.(lblocklengths), first.(lblocklengths))) end @@ -182,14 +148,6 @@ end # TODO: This relies on internals of `BlockArrays`, maybe redesign # to try to avoid that. # TODO: Define `set_grades`, `set_sector_labels`, `set_labels`. -function unlabel_blocks(a::BlockedOneTo) - # TODO: Use `blocklasts(a)`. - return BlockedOneTo(unlabel.(a.lasts)) -end -function unlabel_blocks(a::BlockedUnitRange) - return BlockArrays._BlockedUnitRange(a.first, unlabel.(a.lasts)) -end - function unlabel_blocks(a::GradedOneTo) # TODO: Use `blocklasts(a)`. return BlockedOneTo(unlabel.(a.lasts)) @@ -209,10 +167,7 @@ end function gradedunitrange_blockfirsts(a::AbstractGradedUnitRange) return labelled.(blockfirsts(unlabel_blocks(a)), blocklabels(a)) end -function BlockArrays.blockfirsts(a::GradedUnitRange) - return gradedunitrange_blockfirsts(a) -end -function BlockArrays.blockfirsts(a::GradedOneTo) +function BlockArrays.blockfirsts(a::AbstractGradedUnitRange) return gradedunitrange_blockfirsts(a) end @@ -243,8 +198,6 @@ function firstblockindices(a::AbstractGradedUnitRange) end function gradedunitrange_getindices(a::AbstractGradedUnitRange, index) - # This uses `blocklasts` since that is what is stored - # in `BlockedUnitRange`, maybe abstract that away. return labelled(unlabel_blocks(a)[index], get_label(a, index)) end @@ -269,11 +222,6 @@ function gradedunitrange_getindices( return blockedunitrange_getindices(a, indices) end -# Fixes ambiguity error with: -# ```julia -# gradedunitrange_getindices(::GradedUnitRange, ::AbstractUnitRange{<:Integer}) -# ``` -# TODO: Try removing once GradedAxes is rewritten for BlockArrays v1. function gradedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockSlice) return a[indices.block] end @@ -287,8 +235,6 @@ function gradedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockIn end function Base.getindex(a::AbstractGradedUnitRange, index::Integer) - # This uses `blocklasts` since that is what is stored - # in `BlockedUnitRange`, maybe abstract that away. return labelled(unlabel_blocks(a)[index], get_label(a, index)) end @@ -306,13 +252,6 @@ function Base.getindex( return gradedunitrange_getindices(a, indices) end -# Fixes ambiguity error with `BlockArrays`. -function Base.getindex( - a::AbstractGradedUnitRange, indices::BlockRange{1,Tuple{Base.OneTo{Int}}} -) - return gradedunitrange_getindices(a, indices) -end - function Base.getindex(a::AbstractGradedUnitRange, indices::BlockIndex{1}) return gradedunitrange_getindices(a, indices) end @@ -324,8 +263,6 @@ end # getindex(::GradedUnitRange, ::Any) # getindex(::AbstractUnitRange, ::AbstractUnitRange{<:Integer}) # ``` -# TODO: Maybe not needed once GradedAxes is rewritten -# for BlockArrays v1. function Base.getindex(a::AbstractGradedUnitRange, indices::BlockSlice) return gradedunitrange_getindices(a, indices) end From a184141fb0c58f4565f2f3f731eb33819eb70e3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 29 Oct 2024 19:03:07 -0400 Subject: [PATCH 22/43] custom printing --- NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 032f6b4b69..60e067699f 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -51,6 +51,12 @@ struct GradedOneTo{T,BlockLasts<:Vector{T}} <: AbstractGradedUnitRange{T,BlockLa end end +function Base.show(io::IO, mimetype::MIME"text/plain", g::AbstractGradedUnitRange) + v = blocklabels(g) .=> Int.(blocklengths(g)) + println(typeof(g)) + return show(io, mimetype, v) +end + # == is just a range comparison that ignores labels. Need dedicated function to check equality. struct NoLabel end blocklabels(r::AbstractUnitRange) = Fill(NoLabel(), blocklength(r)) From db719624a335739928f6ce1d06a0393c6b5416c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 31 Oct 2024 19:05:41 -0400 Subject: [PATCH 23/43] fix getindex(a,::Vector{Int} --- NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl | 4 +++- NDTensors/src/lib/GradedAxes/test/test_basics.jl | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 60e067699f..7738d975a9 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -223,7 +223,9 @@ end function gradedunitrange_getindices( a::AbstractGradedUnitRange, - indices::Union{AbstractVector{<:Block{1}},AbstractVector{<:BlockIndexRange{1}}}, + indices::Union{ + AbstractVector{<:Block{1}},AbstractVector{<:BlockIndexRange{1}},Vector{<:Integer} + }, ) return blockedunitrange_getindices(a, indices) end diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 43dc53302d..44e9ac083b 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -20,6 +20,8 @@ using Test: @test, @test_broken, @testset @test eltype(a0) == Bool @test length(a0) == 1 @test labelled_isequal(a0, a0) + @test a0[1] == true + @test a0[[1]] == [true] @test labelled_isequal(a0, 1:1) @test labelled_isequal(1:1, a0) @@ -109,6 +111,7 @@ end @test blocklengths(ax) == blocklengths(a) @test blocklabels(ax) == blocklabels(a) @test blockfirsts(a) == [2, 3] + @test x[[2, 4]] == [labelled(2, "x"), labelled(4, "y")] # Regression test for ambiguity error. x = gradedrange(["x" => 2, "y" => 3]) From c57bc023c3fcadff020ad07d4bc31cbab844be79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 31 Oct 2024 19:18:04 -0400 Subject: [PATCH 24/43] fix GradedUnitRangeDual tests --- NDTensors/src/lib/GradedAxes/src/dual.jl | 2 ++ NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl | 2 -- NDTensors/src/lib/GradedAxes/test/test_dual.jl | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/dual.jl b/NDTensors/src/lib/GradedAxes/src/dual.jl index 28f140d69e..ca985e30a0 100644 --- a/NDTensors/src/lib/GradedAxes/src/dual.jl +++ b/NDTensors/src/lib/GradedAxes/src/dual.jl @@ -5,6 +5,8 @@ isdual(::AbstractUnitRange) = false using NDTensors.LabelledNumbers: LabelledStyle, IsLabelled, NotLabelled, label, labelled, unlabel + +dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i))) label_dual(x) = label_dual(LabelledStyle(x), x) label_dual(::NotLabelled, x) = x label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x))) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index 1ffee324fc..c6d79495a5 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -104,8 +104,6 @@ end Base.unitrange(a::GradedUnitRangeDual) = a using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel -dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i))) - using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock BlockArrays.blockaxes(a::GradedUnitRangeDual) = blockaxes(nondual(a)) BlockArrays.blockfirsts(a::GradedUnitRangeDual) = label_dual.(blockfirsts(nondual(a))) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 04203961a7..714dd04b7a 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -12,6 +12,7 @@ using BlockArrays: blocks, findblock using NDTensors.GradedAxes: + AbstractGradedUnitRange, GradedAxes, GradedUnitRangeDual, OneToOne, @@ -60,6 +61,7 @@ end [gradedrange([U1(0) => 2, U1(1) => 3]), gradedrange([U1(0) => 2, U1(1) => 3])[1:5]] ad = dual(a) @test ad isa GradedUnitRangeDual + @test ad isa AbstractGradedUnitRange @test eltype(ad) == LabelledInteger{Int,U1} @test blocklengths(ad) isa Vector @test eltype(blocklengths(ad)) == eltype(blocklengths(a)) @@ -78,6 +80,8 @@ end @test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))] @test blocklength(ad) == 2 @test blocklengths(ad) == [2, 3] + @test blocklabels(ad) == [U1(0), U1(-1)] + @test label.(blocklengths(ad)) == [U1(0), U1(-1)] @test findblock(ad, 4) == Block(2) @test only(blockaxes(ad)) == Block(1):Block(2) @test blocks(ad) == [labelled(1:2, U1(0)), labelled(3:5, U1(-1))] From 032bde4d26075bca39bc484ef29235e921580240 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 31 Oct 2024 19:26:55 -0400 Subject: [PATCH 25/43] remove doublets --- .../src/lib/GradedAxes/src/gradedunitrange.jl | 5 ---- .../lib/GradedAxes/src/gradedunitrangedual.jl | 24 +------------------ 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 7738d975a9..d32155b46f 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -69,11 +69,6 @@ function space_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange) return (isdual(a1) == isdual(a2)) && labelled_isequal(a1, a2) end -# TODO: See if this is needed. -function Base.AbstractUnitRange{T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T} - return unlabel_blocks(a) -end - # TODO: Use `TypeParameterAccessors`. Base.eltype(::Type{<:GradedUnitRange{T}}) where {T} = T diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index c6d79495a5..def3dbfc75 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -17,16 +17,6 @@ Base.step(a::GradedUnitRangeDual) = label_dual(step(nondual(a))) Base.view(a::GradedUnitRangeDual, index::Block{1}) = a[index] -function Base.show(io::IO, a::GradedUnitRangeDual) - return print(io, GradedUnitRangeDual, "(", blocklasts(a), ")") -end - -function Base.show(io::IO, mimetype::MIME"text/plain", a::GradedUnitRangeDual) - return Base.invoke( - show, Tuple{typeof(io),MIME"text/plain",AbstractArray}, io, mimetype, a - ) -end - function Base.getindex(a::GradedUnitRangeDual, indices::AbstractUnitRange{<:Integer}) return dual(getindex(nondual(a), indices)) end @@ -88,10 +78,7 @@ function BlockArrays.BlockSlice(b::Block, r::GradedUnitRangeDual) return BlockSlice(b, dual(r)) end -using NDTensors.LabelledNumbers: LabelledNumbers, label -LabelledNumbers.label(a::GradedUnitRangeDual) = dual(label(nondual(a))) - -using NDTensors.LabelledNumbers: LabelledUnitRange +using NDTensors.LabelledNumbers: LabelledNumbers, LabelledUnitRange, label # The Base version of `length(::AbstractUnitRange)` drops the label. function Base.length(a::GradedUnitRangeDual{<:Any,<:LabelledUnitRange}) return dual(length(nondual(a))) @@ -131,15 +118,6 @@ function Base.OrdinalRange{Int,Int}( return unlabel(nondual(r)) end -# This is only needed in certain Julia versions below 1.10 -# (for example Julia 1.6). -# TODO: Delete this once we drop Julia 1.6 support. -# The type constraint `T<:Integer` is needed to avoid an ambiguity -# error with a conversion method in Base. -function Base.UnitRange{T}(a::GradedUnitRangeDual{<:LabelledInteger{T}}) where {T<:Integer} - return UnitRange{T}(nondual(a)) -end - function unlabel_blocks(a::GradedUnitRangeDual) return unlabel_blocks(nondual(a)) end From e716903699456b41e3e0355665c051573c99eacf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 31 Oct 2024 19:41:29 -0400 Subject: [PATCH 26/43] use GradedAxes interface --- .../lib/GradedAxes/src/gradedunitrangedual.jl | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index def3dbfc75..bc0824c098 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -17,26 +17,28 @@ Base.step(a::GradedUnitRangeDual) = label_dual(step(nondual(a))) Base.view(a::GradedUnitRangeDual, index::Block{1}) = a[index] -function Base.getindex(a::GradedUnitRangeDual, indices::AbstractUnitRange{<:Integer}) +function gradedunitrange_getindices( + a::GradedUnitRangeDual, indices::AbstractUnitRange{<:Integer} +) return dual(getindex(nondual(a), indices)) end using BlockArrays: Block, BlockIndexRange, BlockRange -function Base.getindex(a::GradedUnitRangeDual, indices::Integer) +function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer) return label_dual(getindex(nondual(a), indices)) end -function Base.getindex(a::GradedUnitRangeDual, indices::Block{1}) +function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1}) return label_dual(getindex(nondual(a), indices)) end -function Base.getindex(a::GradedUnitRangeDual, indices::BlockRange) +function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange) return label_dual(getindex(nondual(a), indices)) end # fix ambiguity -function Base.getindex( +function gradedunitrange_getindices( a::GradedUnitRangeDual, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} ) return dual(getindex(nondual(a), indices)) @@ -52,15 +54,13 @@ function unitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices) end # TODO: Move this to a `BlockArraysExtensions` library. -function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1}) - return a[indices] -end - -function Base.getindex(a::GradedUnitRangeDual, indices::Vector{<:Block{1}}) +function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}}) return unitrangedual_getindices_blocks(a, indices) end -function Base.getindex(a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}}) +function gradedunitrange_getindices( + a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}} +) return unitrangedual_getindices_blocks(a, indices) end @@ -79,16 +79,10 @@ function BlockArrays.BlockSlice(b::Block, r::GradedUnitRangeDual) end using NDTensors.LabelledNumbers: LabelledNumbers, LabelledUnitRange, label -# The Base version of `length(::AbstractUnitRange)` drops the label. -function Base.length(a::GradedUnitRangeDual{<:Any,<:LabelledUnitRange}) - return dual(length(nondual(a))) -end function Base.iterate(a::GradedUnitRangeDual, i) i == last(a) && return nothing return dual.(iterate(nondual(a), i)) end -# TODO: Is this a good definition? -Base.unitrange(a::GradedUnitRangeDual) = a using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, unlabel using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock From 9796676fdcada7eec9fcabef82e5ef12c48dfd1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 1 Nov 2024 11:44:02 -0400 Subject: [PATCH 27/43] Base.AbstractUnitRange actually needed --- NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index d32155b46f..3a56cc49e7 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -69,6 +69,13 @@ function space_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange) return (isdual(a1) == isdual(a2)) && labelled_isequal(a1, a2) end +# needed in BlockSparseArrays +function Base.AbstractUnitRange{T}( + a::AbstractGradedUnitRange{<:LabelledInteger{T}} +) where {T} + return unlabel_blocks(a) +end + # TODO: Use `TypeParameterAccessors`. Base.eltype(::Type{<:GradedUnitRange{T}}) where {T} = T @@ -249,8 +256,9 @@ function Base.getindex(a::AbstractGradedUnitRange, indices::BlockIndexRange) return blockedunitrange_getindices(a, indices) end +# fix ambiguity function Base.getindex( - a::AbstractGradedUnitRange, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} + a::AbstractGradedUnitRange, indices::BlockArrays.BlockRange{1,<:Tuple{Base.OneTo}} ) return gradedunitrange_getindices(a, indices) end From 3b8b2b9aa6395faaa3a4dac1f8724f4b13c57621 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 1 Nov 2024 12:41:53 -0400 Subject: [PATCH 28/43] fix ambiguities --- NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl | 7 ++++++- NDTensors/src/lib/GradedAxes/test/test_basics.jl | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 3a56cc49e7..287e4acd00 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -256,12 +256,17 @@ function Base.getindex(a::AbstractGradedUnitRange, indices::BlockIndexRange) return blockedunitrange_getindices(a, indices) end -# fix ambiguity +# fix ambiguities function Base.getindex( a::AbstractGradedUnitRange, indices::BlockArrays.BlockRange{1,<:Tuple{Base.OneTo}} ) return gradedunitrange_getindices(a, indices) end +function Base.getindex( + a::AbstractGradedUnitRange, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} +) + return gradedunitrange_getindices(a, indices) +end function Base.getindex(a::AbstractGradedUnitRange, indices::BlockIndex{1}) return gradedunitrange_getindices(a, indices) diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 44e9ac083b..02d37f718f 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -1,6 +1,7 @@ @eval module $(gensym()) using BlockArrays: Block, + BlockRange, BlockSlice, BlockVector, blockedrange, @@ -111,7 +112,9 @@ end @test blocklengths(ax) == blocklengths(a) @test blocklabels(ax) == blocklabels(a) @test blockfirsts(a) == [2, 3] + @test x[[2, 4]] == [labelled(2, "x"), labelled(4, "y")] + @test labelled_isequal(x[BlockRange(1)], gradedrange(["x" => 2])) # Regression test for ambiguity error. x = gradedrange(["x" => 2, "y" => 3]) From 71cd86ac3e507cb1c7ea9d26ddd2fac14a6fe843 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 1 Nov 2024 14:58:15 -0400 Subject: [PATCH 29/43] define label_type --- NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 287e4acd00..8017d7674c 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -25,6 +25,7 @@ using ..LabelledNumbers: LabelledInteger, LabelledUnitRange, label, + label_type, labelled, labelled_isequal, unlabel @@ -78,6 +79,8 @@ end # TODO: Use `TypeParameterAccessors`. Base.eltype(::Type{<:GradedUnitRange{T}}) where {T} = T +LabelledNumbers.label_type(g::AbstractGradedUnitRange) = label_type(typeof(g)) +LabelledNumbers.label_type(T::Type{<:AbstractGradedUnitRange}) = label_type(eltype(T)) function gradedrange(lblocklengths::AbstractVector{<:LabelledInteger}) brange = blockedrange(unlabel.(lblocklengths)) From 5c3696780fe60e98d2b53382361bf16665d663ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 1 Nov 2024 15:34:30 -0400 Subject: [PATCH 30/43] improve display --- .../src/lib/GradedAxes/src/gradedunitrange.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 8017d7674c..44dea116f7 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -53,9 +53,18 @@ struct GradedOneTo{T,BlockLasts<:Vector{T}} <: AbstractGradedUnitRange{T,BlockLa end function Base.show(io::IO, mimetype::MIME"text/plain", g::AbstractGradedUnitRange) - v = blocklabels(g) .=> Int.(blocklengths(g)) - println(typeof(g)) - return show(io, mimetype, v) + v = map(blocks(g)) do b + label(b) => unlabel(first(b)):unlabel(last(b)) + end + println(io, typeof(g)) + return print(join(repr.(v), '\n')) +end + +function Base.show(io::IO, g::AbstractGradedUnitRange) + v = map(blocks(g)) do b + label(b) => unlabel(first(b)):unlabel(last(b)) + end + return print(io, nameof(typeof(g)), '[', join(repr.(v), ", "), ']') end # == is just a range comparison that ignores labels. Need dedicated function to check equality. From de6080bc8b4807750556f0e2a941981b164bfc8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 1 Nov 2024 16:08:55 -0400 Subject: [PATCH 31/43] simpler display --- NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 44dea116f7..ae21839a1b 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -53,17 +53,13 @@ struct GradedOneTo{T,BlockLasts<:Vector{T}} <: AbstractGradedUnitRange{T,BlockLa end function Base.show(io::IO, mimetype::MIME"text/plain", g::AbstractGradedUnitRange) - v = map(blocks(g)) do b - label(b) => unlabel(first(b)):unlabel(last(b)) - end + v = map(b -> label(b) => unlabel(b), blocks(g)) println(io, typeof(g)) - return print(join(repr.(v), '\n')) + return print(io, join(repr.(v), '\n')) end function Base.show(io::IO, g::AbstractGradedUnitRange) - v = map(blocks(g)) do b - label(b) => unlabel(first(b)):unlabel(last(b)) - end + v = map(b -> label(b) => unlabel(b), blocks(g)) return print(io, nameof(typeof(g)), '[', join(repr.(v), ", "), ']') end From 24fb3c88aaf2c2399ef1726d234b8e4477d5f3a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 1 Nov 2024 19:02:18 -0400 Subject: [PATCH 32/43] fix combine_blockaxes --- .../test/runtests.jl | 2 ++ .../src/lib/GradedAxes/src/gradedunitrange.jl | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index ffd4407c30..a321fe47ab 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -39,6 +39,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) d2 = gradedrange([U1(0) => 2, U1(1) => 2]) a = BlockSparseArray{elt}(d1, d2, d1, d2) blockdiagonal!(randn!, a) + @test axes(a, 1) isa GradedOneTo + @test axes(view(a, 1:4, 1:4), 1) isa GradedOneTo for b in (a + a, 2 * a) @test size(b) == (4, 4, 4, 4) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index ae21839a1b..a26bf62e7c 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -17,7 +17,8 @@ using BlockArrays: blocklengths, findblock, findblockindex, - mortar + mortar, + sortedunion using Compat: allequal using FillArrays: Fill using ..LabelledNumbers: @@ -304,17 +305,25 @@ end # that mixed dense and graded axes. # TODO: Maybe come up with a more general solution. function BlockArrays.combine_blockaxes( - a1::GradedOneTo{<:LabelledInteger{T}}, a2::Base.OneTo{T} + a1::AbstractGradedUnitRange{T}, a2::Base.OneTo{T} ) where {T<:Integer} combined_blocklasts = sort!(union(unlabel.(blocklasts(a1)), blocklasts(a2))) return BlockedOneTo(combined_blocklasts) end function BlockArrays.combine_blockaxes( - a1::Base.OneTo{T}, a2::GradedOneTo{<:LabelledInteger{T}} + a1::Base.OneTo{T}, a2::AbstractGradedUnitRange{T} ) where {T<:Integer} return BlockArrays.combine_blockaxes(a2, a1) end +# preserve labels inside combine_blockaxes +# TODO dual +function BlockArrays.combine_blockaxes( + a::AbstractGradedUnitRange, b::AbstractGradedUnitRange +) + return gradedrange(sortedunion(blocklasts(a), blocklasts(b))) +end + # Version of length that checks that all blocks have the same label # and returns a labelled length with that label. function labelled_length(a::AbstractBlockVector{<:Integer}) From b2fd32b5be1857ebe59dfa40d74d986930c30b97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 1 Nov 2024 19:06:30 -0400 Subject: [PATCH 33/43] fix slicing BlockSparseArrays --- .../src/BlockArraysExtensions/BlockArraysExtensions.jl | 2 +- .../blocksparsearrayinterface.jl | 8 ++++---- NDTensors/src/lib/GradedAxes/test/test_basics.jl | 5 ++++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 7e7b503475..3214a8a230 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -20,7 +20,7 @@ using BlockArrays: findblockindex using Compat: allequal using Dictionaries: Dictionary, Indices -using ..GradedAxes: blockedunitrange_getindices, to_blockindices +using ..GradedAxes: blockedunitrange_getindices, gradedunitrange_getindices, to_blockindices using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices # A return type for `blocks(array)` when `array` isn't blocked. diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 182504e038..55f202a6a7 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -30,7 +30,7 @@ end # https://github.com/ITensor/ITensors.jl/issues/1336. function blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}}) bs1 = to_blockindices(inds[1], I[1]) - I1 = BlockSlice(bs1, blockedunitrange_getindices(inds[1], I[1])) + I1 = BlockSlice(bs1, gradedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @@ -45,7 +45,7 @@ end # a[[Block(2), Block(1)], [Block(2), Block(1)]] function blocksparse_to_indices(a, inds, I::Tuple{Vector{<:Block{1}},Vararg{Any}}) - I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) + I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @@ -54,7 +54,7 @@ end function blocksparse_to_indices( a, inds, I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},Vararg{Any}} ) - I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) + I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @@ -64,7 +64,7 @@ end function blocksparse_to_indices( a, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}} ) - I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) + I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 02d37f718f..15e04fec87 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -9,7 +9,8 @@ using BlockArrays: blocklasts, blocklength, blocklengths, - blocks + blocks, + combine_blockaxes using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel @@ -94,6 +95,8 @@ end @test length(a[Block(2)]) == 3 @test blocklengths(only(axes(a))) == blocklengths(a) @test blocklabels(only(axes(a))) == blocklabels(a) + + @test combine_blockaxes(a, a) isa GradedOneTo end # Slicing operations From 705da7d852ded1717999ea573e8df2d73b052dff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 4 Nov 2024 12:36:10 -0500 Subject: [PATCH 34/43] use blockedunitrange_getindices instead of gradedunitrange_getindices --- .../BlockArraysExtensions.jl | 2 +- .../blocksparsearrayinterface.jl | 8 +- .../src/lib/GradedAxes/src/gradedunitrange.jl | 73 ++++++++++++------- .../lib/GradedAxes/src/gradedunitrangedual.jl | 14 ++-- 4 files changed, 59 insertions(+), 38 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 3214a8a230..7e7b503475 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -20,7 +20,7 @@ using BlockArrays: findblockindex using Compat: allequal using Dictionaries: Dictionary, Indices -using ..GradedAxes: blockedunitrange_getindices, gradedunitrange_getindices, to_blockindices +using ..GradedAxes: blockedunitrange_getindices, to_blockindices using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices # A return type for `blocks(array)` when `array` isn't blocked. diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 55f202a6a7..182504e038 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -30,7 +30,7 @@ end # https://github.com/ITensor/ITensors.jl/issues/1336. function blocksparse_to_indices(a, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}}) bs1 = to_blockindices(inds[1], I[1]) - I1 = BlockSlice(bs1, gradedunitrange_getindices(inds[1], I[1])) + I1 = BlockSlice(bs1, blockedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @@ -45,7 +45,7 @@ end # a[[Block(2), Block(1)], [Block(2), Block(1)]] function blocksparse_to_indices(a, inds, I::Tuple{Vector{<:Block{1}},Vararg{Any}}) - I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1])) + I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @@ -54,7 +54,7 @@ end function blocksparse_to_indices( a, inds, I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},Vararg{Any}} ) - I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1])) + I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end @@ -64,7 +64,7 @@ end function blocksparse_to_indices( a, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}} ) - I1 = BlockIndices(I[1], gradedunitrange_getindices(inds[1], I[1])) + I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index a26bf62e7c..5961b9b7a0 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -214,10 +214,37 @@ function firstblockindices(a::AbstractGradedUnitRange) return labelled.(firstblockindices(unlabel_blocks(a)), blocklabels(a)) end -function gradedunitrange_getindices(a::AbstractGradedUnitRange, index) +function blockedunitrange_getindices(a::AbstractGradedUnitRange, index::Block{1}) return labelled(unlabel_blocks(a)[index], get_label(a, index)) end +function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::Vector{<:Integer}) + return map(index -> a[index], indices) +end + +function blockedunitrange_getindices(a::AbstractGradedUnitRange, index) + return labelled(unlabel_blocks(a)[index], get_label(a, index)) +end + +function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockIndexRange) + return a[block(indices)][only(indices.indices)] +end + +function blockedunitrange_getindices( + a::AbstractGradedUnitRange, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}} +) + # Without converting `indices` to `Vector`, + # mapping `indices` outputs a `BlockVector` + # which is harder to reason about. + blocks = map(index -> a[index], Vector(indices)) + # We pass `length.(blocks)` to `mortar` in order + # to pass block labels to the axes of the output, + # if they exist. This makes it so that + # `only(axes(a[indices])) isa `GradedUnitRange` + # if `a isa `GradedUnitRange`, for example. + return mortar(blocks, length.(blocks)) +end + # The block labels of the corresponding slice. function blocklabels(a::AbstractUnitRange, indices) return map(_blocks(a, indices)) do block @@ -225,31 +252,22 @@ function blocklabels(a::AbstractUnitRange, indices) end end -function gradedunitrange_getindices( +function blockedunitrange_getindices( ga::AbstractGradedUnitRange, indices::AbstractUnitRange{<:Integer} ) a_indices = blockedunitrange_getindices(unlabel_blocks(ga), indices) return labelled_blocks(a_indices, blocklabels(ga, indices)) end -function gradedunitrange_getindices( - a::AbstractGradedUnitRange, - indices::Union{ - AbstractVector{<:Block{1}},AbstractVector{<:BlockIndexRange{1}},Vector{<:Integer} - }, -) - return blockedunitrange_getindices(a, indices) -end - -function gradedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockSlice) +function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockSlice) return a[indices.block] end -function gradedunitrange_getindices(ga::AbstractGradedUnitRange, indices::BlockRange) +function blockedunitrange_getindices(ga::AbstractGradedUnitRange, indices::BlockRange) return labelled_blocks(unlabel_blocks(ga)[indices], blocklabels(ga, indices)) end -function gradedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockIndex{1}) +function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockIndex{1}) return a[block(indices)][blockindex(indices)] end @@ -258,7 +276,7 @@ function Base.getindex(a::AbstractGradedUnitRange, index::Integer) end function Base.getindex(a::AbstractGradedUnitRange, index::Block{1}) - return gradedunitrange_getindices(a, index) + return blockedunitrange_getindices(a, index) end function Base.getindex(a::AbstractGradedUnitRange, indices::BlockIndexRange) @@ -269,16 +287,16 @@ end function Base.getindex( a::AbstractGradedUnitRange, indices::BlockArrays.BlockRange{1,<:Tuple{Base.OneTo}} ) - return gradedunitrange_getindices(a, indices) + return blockedunitrange_getindices(a, indices) end function Base.getindex( a::AbstractGradedUnitRange, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} ) - return gradedunitrange_getindices(a, indices) + return blockedunitrange_getindices(a, indices) end function Base.getindex(a::AbstractGradedUnitRange, indices::BlockIndex{1}) - return gradedunitrange_getindices(a, indices) + return blockedunitrange_getindices(a, indices) end # Fixes ambiguity issues with: @@ -289,15 +307,15 @@ end # getindex(::AbstractUnitRange, ::AbstractUnitRange{<:Integer}) # ``` function Base.getindex(a::AbstractGradedUnitRange, indices::BlockSlice) - return gradedunitrange_getindices(a, indices) + return blockedunitrange_getindices(a, indices) end function Base.getindex(a::AbstractGradedUnitRange, indices) - return gradedunitrange_getindices(a, indices) + return blockedunitrange_getindices(a, indices) end function Base.getindex(a::AbstractGradedUnitRange, indices::AbstractUnitRange{<:Integer}) - return gradedunitrange_getindices(a, indices) + return blockedunitrange_getindices(a, indices) end # This fixes an issue that `combine_blockaxes` was promoting @@ -318,10 +336,13 @@ end # preserve labels inside combine_blockaxes # TODO dual -function BlockArrays.combine_blockaxes( - a::AbstractGradedUnitRange, b::AbstractGradedUnitRange -) - return gradedrange(sortedunion(blocklasts(a), blocklasts(b))) +function BlockArrays.combine_blockaxes(a::GradedOneTo, b::GradedOneTo) + return GradedOneTo(sortedunion(blocklasts(a), blocklasts(b))) +end +function BlockArrays.combine_blockaxes(a::GradedUnitRange, b::GradedUnitRange) + new_blocklasts = sortedunion(blocklasts(a), blocklasts(b)) + new_first = labelled(oneunit(eltype(new_blocklasts)), label(first(new_blocklasts))) + return GradedUnitRange(new_first, new_blocklasts) end # Version of length that checks that all blocks have the same label @@ -339,7 +360,7 @@ end # blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices)) # return blockedrange(blocklengths) # ``` -function gradedunitrange_getindices( +function blockedunitrange_getindices( a::AbstractGradedUnitRange, indices::AbstractBlockVector{<:Block{1}} ) blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index bc0824c098..80a7553784 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -17,7 +17,7 @@ Base.step(a::GradedUnitRangeDual) = label_dual(step(nondual(a))) Base.view(a::GradedUnitRangeDual, index::Block{1}) = a[index] -function gradedunitrange_getindices( +function blockedunitrange_getindices( a::GradedUnitRangeDual, indices::AbstractUnitRange{<:Integer} ) return dual(getindex(nondual(a), indices)) @@ -25,20 +25,20 @@ end using BlockArrays: Block, BlockIndexRange, BlockRange -function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer) +function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Integer) return label_dual(getindex(nondual(a), indices)) end -function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1}) +function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Block{1}) return label_dual(getindex(nondual(a), indices)) end -function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange) +function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::BlockRange) return label_dual(getindex(nondual(a), indices)) end # fix ambiguity -function gradedunitrange_getindices( +function blockedunitrange_getindices( a::GradedUnitRangeDual, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}} ) return dual(getindex(nondual(a), indices)) @@ -54,11 +54,11 @@ function unitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices) end # TODO: Move this to a `BlockArraysExtensions` library. -function gradedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}}) +function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}}) return unitrangedual_getindices_blocks(a, indices) end -function gradedunitrange_getindices( +function blockedunitrange_getindices( a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}} ) return unitrangedual_getindices_blocks(a, indices) From 38a21f63a91a61166534ba29122ab847ee3a6e3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 4 Nov 2024 13:56:29 -0500 Subject: [PATCH 35/43] fix slicing --- .../ext/BlockSparseArraysGradedAxesExt/test/runtests.jl | 6 ++++-- NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index a321fe47ab..0905613808 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -11,7 +11,8 @@ using NDTensors.GradedAxes: GradedUnitRangeDual, blocklabels, dual, - gradedrange + gradedrange, + isdual using NDTensors.LabelledNumbers: label using NDTensors.SparseArrayInterface: nstored using NDTensors.TensorAlgebra: fusedims, splitdims @@ -40,7 +41,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a = BlockSparseArray{elt}(d1, d2, d1, d2) blockdiagonal!(randn!, a) @test axes(a, 1) isa GradedOneTo - @test axes(view(a, 1:4, 1:4), 1) isa GradedOneTo + @test axes(view(a, 1:4, 1:4, 1:4, 1:4), 1) isa GradedOneTo for b in (a + a, 2 * a) @test size(b) == (4, 4, 4, 4) @@ -121,6 +122,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # TODO: Define and use `isdual` here. for dim in 1:ndims(a) @test typeof(ax[dim]) === typeof(axes(a, dim)) + @test isdual(ax[dim]) == isdual(axes(a, dim)) end @test @view(a[Block(1, 1)])[1, 1] == a[1, 1] @test @view(a[Block(1, 1)])[2, 1] == a[2, 1] diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 5961b9b7a0..129b306eca 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -323,13 +323,13 @@ end # that mixed dense and graded axes. # TODO: Maybe come up with a more general solution. function BlockArrays.combine_blockaxes( - a1::AbstractGradedUnitRange{T}, a2::Base.OneTo{T} + a1::AbstractGradedUnitRange{<:LabelledInteger{T}}, a2::AbstractUnitRange{T} ) where {T<:Integer} combined_blocklasts = sort!(union(unlabel.(blocklasts(a1)), blocklasts(a2))) return BlockedOneTo(combined_blocklasts) end function BlockArrays.combine_blockaxes( - a1::Base.OneTo{T}, a2::AbstractGradedUnitRange{T} + a1::AbstractUnitRange{T}, a2::AbstractGradedUnitRange{<:LabelledInteger{T}} ) where {T<:Integer} return BlockArrays.combine_blockaxes(a2, a1) end From c1ec06d51a038841d194154e4eebfedd53079b8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 4 Nov 2024 17:16:14 -0500 Subject: [PATCH 36/43] fix slicing with BlockVector --- .../test/runtests.jl | 49 ++++++++++++------- .../src/lib/GradedAxes/src/gradedunitrange.jl | 10 ++++ .../src/lib/GradedAxes/test/test_basics.jl | 17 ++++++- .../src/lib/GradedAxes/test/test_dual.jl | 12 ++++- 4 files changed, 68 insertions(+), 20 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 0905613808..28cfab0693 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -112,6 +112,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test blocksize(m) == (3, 3) @test a == splitdims(m, (d1, d2), (d1, d2)) end + @testset "dual axes" begin r = gradedrange([U1(0) => 2, U1(1) => 2]) for ax in ((r, r), (dual(r), r), (r, dual(r)), (dual(r), dual(r))) @@ -119,7 +120,6 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @views for b in [Block(1, 1), Block(2, 2)] a[b] = randn(elt, size(a[b])) end - # TODO: Define and use `isdual` here. for dim in 1:ndims(a) @test typeof(ax[dim]) === typeof(axes(a, dim)) @test isdual(ax[dim]) == isdual(axes(a, dim)) @@ -176,7 +176,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a[I, :] isa AbstractBlockArray @test a[:, I] isa AbstractBlockArray @test size(a[I, I]) == (1, 1) - @test !GradedAxes.isdual(axes(a[I, I], 1)) + @test !isdual(axes(a[I, I], 1)) end @testset "GradedUnitRange" begin @@ -190,14 +190,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test Array(b) == 2 * Array(a) for i in 1:2 @test axes(b, i) isa GradedUnitRange - @test_broken axes(a[:, :], i) isa GradedUnitRange + @test axes(a[:, :], i) isa GradedUnitRange end I = [Block(1)[1:1]] - @test_broken a[I, :] isa AbstractBlockArray - @test_broken a[:, I] isa AbstractBlockArray + @test a[I, :] isa AbstractBlockArray + @test axes(a[I, :], 1) isa GradedOneTo + @test axes(a[I, :], 2) isa GradedUnitRange + + @test a[:, I] isa AbstractBlockArray + @test axes(a[:, I], 2) isa GradedOneTo + @test axes(a[:, I], 1) isa GradedUnitRange @test size(a[I, I]) == (1, 1) - @test_broken GradedAxes.isdual(axes(a[I, I], 1)) + @test !isdual(axes(a[I, I], 1)) end # Test case when all axes are dual. @@ -212,13 +217,18 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test Array(b) == 2 * Array(a) for i in 1:2 @test axes(b, i) isa GradedUnitRangeDual - @test_broken axes(a[:, :], i) isa GradedUnitRangeDual + @test axes(a[:, :], i) isa GradedUnitRangeDual end I = [Block(1)[1:1]] - @test_broken a[I, :] isa AbstractBlockArray - @test_broken a[:, I] isa AbstractBlockArray + @test a[I, :] isa AbstractBlockArray + @test a[:, I] isa AbstractBlockArray @test size(a[I, I]) == (1, 1) - @test_broken GradedAxes.isdual(axes(a[I, I], 1)) + @test isdual(axes(a[I, :], 2)) + @test isdual(axes(a[:, I], 1)) + @test_broken isdual(axes(a[I, :], 1)) + @test_broken isdual(axes(a[:, I], 2)) + @test_broken isdual(axes(a[I, I], 1)) + @test_broken isdual(axes(a[I, I], 2)) end @testset "dual GradedUnitRange" begin @@ -232,14 +242,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test Array(b) == 2 * Array(a) for i in 1:2 @test axes(b, i) isa GradedUnitRangeDual - @test_broken axes(a[:, :], i) isa GradedUnitRangeDual + @test axes(a[:, :], i) isa GradedUnitRangeDual end I = [Block(1)[1:1]] - @test_broken a[I, :] isa AbstractBlockArray - @test_broken a[:, I] isa AbstractBlockArray + @test a[I, :] isa AbstractBlockArray + @test a[:, I] isa AbstractBlockArray @test size(a[I, I]) == (1, 1) - @test_broken GradedAxes.isdual(axes(a[I, I], 1)) + @test isdual(axes(a[I, :], 2)) + @test isdual(axes(a[:, I], 1)) + @test_broken isdual(axes(a[I, :], 1)) + @test_broken isdual(axes(a[:, I], 2)) + @test_broken isdual(axes(a[I, I], 1)) + @test_broken isdual(axes(a[I, I], 2)) end @testset "dual BlockedUnitRange" begin # self dual @@ -261,7 +276,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a[I, :] isa BlockSparseArray @test a[:, I] isa BlockSparseArray @test size(a[I, I]) == (1, 1) - @test !GradedAxes.isdual(axes(a[I, I], 1)) + @test !isdual(axes(a[I, I], 1)) end # Test case when all axes are dual from taking the adjoint. @@ -281,8 +296,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end I = [Block(1)[1:1]] - @test_broken size(b[I, :]) == (1, 4) - @test_broken size(b[:, I]) == (4, 1) + @test size(b[I, :]) == (1, 4) + @test size(b[:, I]) == (4, 1) @test size(b[I, I]) == (1, 1) end end diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 129b306eca..bf02786280 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -222,6 +222,13 @@ function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::Vector return map(index -> a[index], indices) end +function blockedunitrange_getindices( + a::AbstractGradedUnitRange, + indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}, +) + return mortar(map(b -> a[b], blocks(indices))) +end + function blockedunitrange_getindices(a::AbstractGradedUnitRange, index) return labelled(unlabel_blocks(a)[index], get_label(a, index)) end @@ -345,6 +352,9 @@ function BlockArrays.combine_blockaxes(a::GradedUnitRange, b::GradedUnitRange) return GradedUnitRange(new_first, new_blocklasts) end +# preserve axes in SubArray +Base.axes(S::Base.Slice{<:AbstractGradedUnitRange}) = (S.indices,) + # Version of length that checks that all blocks have the same label # and returns a labelled length with that label. function labelled_length(a::AbstractBlockVector{<:Integer}) diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index 15e04fec87..f04f35046c 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -10,7 +10,8 @@ using BlockArrays: blocklength, blocklengths, blocks, - combine_blockaxes + combine_blockaxes, + mortar using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel @@ -97,6 +98,7 @@ end @test blocklabels(only(axes(a))) == blocklabels(a) @test combine_blockaxes(a, a) isa GradedOneTo + @test axes(Base.Slice(a)) isa Tuple{typeof(a)} end # Slicing operations @@ -145,6 +147,7 @@ end @test length(ax) == length(a) @test blocklengths(ax) == blocklengths(a) @test blocklabels(ax) == blocklabels(a) + @test axes(Base.Slice(a)) isa Tuple{typeof(a)} x = gradedrange(["x" => 2, "y" => 3]) a = x[2:4][1:2] @@ -227,5 +230,17 @@ end # once `blocklengths(::BlockVector)` is defined. @test blocklengths(ax) == [2, 2] @test blocklabels(ax) == blocklabels(a) + + x = gradedrange(["x" => 2, "y" => 3]) + I = mortar([Block(1)[1:1]]) + a = x[I] + @test length(a) == 1 + @test label(first(a)) == "x" + + x = gradedrange(["x" => 2, "y" => 3])[1:5] + I = mortar([Block(1)[1:1]]) + a = x[I] + @test length(a) == 1 + @test label(first(a)) == "x" end end diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 714dd04b7a..5a211a77b8 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -10,7 +10,8 @@ using BlockArrays: blocklength, blocklengths, blocks, - findblock + findblock, + mortar using NDTensors.GradedAxes: AbstractGradedUnitRange, GradedAxes, @@ -26,7 +27,7 @@ using NDTensors.GradedAxes: isdual, nondual using NDTensors.LabelledNumbers: LabelledInteger, label, labelled, labelled_isequal -using Test: @test, @testset +using Test: @test, @test_broken, @testset struct U1 n::Int end @@ -75,6 +76,7 @@ end @test isdual(ad) @test !isdual(a) + @test axes(Base.Slice(a)) isa Tuple{typeof(a)} @test blockfirsts(ad) == [labelled(1, U1(0)), labelled(3, U1(-1))] @test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))] @@ -106,6 +108,12 @@ end @test blocklength(blockmergesortperm(ad)) == 2 @test blockmergesortperm(a) == [Block(1), Block(2)] @test blockmergesortperm(ad) == [Block(1), Block(2)] + + I = mortar([Block(2)[1:1]]) + g = ad[I] + @test length(g) == 1 + @test label(first(g)) == U1(-1) + @test_broken isdual(g[Block(1)]) end end From 5111bb2903c830a43732ebade53e672778165e13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 4 Nov 2024 18:04:13 -0500 Subject: [PATCH 37/43] use isdual ins tests --- .../ext/BlockSparseArraysGradedAxesExt/test/runtests.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 28cfab0693..e6b9955a77 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -140,17 +140,16 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a[I] == a_dense[I] end @test axes(a') == dual.(reverse(axes(a))) - # TODO: Define and use `isdual` here. - @test typeof(axes(a', 1)) === typeof(dual(axes(a, 2))) - @test typeof(axes(a', 2)) === typeof(dual(axes(a, 1))) + + @test isdual(axes(a', 1)) ⊻ isdual(axes(a, 2)) + @test isdual(axes(a', 2)) ⊻ isdual(axes(a, 1)) @test isnothing(show(devnull, MIME("text/plain"), a)) # Check preserving dual in tensor algebra. for b in (a + a, 2 * a, 3 * a - a) @test Array(b) ≈ 2 * Array(a) - # TODO: Define and use `isdual` here. for dim in 1:ndims(a) - @test typeof(axes(b, dim)) === typeof(axes(b, dim)) + @test isdual(axes(b, dim)) == isdual(axes(a, dim)) end end From a0ddb30a6c09a4e2a33ff53907beccec1312bbc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 4 Nov 2024 18:14:56 -0500 Subject: [PATCH 38/43] add broken tests --- NDTensors/src/lib/GradedAxes/test/test_dual.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 5a211a77b8..6049041bb0 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -109,6 +109,8 @@ end @test blockmergesortperm(a) == [Block(1), Block(2)] @test blockmergesortperm(ad) == [Block(1), Block(2)] + @test_broken isdual(ad[Block(1)]) + @test_broken isdual(ad[Block(1)[1:1]]) I = mortar([Block(2)[1:1]]) g = ad[I] @test length(g) == 1 From 92707fbd05a64d86d1ae53f31651fe3dba831ffc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 5 Nov 2024 10:07:14 -0500 Subject: [PATCH 39/43] remove dead code --- .../src/lib/GradedAxes/src/gradedunitrange.jl | 1 - .../lib/GradedAxes/src/gradedunitrangedual.jl | 19 +++---------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index bf02786280..7f951b54dd 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -342,7 +342,6 @@ function BlockArrays.combine_blockaxes( end # preserve labels inside combine_blockaxes -# TODO dual function BlockArrays.combine_blockaxes(a::GradedOneTo, b::GradedOneTo) return GradedOneTo(sortedunion(blocklasts(a), blocklasts(b))) end diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index 80a7553784..c4cb491396 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -48,20 +48,20 @@ function BlockArrays.blocklengths(a::GradedUnitRangeDual) return dual.(blocklengths(nondual(a))) end -function unitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices) +function gradedunitrangedual_getindices_blocks(a::GradedUnitRangeDual, indices) a_indices = getindex(nondual(a), indices) return mortar([label_dual(b) for b in blocks(a_indices)]) end # TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices(a::GradedUnitRangeDual, indices::Vector{<:Block{1}}) - return unitrangedual_getindices_blocks(a, indices) + return gradedunitrangedual_getindices_blocks(a, indices) end function blockedunitrange_getindices( a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}} ) - return unitrangedual_getindices_blocks(a, indices) + return gradedunitrangedual_getindices_blocks(a, indices) end Base.axes(a::GradedUnitRangeDual) = axes(nondual(a)) @@ -99,19 +99,6 @@ function BlockArrays.combine_blockaxes(a1::GradedUnitRangeDual, a2::GradedUnitRa return dual(combine_blockaxes(dual(a1), dual(a2))) end -# This is needed when constructing `CartesianIndices` from -# a tuple of unit ranges that have this kind of dual unit range. -# TODO: See if we can find some more elegant way of constructing -# `CartesianIndices`, maybe by defining conversion of `LabelledInteger` -# to `Int`, defining a more general `convert` function, etc. -function Base.OrdinalRange{Int,Int}( - r::GradedUnitRangeDual{<:LabelledInteger{Int},<:LabelledUnitRange{Int,UnitRange{Int}}} -) - # TODO: Implement this broadcasting operation and use it here. - # return Int.(r) - return unlabel(nondual(r)) -end - function unlabel_blocks(a::GradedUnitRangeDual) return unlabel_blocks(nondual(a)) end From 23b4fb1c3114dc179c04356a24f4d5c54cac6fa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 5 Nov 2024 09:31:49 -0500 Subject: [PATCH 40/43] more tests --- .../lib/GradedAxes/src/gradedunitrangedual.jl | 2 +- .../src/lib/GradedAxes/test/test_basics.jl | 14 ++++++++++++-- NDTensors/src/lib/GradedAxes/test/test_dual.jl | 17 ++++++++++++++++- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index c4cb491396..b5d85b9e97 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -96,7 +96,7 @@ end blocklabels(a::GradedUnitRangeDual) = dual.(blocklabels(nondual(a))) function BlockArrays.combine_blockaxes(a1::GradedUnitRangeDual, a2::GradedUnitRangeDual) - return dual(combine_blockaxes(dual(a1), dual(a2))) + return dual(combine_blockaxes(nondual(a1), nondual(a2))) end function unlabel_blocks(a::GradedUnitRangeDual) diff --git a/NDTensors/src/lib/GradedAxes/test/test_basics.jl b/NDTensors/src/lib/GradedAxes/test/test_basics.jl index f04f35046c..90faa59b93 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_basics.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_basics.jl @@ -12,7 +12,8 @@ using BlockArrays: blocks, combine_blockaxes, mortar -using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange +using NDTensors.GradedAxes: + GradedOneTo, GradedUnitRange, OneToOne, blocklabels, gradedrange, space_isequal using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, labelled_isequal, unlabel using Test: @test, @test_broken, @testset @@ -97,8 +98,12 @@ end @test blocklengths(only(axes(a))) == blocklengths(a) @test blocklabels(only(axes(a))) == blocklabels(a) - @test combine_blockaxes(a, a) isa GradedOneTo @test axes(Base.Slice(a)) isa Tuple{typeof(a)} + @test AbstractUnitRange{Int}(a) == 1:5 + b = combine_blockaxes(a, a) + @test b isa GradedOneTo + @test b == 1:5 + @test space_isequal(b, a) end # Slicing operations @@ -118,6 +123,11 @@ end @test blocklabels(ax) == blocklabels(a) @test blockfirsts(a) == [2, 3] + @test AbstractUnitRange{Int}(a) == 2:4 + b = combine_blockaxes(a, a) + @test b isa GradedUnitRange + @test b == 1:4 + @test x[[2, 4]] == [labelled(2, "x"), labelled(4, "y")] @test labelled_isequal(x[BlockRange(1)], gradedrange(["x" => 2])) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 6049041bb0..18dbac045c 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -11,7 +11,8 @@ using BlockArrays: blocklengths, blocks, findblock, - mortar + mortar, + combine_blockaxes using NDTensors.GradedAxes: AbstractGradedUnitRange, GradedAxes, @@ -77,6 +78,20 @@ end @test isdual(ad) @test !isdual(a) @test axes(Base.Slice(a)) isa Tuple{typeof(a)} + @test AbstractUnitRange{Int}(ad) == 1:5 + b = combine_blockaxes(ad, ad) + @test b isa GradedUnitRangeDual + @test b == 1:5 + @test space_isequal(b, ad) + + for x in iterate(ad) + @test x == 1 + @test label(x) == U1(0) + end + for x in iterate(ad, labelled(3, U1(-1))) + @test x == 4 + @test label(x) == U1(-1) + end @test blockfirsts(ad) == [labelled(1, U1(0)), labelled(3, U1(-1))] @test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))] From 0953484b1ea7838ee291fc279afd2ce7771704d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 5 Nov 2024 10:44:28 -0500 Subject: [PATCH 41/43] remove julia v1.6 code --- .../src/lib/LabelledNumbers/src/labelledunitrange.jl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl index a82c666987..03965f62f5 100644 --- a/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl +++ b/NDTensors/src/lib/LabelledNumbers/src/labelledunitrange.jl @@ -17,16 +17,6 @@ unlabel_type(::Type{<:LabelledUnitRange{Value}}) where {Value} = Value function Base.AbstractUnitRange{T}(a::LabelledUnitRange) where {T} return AbstractUnitRange{T}(unlabel(a)) end -# Used by `CartesianIndices` constructor. -# TODO: Seems to only be needed for Julia v1.6, maybe remove once we -# drop Julia v1.6 support. -function Base.OrdinalRange{T1,T2}(a::LabelledUnitRange) where {T1,T2<:Integer} - return OrdinalRange{T1,T2}(unlabel(a)) -end -# Fix ambiguity error in Julia v1.10. -function Base.OrdinalRange{T,T}(a::LabelledUnitRange) where {T<:Integer} - return OrdinalRange{T,T}(unlabel(a)) -end # TODO: Is this a good definition? Base.unitrange(a::LabelledUnitRange) = a From 0c6f7e5d7ed62d47461c35c86a2d00ae300deccf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 5 Nov 2024 11:50:58 -0500 Subject: [PATCH 42/43] housekeeping --- .../ext/BlockSparseArraysGradedAxesExt/test/runtests.jl | 4 ++-- NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index e6b9955a77..67cf96de00 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -205,8 +205,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end # Test case when all axes are dual. - @testset "dual BlockedOneTo" begin - r = gradedrange([U1(0) => 2, U1(1) => 2]) + @testset "dual GradedOneTo" begin + r = gradedrange([U1(-1) => 2, U1(1) => 2]) a = BlockSparseArray{elt}(dual(r), dual(r)) @views for i in [Block(1, 1), Block(2, 2)] a[i] = randn(elt, size(a[i])) diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 7f951b54dd..036efd9c0f 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -53,7 +53,7 @@ struct GradedOneTo{T,BlockLasts<:Vector{T}} <: AbstractGradedUnitRange{T,BlockLa end end -function Base.show(io::IO, mimetype::MIME"text/plain", g::AbstractGradedUnitRange) +function Base.show(io::IO, ::MIME"text/plain", g::AbstractGradedUnitRange) v = map(b -> label(b) => unlabel(b), blocks(g)) println(io, typeof(g)) return print(io, join(repr.(v), '\n')) From 20009e148e3fcc36902688807ff2b5841c790ce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 5 Nov 2024 12:00:08 -0500 Subject: [PATCH 43/43] use BlockLasts as argument name --- .../BlockSparseArraysGradedAxesExt/test/runtests.jl | 4 ++-- NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl | 11 ++++++----- .../src/lib/GradedAxes/src/gradedunitrangedual.jl | 5 +++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index 67cf96de00..585dbb5afe 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -141,8 +141,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) end @test axes(a') == dual.(reverse(axes(a))) - @test isdual(axes(a', 1)) ⊻ isdual(axes(a, 2)) - @test isdual(axes(a', 2)) ⊻ isdual(axes(a, 1)) + @test isdual(axes(a', 1)) ≠ isdual(axes(a, 2)) + @test isdual(axes(a', 2)) ≠ isdual(axes(a, 1)) @test isnothing(show(devnull, MIME("text/plain"), a)) # Check preserving dual in tensor algebra. diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl index 036efd9c0f..0bd35707a7 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl @@ -31,7 +31,8 @@ using ..LabelledNumbers: labelled_isequal, unlabel -abstract type AbstractGradedUnitRange{T,CS} <: AbstractBlockedUnitRange{T,CS} end +abstract type AbstractGradedUnitRange{T,BlockLasts} <: + AbstractBlockedUnitRange{T,BlockLasts} end struct GradedUnitRange{T,BlockLasts<:Vector{T}} <: AbstractGradedUnitRange{T,BlockLasts} first::T @@ -42,14 +43,14 @@ struct GradedOneTo{T,BlockLasts<:Vector{T}} <: AbstractGradedUnitRange{T,BlockLa lasts::BlockLasts # assume that lasts is sorted, no checks carried out here - function GradedOneTo(lasts::CS) where {T<:Integer,CS<:AbstractVector{T}} + function GradedOneTo(lasts::BlockLasts) where {T<:Integer,BlockLasts<:AbstractVector{T}} Base.require_one_based_indexing(lasts) isempty(lasts) || first(lasts) >= 0 || throw(ArgumentError("blocklasts must be >= 0")) - return new{T,CS}(lasts) + return new{T,BlockLasts}(lasts) end - function GradedOneTo(lasts::CS) where {T<:Integer,CS<:Tuple{T,Vararg{T}}} + function GradedOneTo(lasts::BlockLasts) where {T<:Integer,BlockLasts<:Tuple{T,Vararg{T}}} first(lasts) >= 0 || throw(ArgumentError("blocklasts must be >= 0")) - return new{T,CS}(lasts) + return new{T,BlockLasts}(lasts) end end diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index b5d85b9e97..217d4b401f 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -1,5 +1,6 @@ -struct GradedUnitRangeDual{T,CS,NondualUnitRange<:AbstractGradedUnitRange{T,CS}} <: - AbstractGradedUnitRange{T,CS} +struct GradedUnitRangeDual{ + T,BlockLasts,NondualUnitRange<:AbstractGradedUnitRange{T,BlockLasts} +} <: AbstractGradedUnitRange{T,BlockLasts} nondual_unitrange::NondualUnitRange end