Skip to content

Commit

Permalink
[NDTensors] Fix output type of Empty BlockSparse contraction (#1135)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Jun 21, 2023
1 parent dc92810 commit 08daddf
Show file tree
Hide file tree
Showing 20 changed files with 132 additions and 41 deletions.
1 change: 1 addition & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ include("abstractarray/fill.jl")
include("array/set_types.jl")
include("tupletools.jl")
include("tensorstorage/tensorstorage.jl")
include("tensorstorage/set_types.jl")
include("tensorstorage/default_storage.jl")
include("tensorstorage/similar.jl")
include("tensor/tensor.jl")
Expand Down
9 changes: 9 additions & 0 deletions NDTensors/src/abstractarray/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ function set_ndims(arraytype::Type{<:AbstractArray}, ndims)
)
end

# This is for uniform `Diag` storage which uses
# a Number as the data type.
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function set_ndims(numbertype::Type{<:Number}, ndims)
return numbertype
end

"""
`set_indstype` should be overloaded for
types with structured dimensions,
Expand Down
9 changes: 9 additions & 0 deletions NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ end
return similartype(parenttype(arraytype), dims)
end

# This is for uniform `Diag` storage which uses
# a Number as the data type.
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function similartype(numbertype::Type{<:Number})
return numbertype
end

# Instances
function similartype(array::AbstractArray, eltype::Type, dims...)
return similartype(typeof(array), eltype, dims...)
Expand Down
4 changes: 2 additions & 2 deletions NDTensors/src/blocksparse/block.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ end
# Constructors
#

Block{N}(t::Tuple{Vararg{<:Any,N}}) where {N} = Block{N}(UInt.(t))
Block{N}(t::Tuple{Vararg{Any,N}}) where {N} = Block{N}(UInt.(t))

Block{N}(I::CartesianIndex{N}) where {N} = Block{N}(I.I)

Expand All @@ -38,7 +38,7 @@ Block(v::SVector{N}) where {N} = Block{N}(v)

Block(t::NTuple{N,UInt}) where {N} = Block{N}(t)

Block(t::Tuple{Vararg{<:Any,N}}) where {N} = Block{N}(t)
Block(t::Tuple{Vararg{Any,N}}) where {N} = Block{N}(t)

Block(::Tuple{}) = Block{0}(())

Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/blocksparse/blockoffsets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Assumes the blocks are allong the diagonal.
"""
function diagblockoffsets(
blocks::Vector{BlockT}, inds
) where {BlockT<:Union{Block{N},Tuple{Vararg{<:Any,N}}}} where {N}
) where {BlockT<:Union{Block{N},Tuple{Vararg{Any,N}}}} where {N}
blockoffsets = BlockOffsets{N}()
nnzdiag = 0
for (i, block) in enumerate(blocks)
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ First T is permuted as `permutedims(3,2,1)`, then reshaped such
that the original indices 3 and 2 are combined.
"""
function permute_reshape(
T::DenseTensor{ElT,NT,IndsT}, pos::Vararg{<:Any,N}
T::DenseTensor{ElT,NT,IndsT}, pos::Vararg{Any,N}
) where {ElT,NT,IndsT,N}
perm = flatten(pos...)

Expand Down
7 changes: 7 additions & 0 deletions NDTensors/src/empty/EmptyTensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ fulltype(T::EmptyStorage) = fulltype(typeof(T))

fulltype(T::Tensor) = fulltype(typeof(T))

# Needed for correct `NDTensors.ndims` definitions, for
# example `EmptyStorage` that wraps a `BlockSparse` which
# can have non-unity dimensions.
function ndims(storagetype::Type{<:EmptyStorage})
return ndims(fulltype(storagetype))
end

# From an EmptyTensor, return the closest Tensor type
function fulltype(::Type{TensorT}) where {TensorT<:Tensor}
return Tensor{
Expand Down
5 changes: 4 additions & 1 deletion NDTensors/src/tensor/similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,8 @@ end

function similartype(tensortype::Type{<:Tensor}, dims::Tuple)
tensortype_new_inds = set_indstype(tensortype, dims)
return set_storagetype(tensortype_new_inds, similartype(storagetype(tensortype_new_inds)))
# Need to pass `dims` in case that information is needed to make a storage type,
# for example `BlockSparse` needs the number of dimensions.
storagetype_new_inds = similartype(storagetype(tensortype_new_inds), dims)
return set_storagetype(tensortype_new_inds, storagetype_new_inds)
end
5 changes: 4 additions & 1 deletion NDTensors/src/tensorstorage/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@ function set_eltype(arraytype::Type{<:TensorStorage}, eltype::Type)
end

function set_ndims(arraytype::Type{<:TensorStorage}, ndims)
return set_datatype(arraytype, set_ndims(datatype(arraytype), ndims))
# TODO: Change to this once `TensorStorage` types support wrapping
# non-AbstractVector types.
# return set_datatype(arraytype, set_ndims(datatype(arraytype), ndims))
return arraytype
end
12 changes: 10 additions & 2 deletions NDTensors/src/tensorstorage/similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,15 @@ function similartype(storagetype::Type{<:TensorStorage}, eltype::Type)
# TODO: Don't convert to an `AbstractVector` with `set_ndims(datatype, 1)`, once we support
# more general data types.
# return set_datatype(storagetype, NDTensors.similartype(datatype(storagetype), eltype))
return set_datatype(
storagetype, set_ndims(NDTensors.similartype(datatype(storagetype), eltype), 1)
return set_datatype(storagetype, set_ndims(similartype(datatype(storagetype), eltype), 1))
end

function similartype(storagetype::Type{<:TensorStorage}, dims::Tuple)
# TODO: In the future, set the dimensions of the data type based on `dims`, once
# more general data types beyond `AbstractVector` are supported.
# `similartype` unwraps any wrapped data.
return set_ndims(
set_datatype(storagetype, set_ndims(similartype(datatype(storagetype)), 1)),
length(dims),
)
end
4 changes: 2 additions & 2 deletions NDTensors/src/tupletools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ ValLength(::NTuple{N}) where {N} = Val(N)
# is not type stable and therefore not efficient.
ValLength(v::Vector) = Val(length(v))

ValLength(::Tuple{Vararg{<:Any,N}}) where {N} = Val(N)
ValLength(::Tuple{Vararg{Any,N}}) where {N} = Val(N)

ValLength(::Type{<:Tuple{Vararg{<:Any,N}}}) where {N} = Val{N}
ValLength(::Type{<:Tuple{Vararg{Any,N}}}) where {N} = Val{N}

ValLength(::CartesianIndex{N}) where {N} = Val(N)
ValLength(::Type{CartesianIndex{N}}) where {N} = Val{N}
Expand Down
12 changes: 12 additions & 0 deletions NDTensors/test/blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ using Test
@test conj(data(store(A))) == data(store(conj(A)))
@test typeof(conj(A)) <: BlockSparseTensor

@testset "similartype regression test" begin
# Regression test for issue seen in:
# https://github.com/ITensor/ITensorInfiniteMPS.jl/pull/77
# Previously, `similartype` wasn't using information about the dimensions
# properly and was returning a `BlockSparse` storage of the dimensions
# of the input tensor.
T = BlockSparseTensor([(1, 1)], ([2], [2]))
@test NDTensors.ndims(
NDTensors.storagetype(NDTensors.similartype(typeof(T), ([2], [2], [2])))
) == 3
end

@testset "Random constructor" begin
T = randomBlockSparseTensor([(1, 1), (2, 2)], ([2, 2], [2, 2]))
@test nnzblocks(T) == 2
Expand Down
14 changes: 7 additions & 7 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,16 @@ end
# B .+= A
#

function fmap(bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{<:ITensor}}})
function fmap(bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{ITensor}}})
return (r, t) -> bc.f(r, t)
end

function fmap(bc::Broadcasted{ITensorStyle,<:Any,typeof(-),<:Tuple{Vararg{<:ITensor}}})
function fmap(bc::Broadcasted{ITensorStyle,<:Any,typeof(-),<:Tuple{Vararg{ITensor}}})
return (r, t) -> bc.f(r, t)
end

function Base.copyto!(
T::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{<:ITensor}}}
T::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{ITensor}}}
)
if T === bc.args[1]
A = bc.args[2]
Expand All @@ -307,7 +307,7 @@ end
#

function Base.copyto!(
T::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(-),<:Tuple{Vararg{<:ITensor}}}
T::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(-),<:Tuple{Vararg{ITensor}}}
)
if T === bc.args[1]
A = bc.args[2]
Expand Down Expand Up @@ -380,7 +380,7 @@ end

function Base.copyto!(
T::ITensor,
bc::Broadcasted{ITensorOpScalarStyle,<:Any,typeof(+),<:Tuple{Vararg{<:Broadcasted}}},
bc::Broadcasted{ITensorOpScalarStyle,<:Any,typeof(+),<:Tuple{Vararg{Broadcasted}}},
)
bc_α = bc.args[1]
bc_β = bc.args[2]
Expand Down Expand Up @@ -417,7 +417,7 @@ end
function Base.copyto!(
T::ITensor,
bc::Broadcasted{
ITensorOpScalarStyle,<:Any,typeof(+),<:Tuple{Vararg{<:Union{<:ITensor,<:Number}}}
ITensorOpScalarStyle,<:Any,typeof(+),<:Tuple{Vararg{Union{<:ITensor,<:Number}}}
},
)
α = find_type(Number, bc.args)
Expand Down Expand Up @@ -496,7 +496,7 @@ end
#

function Base.copyto!(
R::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{<:Broadcasted}}}
R::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{Broadcasted}}}
)
bc1 = bc.args[1]
bc2 = bc.args[2]
Expand Down
4 changes: 2 additions & 2 deletions src/indexset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ Tuple(is::IndexSet) = _Tuple(is)
NTuple{N}(is::IndexSet) where {N} = _NTuple(Val(N), is)

"""
not(inds::Union{IndexSet, Tuple{Vararg{<:Index}}})
not(inds::Union{IndexSet, Tuple{Vararg{Index}}})
not(inds::Index...)
!(inds::Union{IndexSet, Tuple{Vararg{<:Index}}})
!(inds::Union{IndexSet, Tuple{Vararg{Index}}})
!(inds::Index...)
Represents the set of indices not in the specified
Expand Down
13 changes: 7 additions & 6 deletions src/itensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ randomITensor() = randomITensor(Random.default_rng())
randomITensor(rng::AbstractRNG) = randomITensor(rng, Float64, ())

copy(T::ITensor)::ITensor = itensor(copy(tensor(T)))
zero(T::ITensor)::ITensor = itensor(zero(tensor(T)))

#
# Construct from Array
Expand Down Expand Up @@ -1112,7 +1113,7 @@ end
# CartesianIndices
@propagate_inbounds getindex(T::ITensor, I::CartesianIndex)::Any = T[Tuple(I)...]

@propagate_inbounds @inline function _getindex(T::Tensor, ivs::Vararg{<:Any,N}) where {N}
@propagate_inbounds @inline function _getindex(T::Tensor, ivs::Vararg{Any,N}) where {N}
# Tried ind.(ivs), val.(ivs) but it is slower
p = NDTensors.getperm(inds(T), ntuple(n -> ind(@inbounds ivs[n]), Val(N)))
fac = NDTensors.permfactor(p, ivs...) #<fermions> possible sign
Expand All @@ -1133,7 +1134,7 @@ A = ITensor(2.0, i, i')
A[i => 1, i' => 2] # 2.0, same as: A[i' => 2, i => 1]
```
"""
@propagate_inbounds (getindex(T::ITensor, ivs::Vararg{<:Any,N})::Any) where {N} =
@propagate_inbounds (getindex(T::ITensor, ivs::Vararg{Any,N})::Any) where {N} =
_getindex(tensor(T), ivs...)

@propagate_inbounds function getindex(T::ITensor)::Any
Expand Down Expand Up @@ -1233,7 +1234,7 @@ end
end

@propagate_inbounds @inline function _setindex!!(
T::Tensor, x::Number, ivs::Vararg{<:Any,N}
T::Tensor, x::Number, ivs::Vararg{Any,N}
) where {N}
# Would be nice to split off the functions for extracting the `ind` and `val` as Tuples,
# but it was slower.
Expand All @@ -1245,7 +1246,7 @@ end
end

@propagate_inbounds @inline function setindex!(
T::ITensor, x::Number, I::Vararg{<:Any,N}
T::ITensor, x::Number, I::Vararg{Any,N}
) where {N}
return settensor!(T, _setindex!!(tensor(T), x, I...))
end
Expand Down Expand Up @@ -1335,7 +1336,7 @@ itensor2inds(A::ITensor)::Any = inds(A)
itensor2inds(A::Tensor) = inds(A)
itensor2inds(i::Index) = (i,)
itensor2inds(A) = A
function map_itensor2inds(A::Tuple{Vararg{<:Any,N}}) where {N}
function map_itensor2inds(A::Tuple{Vararg{Any,N}}) where {N}
return ntuple(i -> itensor2inds(A[i]), Val(N))
end

Expand Down Expand Up @@ -1376,7 +1377,7 @@ hassameinds(A, B) = issetequal(itensor2inds(A), itensor2inds(B))

# Apply the Index set function and then filter the results
function filter_inds_set_function(
ffilter::Function, fset::Function, A::Vararg{<:Any,N}
ffilter::Function, fset::Function, A::Vararg{Any,N}
) where {N}
return filter(ffilter, fset(map_itensor2inds(A)...))
end
Expand Down
6 changes: 3 additions & 3 deletions src/tensor_operations/tensor_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ function contract(A::ITensor, B::ITensor)::ITensor
end
end

function optimal_contraction_sequence(A::Union{Vector{<:ITensor},Tuple{Vararg{<:ITensor}}})
function optimal_contraction_sequence(A::Union{Vector{<:ITensor},Tuple{Vararg{ITensor}}})
if length(A) == 1
return optimal_contraction_sequence(A[1])
elseif length(A) == 2
Expand All @@ -130,7 +130,7 @@ _optimal_contraction_sequence(As::Tuple{<:ITensor,<:ITensor}) = Any[1, 2]
function _optimal_contraction_sequence(As::Tuple{<:ITensor,<:ITensor,<:ITensor})
return optimal_contraction_sequence(inds(As[1]), inds(As[2]), inds(As[3]))
end
function _optimal_contraction_sequence(As::Tuple{Vararg{<:ITensor}})
function _optimal_contraction_sequence(As::Tuple{Vararg{ITensor}})
return __optimal_contraction_sequence(As)
end

Expand All @@ -145,7 +145,7 @@ function default_sequence()
return using_contraction_sequence_optimization() ? "automatic" : "left_associative"
end

function contraction_cost(As::Union{Vector{<:ITensor},Tuple{Vararg{<:ITensor}}}; kwargs...)
function contraction_cost(As::Union{Vector{<:ITensor},Tuple{Vararg{ITensor}}}; kwargs...)
indsAs = [inds(A) for A in As]
return contraction_cost(indsAs; kwargs...)
end
Expand Down
35 changes: 22 additions & 13 deletions test/ITensorChainRules/test_chainrules_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,29 @@ using Zygote: ZygoteRuleConfig, gradient
atol=1.0e-7,
)

f = function (x)
y = Op("Ry", 1; θ=x) + Op("Ry", 1; θ=x)
return y[1].params.θ
function sometimes_broken_test()
f = function (x)
y = Op("Ry", 1; θ=x) + Op("Ry", 1; θ=x)
return y[1].params.θ
end
args = (x,)
test_rrule(
ZygoteRuleConfig(),
f,
args...;
rrule_f=rrule_via_ad,
check_inferred=false,
rtol=1.0e-7,
atol=1.0e-7,
)
return nothing
end

@static if VERSION > v"1.8"
@test_skip sometimes_broken_test()
else
sometimes_broken_test()
end
args = (x,)
test_rrule(
ZygoteRuleConfig(),
f,
args...;
rrule_f=rrule_via_ad,
check_inferred=false,
rtol=1.0e-7,
atol=1.0e-7,
)

f = function (x)
y = ITensor(Op("Ry", 1; θ=x) + Op("Ry", 1; θ=x), s)
Expand Down
8 changes: 8 additions & 0 deletions test/base/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
CArray = transpose(array(Ai)) * array(Bi)
@test CArray scalar(C)
end
@testset "Test Matrix{ITensor} * Matrix{ITensor}" begin
M1 = [Aij Aij; Aij Aij]
M2 = [Ajk Ajk; Ajk Ajk]
M12 = M1 * M2
for x in 1:2, y in 1:2
@test M12[x, y] 2 * Aij * Ajk
end
end
@testset "Test contract ITensors (Vector*Vectorᵀ -> Matrix)" begin
C = Ai * Aj
for ii in 1:dim(i), jj in 1:dim(j)
Expand Down
7 changes: 7 additions & 0 deletions test/base/test_itensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,13 @@ end
@test all(ITensors.data(A) .== 1.0)
end

@testset "zero" begin
i = Index(2)
A = randomITensor(i)
B = zero(A)
@test false * A B
end

@testset "copyto!" begin
i = Index(2, "i")
j = Index(2, "j")
Expand Down
Loading

0 comments on commit 08daddf

Please sign in to comment.