Skip to content

Commit

Permalink
OpSum to TTN refactor (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
emstoudenmire authored May 10, 2024
1 parent 65f40ec commit 7ecd931
Show file tree
Hide file tree
Showing 7 changed files with 397 additions and 287 deletions.
4 changes: 3 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ include("solvers/local_solvers/contract.jl")
include("solvers/local_solvers/linsolve.jl")
include("treetensornetworks/abstracttreetensornetwork.jl")
include("treetensornetworks/treetensornetwork.jl")
include("treetensornetworks/opsum_to_ttn.jl")
include("treetensornetworks/opsum_to_ttn/matelem.jl")
include("treetensornetworks/opsum_to_ttn/qnarrelem.jl")
include("treetensornetworks/opsum_to_ttn/opsum_to_ttn.jl")
include("treetensornetworks/projttns/abstractprojttn.jl")
include("treetensornetworks/projttns/projttn.jl")
include("treetensornetworks/projttns/projttnsum.jl")
Expand Down
7 changes: 6 additions & 1 deletion src/apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,12 @@ function ITensors.apply(
end

function ITensors.apply(
o⃗::Scaled, ψ::AbstractITensorNetwork; normalize=false, ortho=false, apply_kwargs...
o⃗::Scaled,
ψ::AbstractITensorNetwork;
cutoff=nothing,
normalize=false,
ortho=false,
apply_kwargs...,
)
return maybe_real(Ops.coefficient(o⃗)) *
apply(Ops.argument(o⃗), ψ; cutoff, maxdim, normalize, ortho, apply_kwargs...)
Expand Down
41 changes: 41 additions & 0 deletions src/treetensornetworks/opsum_to_ttn/matelem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

#
# MatElem
#

struct MatElem{T}
row::Int
col::Int
val::T
end

#function Base.show(io::IO,m::MatElem)
# print(io,"($(m.row),$(m.col),$(m.val))")
#end

function toMatrix(els::Vector{MatElem{T}})::Matrix{T} where {T}
nr = 0
nc = 0
for el in els
nr = max(nr, el.row)
nc = max(nc, el.col)
end
M = zeros(T, nr, nc)
for el in els
M[el.row, el.col] = el.val
end
return M
end

function Base.:(==)(m1::MatElem{T}, m2::MatElem{T})::Bool where {T}
return (m1.row == m2.row && m1.col == m2.col && m1.val == m2.val)
end

function Base.isless(m1::MatElem{T}, m2::MatElem{T})::Bool where {T}
if m1.row != m2.row
return m1.row < m2.row
elseif m1.col != m2.col
return m1.col < m2.col
end
return m1.val < m2.val
end
Loading

0 comments on commit 7ecd931

Please sign in to comment.