Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GradedAxes] Introduce GradedUnitRangeDual #1531

Merged
merged 46 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
2869a8c
pass GradedAxes test
ogauthe Sep 5, 2024
b21f404
separate types UnitRangeDual and GradedUnitRangeDual
ogauthe Sep 10, 2024
743592a
pass BlockSparseArrays tests
ogauthe Sep 10, 2024
36e7bfb
test UnitRangeDual
ogauthe Sep 11, 2024
cb1c4d9
fix blocklengths
ogauthe Sep 11, 2024
5a7d4a7
add tests
ogauthe Sep 11, 2024
c373729
fix some slicing
ogauthe Sep 12, 2024
113a1fe
fix tests
ogauthe Sep 16, 2024
14c95ca
generalize gradedisequal
ogauthe Sep 16, 2024
a11107c
remove UnitRangeDual
ogauthe Sep 17, 2024
dc762be
fix tests
ogauthe Sep 17, 2024
32d0f11
add tests
ogauthe Sep 17, 2024
47a9ed1
remove tests on a[:,:]
ogauthe Sep 17, 2024
9b35445
Merge branch 'main' into GradedUnitRangeDual
ogauthe Oct 25, 2024
1131945
WIP
ogauthe Oct 25, 2024
cc6d7ea
Merge branch 'main' into GradedUnitRangeDual
ogauthe Oct 25, 2024
3f6bd2e
passing tests
ogauthe Oct 28, 2024
8a353bd
finish merging
ogauthe Oct 28, 2024
aa1b655
more tests
ogauthe Oct 28, 2024
5c5abd6
more tests
ogauthe Oct 28, 2024
d7e3a8d
reduce diff
ogauthe Oct 29, 2024
efe1c7c
custom type for AbstractGradedUnitRange
ogauthe Oct 29, 2024
3dcdd20
clean new implementation
ogauthe Oct 29, 2024
a184141
custom printing
ogauthe Oct 29, 2024
db71962
fix getindex(a,::Vector{Int}
ogauthe Oct 31, 2024
c57bc02
fix GradedUnitRangeDual tests
ogauthe Oct 31, 2024
032bde4
remove doublets
ogauthe Oct 31, 2024
e716903
use GradedAxes interface
ogauthe Oct 31, 2024
9796676
Base.AbstractUnitRange actually needed
ogauthe Nov 1, 2024
3b8b2b9
fix ambiguities
ogauthe Nov 1, 2024
71cd86a
define label_type
ogauthe Nov 1, 2024
5c36967
improve display
ogauthe Nov 1, 2024
de6080b
simpler display
ogauthe Nov 1, 2024
24fb3c8
fix combine_blockaxes
ogauthe Nov 1, 2024
b2fd32b
fix slicing BlockSparseArrays
ogauthe Nov 1, 2024
705da7d
use blockedunitrange_getindices instead of gradedunitrange_getindices
ogauthe Nov 4, 2024
38a21f6
fix slicing
ogauthe Nov 4, 2024
c1ec06d
fix slicing with BlockVector
ogauthe Nov 4, 2024
5111bb2
use isdual ins tests
ogauthe Nov 4, 2024
a0ddb30
add broken tests
ogauthe Nov 4, 2024
92707fb
remove dead code
ogauthe Nov 5, 2024
23b4fb1
more tests
ogauthe Nov 5, 2024
0953484
remove julia v1.6 code
ogauthe Nov 5, 2024
0c6f7e5
housekeeping
ogauthe Nov 5, 2024
20009e1
use BlockLasts as argument name
ogauthe Nov 5, 2024
d632ff1
Merge branch 'main' into GradedUnitRangeDual
mtfishman Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
@eval module $(gensym())
using Compat: Returns
using Test: @test, @testset, @test_broken
using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using BlockArrays:
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
using NDTensors.GradedAxes:
GradedAxes, GradedOneTo, UnitRangeDual, blocklabels, dual, gradedrange
GradedAxes,
GradedOneTo,
GradedUnitRange,
GradedUnitRangeDual,
blocklabels,
dual,
gradedrange,
isdual
using NDTensors.LabelledNumbers: label
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: fusedims, splitdims
using LinearAlgebra: adjoint
using Random: randn!
function blockdiagonal!(f, a::AbstractArray)
for i in 1:minimum(blocksize(a))
Expand All @@ -31,15 +40,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(d1, d2, d1, d2)
blockdiagonal!(randn!, a)
@test axes(a, 1) isa GradedOneTo
@test axes(view(a, 1:4, 1:4, 1:4, 1:4), 1) isa GradedOneTo

for b in (a + a, 2 * a)
@test size(b) == (4, 4, 4, 4)
@test blocksize(b) == (2, 2, 2, 2)
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
@test nstored(b) == 32
@test block_nstored(b) == 2
# TODO: Have to investigate why this fails
# on Julia v1.6, or drop support for v1.6.
for i in 1:ndims(a)
@test axes(b, i) isa GradedOneTo
end
Expand Down Expand Up @@ -103,16 +112,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test blocksize(m) == (3, 3)
@test a == splitdims(m, (d1, d2), (d1, d2))
end

@testset "dual axes" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
for ax in ((r, r), (dual(r), r), (r, dual(r)), (dual(r), dual(r)))
a = BlockSparseArray{elt}(ax...)
@views for b in [Block(1, 1), Block(2, 2)]
a[b] = randn(elt, size(a[b]))
end
# TODO: Define and use `isdual` here.
for dim in 1:ndims(a)
@test typeof(ax[dim]) === typeof(axes(a, dim))
@test isdual(ax[dim]) == isdual(axes(a, dim))
end
@test @view(a[Block(1, 1)])[1, 1] == a[1, 1]
@test @view(a[Block(1, 1)])[2, 1] == a[2, 1]
Expand All @@ -130,41 +140,149 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test a[I] == a_dense[I]
end
@test axes(a') == dual.(reverse(axes(a)))
# TODO: Define and use `isdual` here.
@test typeof(axes(a', 1)) === typeof(dual(axes(a, 2)))
@test typeof(axes(a', 2)) === typeof(dual(axes(a, 1)))

@test isdual(axes(a', 1)) ≠ isdual(axes(a, 2))
@test isdual(axes(a', 2)) ≠ isdual(axes(a, 1))
@test isnothing(show(devnull, MIME("text/plain"), a))

# Check preserving dual in tensor algebra.
for b in (a + a, 2 * a, 3 * a - a)
@test Array(b) ≈ 2 * Array(a)
# TODO: Define and use `isdual` here.
for dim in 1:ndims(a)
@test typeof(axes(b, dim)) === typeof(axes(b, dim))
@test isdual(axes(b, dim)) == isdual(axes(a, dim))
end
end

@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
end

@testset "GradedOneTo" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(r, r)
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedOneTo
@test axes(a[:, :], i) isa GradedOneTo
end

I = [Block(1)[1:1]]
@test a[I, :] isa AbstractBlockArray
@test a[:, I] isa AbstractBlockArray
@test size(a[I, I]) == (1, 1)
@test !isdual(axes(a[I, I], 1))
end

@testset "GradedUnitRange" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
a = BlockSparseArray{elt}(r, r)
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedUnitRange
@test axes(a[:, :], i) isa GradedUnitRange
end

I = [Block(1)[1:1]]
@test a[I, :] isa AbstractBlockArray
@test axes(a[I, :], 1) isa GradedOneTo
@test axes(a[I, :], 2) isa GradedUnitRange

@test a[:, I] isa AbstractBlockArray
@test axes(a[:, I], 2) isa GradedOneTo
@test axes(a[:, I], 1) isa GradedUnitRange
@test size(a[I, I]) == (1, 1)
@test !isdual(axes(a[I, I], 1))
end

# Test case when all axes are dual.
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
@testset "dual GradedOneTo" begin
r = gradedrange([U1(-1) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for ax in axes(b)
@test ax isa UnitRangeDual
for i in 1:2
@test axes(b, i) isa GradedUnitRangeDual
@test axes(a[:, :], i) isa GradedUnitRangeDual
end
I = [Block(1)[1:1]]
@test a[I, :] isa AbstractBlockArray
@test a[:, I] isa AbstractBlockArray
@test size(a[I, I]) == (1, 1)
@test isdual(axes(a[I, :], 2))
@test isdual(axes(a[:, I], 1))
@test_broken isdual(axes(a[I, :], 1))
@test_broken isdual(axes(a[:, I], 2))
@test_broken isdual(axes(a[I, I], 1))
@test_broken isdual(axes(a[I, I], 2))
end

@testset "dual GradedUnitRange" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])[1:3]
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for i in 1:2
@test axes(b, i) isa GradedUnitRangeDual
@test axes(a[:, :], i) isa GradedUnitRangeDual
end

I = [Block(1)[1:1]]
@test a[I, :] isa AbstractBlockArray
@test a[:, I] isa AbstractBlockArray
@test size(a[I, I]) == (1, 1)
@test isdual(axes(a[I, :], 2))
@test isdual(axes(a[:, I], 1))
@test_broken isdual(axes(a[I, :], 1))
@test_broken isdual(axes(a[:, I], 2))
@test_broken isdual(axes(a[I, I], 1))
@test_broken isdual(axes(a[I, I], 2))
end

# Test case when all axes are dual
# from taking the adjoint.
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
@testset "dual BlockedUnitRange" begin # self dual
r = blockedrange([2, 2])
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
@test a[:, :] isa BlockSparseArray
for i in 1:2
@test axes(b, i) isa BlockedOneTo
@test axes(a[:, :], i) isa BlockedOneTo
end

I = [Block(1)[1:1]]
@test a[I, :] isa BlockSparseArray
@test a[:, I] isa BlockSparseArray
@test size(a[I, I]) == (1, 1)
@test !isdual(axes(a[I, I], 1))
end

# Test case when all axes are dual from taking the adjoint.
for r in (
gradedrange([U1(0) => 2, U1(1) => 2]),
gradedrange([U1(0) => 2, U1(1) => 2])[begin:end],
)
a = BlockSparseArray{elt}(r, r)
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
Expand All @@ -173,8 +291,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)'
for ax in axes(b)
@test ax isa UnitRangeDual
@test ax isa typeof(dual(r))
end

I = [Block(1)[1:1]]
@test size(b[I, :]) == (1, 4)
@test size(b[:, I]) == (4, 1)
@test size(b[I, I]) == (1, 1)
end
end
@testset "Matrix multiplication" begin
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
using BlockArrays:
BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock
AbstractBlockedUnitRange,
BlockArrays,
Block,
BlockIndexRange,
BlockedVector,
blocklength,
blocksize,
viewblock

# This splits `BlockIndexRange{N}` into
# `NTuple{N,BlockIndexRange{1}}`.
Expand Down Expand Up @@ -191,7 +198,9 @@ function to_blockindexrange(
# work right now.
return blocks(a.blocks)[Int(I)]
end
function to_blockindexrange(a::Base.Slice{<:BlockedOneTo{<:Integer}}, I::Block{1})
function to_blockindexrange(
a::Base.Slice{<:AbstractBlockedUnitRange{<:Integer}}, I::Block{1}
)
@assert I in only(blockaxes(a.indices))
return I
end
Expand Down
21 changes: 19 additions & 2 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@ using BlockArrays:
blocksizes,
mortar
using Compat: @compat
using LinearAlgebra: mul!
using LinearAlgebra: Adjoint, mul!
using NDTensors.BlockSparseArrays:
@view!, BlockSparseArray, BlockView, block_nstored, block_reshape, view!
@view!,
BlockSparseArray,
BlockView,
block_nstored,
block_reshape,
block_stored_indices,
view!
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: contract
using Test: @test, @test_broken, @test_throws, @testset
Expand All @@ -44,6 +50,17 @@ include("TestBlockSparseArraysUtils.jl")
a[Block(2, 2)] = randn(elt, 3, 3)
@test a[2:4, 4] == Array(a)[2:4, 4]
@test_broken a[4, 2:4]

@test a[Block(1), :] isa BlockSparseArray{elt}
@test adjoint(a) isa Adjoint{elt,<:BlockSparseArray}
@test_broken adjoint(a)[Block(1), :] isa Adjoint{elt,<:BlockSparseArray}
# could also be directly a BlockSparseArray

a = BlockSparseArray{elt}([1], [1, 1])
a[1, 2] = 1
@test [a[Block(Tuple(it))] for it in eachindex(block_stored_indices(a))] isa Vector
ah = adjoint(a)
@test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector
end
@testset "Basics" begin
a = BlockSparseArray{elt}([2, 3], [2, 3])
Expand Down
3 changes: 2 additions & 1 deletion NDTensors/src/lib/GradedAxes/src/GradedAxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module GradedAxes
include("blockedunitrange.jl")
include("gradedunitrange.jl")
include("dual.jl")
include("unitrangedual.jl")
include("gradedunitrangedual.jl")
include("onetoone.jl")
include("fusion.jl")
end
2 changes: 1 addition & 1 deletion NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ end
# Slice `a` by `I`, returning a:
# `BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}`
# with the `BlockIndex{1}` corresponding to each value of `I`.
function to_blockindices(a::BlockedOneTo{<:Integer}, I::UnitRange{<:Integer})
function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<:Integer})
return mortar(
map(blocks(blockedunitrange_getindices(a, I))) do r
bi_first = findblockindex(a, first(r))
Expand Down
7 changes: 6 additions & 1 deletion NDTensors/src/lib/GradedAxes/src/dual.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
function dual end
# default behavior: self-dual
dual(r::AbstractUnitRange) = r
nondual(r::AbstractUnitRange) = r
isdual(::AbstractUnitRange) = false

using NDTensors.LabelledNumbers:
LabelledStyle, IsLabelled, NotLabelled, label, labelled, unlabel

dual(i::LabelledInteger) = labelled(unlabel(i), dual(label(i)))
label_dual(x) = label_dual(LabelledStyle(x), x)
label_dual(::NotLabelled, x) = x
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))
Expand Down
15 changes: 4 additions & 11 deletions NDTensors/src/lib/GradedAxes/src/fusion.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
using BlockArrays: AbstractBlockedUnitRange, blocklengths

# Represents the range `1:1` or `Base.OneTo(1)`.
struct OneToOne{T} <: AbstractUnitRange{T} end
OneToOne() = OneToOne{Bool}()
Base.first(a::OneToOne) = one(eltype(a))
Base.last(a::OneToOne) = one(eltype(a))
BlockArrays.blockaxes(g::OneToOne) = (Block.(g),) # BlockArrays default crashes for OneToOne{Bool}

# https://github.com/ITensor/ITensors.jl/blob/v0.3.57/NDTensors/src/lib/GradedAxes/src/tensor_product.jl
# https://en.wikipedia.org/wiki/Tensor_product
# https://github.com/KeitaNakamura/Tensorial.jl
Expand All @@ -20,7 +13,7 @@ function tensor_product(
end

flip_dual(r::AbstractUnitRange) = r
flip_dual(r::UnitRangeDual) = flip(r)
flip_dual(r::GradedUnitRangeDual) = flip(r)
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
return tensor_product(flip_dual(a1), flip_dual(a2))
end
Expand Down Expand Up @@ -67,7 +60,7 @@ function tensor_product(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRan
return blockedrange(new_blocklengths)
end

# convention: sort UnitRangeDual according to nondual blocks
# convention: sort GradedUnitRangeDual according to nondual blocks
function blocksortperm(a::AbstractUnitRange)
return Block.(sortperm(blocklabels(nondual(a))))
end
Expand Down Expand Up @@ -102,7 +95,7 @@ function blockmergesort(g::AbstractGradedUnitRange)
return gradedrange(new_blocklengths)
end

blockmergesort(g::UnitRangeDual) = flip(blockmergesort(flip(g)))
blockmergesort(g::GradedUnitRangeDual) = flip(blockmergesort(flip(g)))
blockmergesort(g::AbstractUnitRange) = g

# fusion_product produces a sorted, non-dual GradedUnitRange
Expand All @@ -111,7 +104,7 @@ function fusion_product(g1, g2)
end

fusion_product(g::AbstractUnitRange) = blockmergesort(g)
fusion_product(g::UnitRangeDual) = fusion_product(flip(g))
fusion_product(g::GradedUnitRangeDual) = fusion_product(flip(g))

# recursive fusion_product. Simpler than reduce + fix type stability issues with reduce
function fusion_product(g1, g2, g3...)
Expand Down
Loading
Loading