Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Fix output type of Empty BlockSparse contraction #1135

Merged
merged 4 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ include("abstractarray/fill.jl")
include("array/set_types.jl")
include("tupletools.jl")
include("tensorstorage/tensorstorage.jl")
include("tensorstorage/set_types.jl")
include("tensorstorage/default_storage.jl")
include("tensorstorage/similar.jl")
include("tensor/tensor.jl")
Expand Down
9 changes: 9 additions & 0 deletions NDTensors/src/abstractarray/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ function set_ndims(arraytype::Type{<:AbstractArray}, ndims)
)
end

# This is for uniform `Diag` storage which uses
# a Number as the data type.
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function set_ndims(numbertype::Type{<:Number}, ndims)
return numbertype
end

"""
`set_indstype` should be overloaded for
types with structured dimensions,
Expand Down
9 changes: 9 additions & 0 deletions NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ end
return similartype(parenttype(arraytype), dims)
end

# This is for uniform `Diag` storage which uses
# a Number as the data type.
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function similartype(numbertype::Type{<:Number})
return numbertype
end

# Instances
function similartype(array::AbstractArray, eltype::Type, dims...)
return similartype(typeof(array), eltype, dims...)
Expand Down
4 changes: 2 additions & 2 deletions NDTensors/src/blocksparse/block.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ end
# Constructors
#

Block{N}(t::Tuple{Vararg{<:Any,N}}) where {N} = Block{N}(UInt.(t))
Block{N}(t::Tuple{Vararg{Any,N}}) where {N} = Block{N}(UInt.(t))

Block{N}(I::CartesianIndex{N}) where {N} = Block{N}(I.I)

Expand All @@ -38,7 +38,7 @@ Block(v::SVector{N}) where {N} = Block{N}(v)

Block(t::NTuple{N,UInt}) where {N} = Block{N}(t)

Block(t::Tuple{Vararg{<:Any,N}}) where {N} = Block{N}(t)
Block(t::Tuple{Vararg{Any,N}}) where {N} = Block{N}(t)

Block(::Tuple{}) = Block{0}(())

Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/blocksparse/blockoffsets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Assumes the blocks are allong the diagonal.
"""
function diagblockoffsets(
blocks::Vector{BlockT}, inds
) where {BlockT<:Union{Block{N},Tuple{Vararg{<:Any,N}}}} where {N}
) where {BlockT<:Union{Block{N},Tuple{Vararg{Any,N}}}} where {N}
blockoffsets = BlockOffsets{N}()
nnzdiag = 0
for (i, block) in enumerate(blocks)
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ First T is permuted as `permutedims(3,2,1)`, then reshaped such
that the original indices 3 and 2 are combined.
"""
function permute_reshape(
T::DenseTensor{ElT,NT,IndsT}, pos::Vararg{<:Any,N}
T::DenseTensor{ElT,NT,IndsT}, pos::Vararg{Any,N}
) where {ElT,NT,IndsT,N}
perm = flatten(pos...)

Expand Down
7 changes: 7 additions & 0 deletions NDTensors/src/empty/EmptyTensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ fulltype(T::EmptyStorage) = fulltype(typeof(T))

fulltype(T::Tensor) = fulltype(typeof(T))

# Needed for correct `NDTensors.ndims` definitions, for
# example `EmptyStorage` that wraps a `BlockSparse` which
# can have non-unity dimensions.
function ndims(storagetype::Type{<:EmptyStorage})
return ndims(fulltype(storagetype))
end

# From an EmptyTensor, return the closest Tensor type
function fulltype(::Type{TensorT}) where {TensorT<:Tensor}
return Tensor{
Expand Down
5 changes: 4 additions & 1 deletion NDTensors/src/tensor/similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,8 @@ end

function similartype(tensortype::Type{<:Tensor}, dims::Tuple)
tensortype_new_inds = set_indstype(tensortype, dims)
return set_storagetype(tensortype_new_inds, similartype(storagetype(tensortype_new_inds)))
# Need to pass `dims` in case that information is needed to make a storage type,
# for example `BlockSparse` needs the number of dimensions.
storagetype_new_inds = similartype(storagetype(tensortype_new_inds), dims)
return set_storagetype(tensortype_new_inds, storagetype_new_inds)
end
5 changes: 4 additions & 1 deletion NDTensors/src/tensorstorage/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@ function set_eltype(arraytype::Type{<:TensorStorage}, eltype::Type)
end

function set_ndims(arraytype::Type{<:TensorStorage}, ndims)
return set_datatype(arraytype, set_ndims(datatype(arraytype), ndims))
# TODO: Change to this once `TensorStorage` types support wrapping
# non-AbstractVector types.
# return set_datatype(arraytype, set_ndims(datatype(arraytype), ndims))
return arraytype
end
12 changes: 10 additions & 2 deletions NDTensors/src/tensorstorage/similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,15 @@ function similartype(storagetype::Type{<:TensorStorage}, eltype::Type)
# TODO: Don't convert to an `AbstractVector` with `set_ndims(datatype, 1)`, once we support
# more general data types.
# return set_datatype(storagetype, NDTensors.similartype(datatype(storagetype), eltype))
return set_datatype(
storagetype, set_ndims(NDTensors.similartype(datatype(storagetype), eltype), 1)
return set_datatype(storagetype, set_ndims(similartype(datatype(storagetype), eltype), 1))
end

function similartype(storagetype::Type{<:TensorStorage}, dims::Tuple)
# TODO: In the future, set the dimensions of the data type based on `dims`, once
# more general data types beyond `AbstractVector` are supported.
# `similartype` unwraps any wrapped data.
return set_ndims(
set_datatype(storagetype, set_ndims(similartype(datatype(storagetype)), 1)),
length(dims),
)
end
4 changes: 2 additions & 2 deletions NDTensors/src/tupletools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ ValLength(::NTuple{N}) where {N} = Val(N)
# is not type stable and therefore not efficient.
ValLength(v::Vector) = Val(length(v))

ValLength(::Tuple{Vararg{<:Any,N}}) where {N} = Val(N)
ValLength(::Tuple{Vararg{Any,N}}) where {N} = Val(N)

ValLength(::Type{<:Tuple{Vararg{<:Any,N}}}) where {N} = Val{N}
ValLength(::Type{<:Tuple{Vararg{Any,N}}}) where {N} = Val{N}

ValLength(::CartesianIndex{N}) where {N} = Val(N)
ValLength(::Type{CartesianIndex{N}}) where {N} = Val{N}
Expand Down
12 changes: 12 additions & 0 deletions NDTensors/test/blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ using Test
@test conj(data(store(A))) == data(store(conj(A)))
@test typeof(conj(A)) <: BlockSparseTensor

@testset "similartype regression test" begin
# Regression test for issue seen in:
# https://github.com/ITensor/ITensorInfiniteMPS.jl/pull/77
# Previously, `similartype` wasn't using information about the dimensions
# properly and was returning a `BlockSparse` storage of the dimensions
# of the input tensor.
T = BlockSparseTensor([(1, 1)], ([2], [2]))
@test NDTensors.ndims(
NDTensors.storagetype(NDTensors.similartype(typeof(T), ([2], [2], [2])))
) == 3
end

@testset "Random constructor" begin
T = randomBlockSparseTensor([(1, 1), (2, 2)], ([2, 2], [2, 2]))
@test nnzblocks(T) == 2
Expand Down
14 changes: 7 additions & 7 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,16 @@ end
# B .+= A
#

function fmap(bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{<:ITensor}}})
function fmap(bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{ITensor}}})
return (r, t) -> bc.f(r, t)
end

function fmap(bc::Broadcasted{ITensorStyle,<:Any,typeof(-),<:Tuple{Vararg{<:ITensor}}})
function fmap(bc::Broadcasted{ITensorStyle,<:Any,typeof(-),<:Tuple{Vararg{ITensor}}})
return (r, t) -> bc.f(r, t)
end

function Base.copyto!(
T::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{<:ITensor}}}
T::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{ITensor}}}
)
if T === bc.args[1]
A = bc.args[2]
Expand All @@ -307,7 +307,7 @@ end
#

function Base.copyto!(
T::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(-),<:Tuple{Vararg{<:ITensor}}}
T::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(-),<:Tuple{Vararg{ITensor}}}
)
if T === bc.args[1]
A = bc.args[2]
Expand Down Expand Up @@ -380,7 +380,7 @@ end

function Base.copyto!(
T::ITensor,
bc::Broadcasted{ITensorOpScalarStyle,<:Any,typeof(+),<:Tuple{Vararg{<:Broadcasted}}},
bc::Broadcasted{ITensorOpScalarStyle,<:Any,typeof(+),<:Tuple{Vararg{Broadcasted}}},
)
bc_α = bc.args[1]
bc_β = bc.args[2]
Expand Down Expand Up @@ -417,7 +417,7 @@ end
function Base.copyto!(
T::ITensor,
bc::Broadcasted{
ITensorOpScalarStyle,<:Any,typeof(+),<:Tuple{Vararg{<:Union{<:ITensor,<:Number}}}
ITensorOpScalarStyle,<:Any,typeof(+),<:Tuple{Vararg{Union{<:ITensor,<:Number}}}
},
)
α = find_type(Number, bc.args)
Expand Down Expand Up @@ -496,7 +496,7 @@ end
#

function Base.copyto!(
R::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{<:Broadcasted}}}
R::ITensor, bc::Broadcasted{ITensorStyle,<:Any,typeof(+),<:Tuple{Vararg{Broadcasted}}}
)
bc1 = bc.args[1]
bc2 = bc.args[2]
Expand Down
4 changes: 2 additions & 2 deletions src/indexset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ Tuple(is::IndexSet) = _Tuple(is)
NTuple{N}(is::IndexSet) where {N} = _NTuple(Val(N), is)

"""
not(inds::Union{IndexSet, Tuple{Vararg{<:Index}}})
not(inds::Union{IndexSet, Tuple{Vararg{Index}}})
not(inds::Index...)
!(inds::Union{IndexSet, Tuple{Vararg{<:Index}}})
!(inds::Union{IndexSet, Tuple{Vararg{Index}}})
!(inds::Index...)

Represents the set of indices not in the specified
Expand Down
13 changes: 7 additions & 6 deletions src/itensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ randomITensor() = randomITensor(Random.default_rng())
randomITensor(rng::AbstractRNG) = randomITensor(rng, Float64, ())

copy(T::ITensor)::ITensor = itensor(copy(tensor(T)))
zero(T::ITensor)::ITensor = itensor(zero(tensor(T)))

#
# Construct from Array
Expand Down Expand Up @@ -1112,7 +1113,7 @@ end
# CartesianIndices
@propagate_inbounds getindex(T::ITensor, I::CartesianIndex)::Any = T[Tuple(I)...]

@propagate_inbounds @inline function _getindex(T::Tensor, ivs::Vararg{<:Any,N}) where {N}
@propagate_inbounds @inline function _getindex(T::Tensor, ivs::Vararg{Any,N}) where {N}
# Tried ind.(ivs), val.(ivs) but it is slower
p = NDTensors.getperm(inds(T), ntuple(n -> ind(@inbounds ivs[n]), Val(N)))
fac = NDTensors.permfactor(p, ivs...) #<fermions> possible sign
Expand All @@ -1133,7 +1134,7 @@ A = ITensor(2.0, i, i')
A[i => 1, i' => 2] # 2.0, same as: A[i' => 2, i => 1]
```
"""
@propagate_inbounds (getindex(T::ITensor, ivs::Vararg{<:Any,N})::Any) where {N} =
@propagate_inbounds (getindex(T::ITensor, ivs::Vararg{Any,N})::Any) where {N} =
_getindex(tensor(T), ivs...)

@propagate_inbounds function getindex(T::ITensor)::Any
Expand Down Expand Up @@ -1233,7 +1234,7 @@ end
end

@propagate_inbounds @inline function _setindex!!(
T::Tensor, x::Number, ivs::Vararg{<:Any,N}
T::Tensor, x::Number, ivs::Vararg{Any,N}
) where {N}
# Would be nice to split off the functions for extracting the `ind` and `val` as Tuples,
# but it was slower.
Expand All @@ -1245,7 +1246,7 @@ end
end

@propagate_inbounds @inline function setindex!(
T::ITensor, x::Number, I::Vararg{<:Any,N}
T::ITensor, x::Number, I::Vararg{Any,N}
) where {N}
return settensor!(T, _setindex!!(tensor(T), x, I...))
end
Expand Down Expand Up @@ -1335,7 +1336,7 @@ itensor2inds(A::ITensor)::Any = inds(A)
itensor2inds(A::Tensor) = inds(A)
itensor2inds(i::Index) = (i,)
itensor2inds(A) = A
function map_itensor2inds(A::Tuple{Vararg{<:Any,N}}) where {N}
function map_itensor2inds(A::Tuple{Vararg{Any,N}}) where {N}
return ntuple(i -> itensor2inds(A[i]), Val(N))
end

Expand Down Expand Up @@ -1376,7 +1377,7 @@ hassameinds(A, B) = issetequal(itensor2inds(A), itensor2inds(B))

# Apply the Index set function and then filter the results
function filter_inds_set_function(
ffilter::Function, fset::Function, A::Vararg{<:Any,N}
ffilter::Function, fset::Function, A::Vararg{Any,N}
) where {N}
return filter(ffilter, fset(map_itensor2inds(A)...))
end
Expand Down
6 changes: 3 additions & 3 deletions src/tensor_operations/tensor_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ function contract(A::ITensor, B::ITensor)::ITensor
end
end

function optimal_contraction_sequence(A::Union{Vector{<:ITensor},Tuple{Vararg{<:ITensor}}})
function optimal_contraction_sequence(A::Union{Vector{<:ITensor},Tuple{Vararg{ITensor}}})
if length(A) == 1
return optimal_contraction_sequence(A[1])
elseif length(A) == 2
Expand All @@ -130,7 +130,7 @@ _optimal_contraction_sequence(As::Tuple{<:ITensor,<:ITensor}) = Any[1, 2]
function _optimal_contraction_sequence(As::Tuple{<:ITensor,<:ITensor,<:ITensor})
return optimal_contraction_sequence(inds(As[1]), inds(As[2]), inds(As[3]))
end
function _optimal_contraction_sequence(As::Tuple{Vararg{<:ITensor}})
function _optimal_contraction_sequence(As::Tuple{Vararg{ITensor}})
return __optimal_contraction_sequence(As)
end

Expand All @@ -145,7 +145,7 @@ function default_sequence()
return using_contraction_sequence_optimization() ? "automatic" : "left_associative"
end

function contraction_cost(As::Union{Vector{<:ITensor},Tuple{Vararg{<:ITensor}}}; kwargs...)
function contraction_cost(As::Union{Vector{<:ITensor},Tuple{Vararg{ITensor}}}; kwargs...)
indsAs = [inds(A) for A in As]
return contraction_cost(indsAs; kwargs...)
end
Expand Down
35 changes: 22 additions & 13 deletions test/ITensorChainRules/test_chainrules_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,29 @@ using Zygote: ZygoteRuleConfig, gradient
atol=1.0e-7,
)

f = function (x)
y = Op("Ry", 1; θ=x) + Op("Ry", 1; θ=x)
return y[1].params.θ
function sometimes_broken_test()
f = function (x)
y = Op("Ry", 1; θ=x) + Op("Ry", 1; θ=x)
return y[1].params.θ
end
args = (x,)
test_rrule(
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
ZygoteRuleConfig(),
f,
args...;
rrule_f=rrule_via_ad,
check_inferred=false,
rtol=1.0e-7,
atol=1.0e-7,
)
return nothing
end

@static if VERSION > v"1.8"
@test_skip sometimes_broken_test()
else
sometimes_broken_test()
end
args = (x,)
test_rrule(
ZygoteRuleConfig(),
f,
args...;
rrule_f=rrule_via_ad,
check_inferred=false,
rtol=1.0e-7,
atol=1.0e-7,
)

f = function (x)
y = ITensor(Op("Ry", 1; θ=x) + Op("Ry", 1; θ=x), s)
Expand Down
8 changes: 8 additions & 0 deletions test/base/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
CArray = transpose(array(Ai)) * array(Bi)
@test CArray ≈ scalar(C)
end
@testset "Test Matrix{ITensor} * Matrix{ITensor}" begin
M1 = [Aij Aij; Aij Aij]
M2 = [Ajk Ajk; Ajk Ajk]
M12 = M1 * M2
for x in 1:2, y in 1:2
@test M12[x, y] ≈ 2 * Aij * Ajk
end
end
@testset "Test contract ITensors (Vector*Vectorᵀ -> Matrix)" begin
C = Ai * Aj
for ii in 1:dim(i), jj in 1:dim(j)
Expand Down
7 changes: 7 additions & 0 deletions test/base/test_itensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,13 @@ end
@test all(ITensors.data(A) .== 1.0)
end

@testset "zero" begin
i = Index(2)
A = randomITensor(i)
B = zero(A)
@test false * A ≈ B
end

@testset "copyto!" begin
i = Index(2, "i")
j = Index(2, "j")
Expand Down
Loading