Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend contractor interface #32

Merged
merged 7 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 54 additions & 27 deletions src/contractor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,16 @@ struct MpsParameters{S<:Real}
iters_var = 1,
Dtemp_multiplier = 2,
method = :psvd_sparse,
) where {S} = new(bond_dim, var_tol, num_sweeps, tol_SVD, iters_svd, iters_var, Dtemp_multiplier, method)
) where {S} = new(
bond_dim,
var_tol,
num_sweeps,
tol_SVD,
iters_svd,
iters_var,
Dtemp_multiplier,
method,
)
end

"""
Expand Down Expand Up @@ -186,6 +195,44 @@ mutable struct MpsContractor{T<:AbstractStrategy,R<:AbstractGauge,S<:Real} <:
end
end

function MpsContractor(
::Type{T},
::Type{R},
::Type{S},
net,
params;
beta::S,
graduate_truncation::Bool,
onGPU = true,
depth::Int = 0,
) where {T, R, S}
return MpsContractor{T,R,S}(net, params; beta, graduate_truncation, onGPU, depth)
end

function MpsContractor(
::Type{T},
::Type{R},
net,
params;
beta::S,
graduate_truncation::Bool,
onGPU = true,
depth::Int = 0,
) where {T, R, S}
return MpsContractor(T, R, S, net, params; beta, graduate_truncation, onGPU, depth)
end

function MpsContractor(
::Type{T},
net,
params;
beta::S,
graduate_truncation::Bool,
onGPU = true,
depth::Int = 0,
) where {T, S}
return MpsContractor(T, NoUpdate, net, params; beta, graduate_truncation, onGPU, depth)
end
"""
$(TYPEDSIGNATURES)
Get the strategy used to contract the PEPS network.
Expand Down Expand Up @@ -243,10 +290,7 @@ Construct and memoize the top Matrix Product State (MPS) using Singular Value De

This function constructs the top MPS using SVD for a given row in the PEPS network contraction. It recursively builds the MPS row by row, performing canonicalization, truncation, and compression steps as needed based on the specified parameters in `ctr.params`. The resulting MPS is memoized for efficient reuse.
"""
@memoize Dict function mps_top(
ctr::MpsContractor{SVDTruncate,R,S},
i::Int,
) where {R,S}
@memoize Dict function mps_top(ctr::MpsContractor{SVDTruncate,R,S}, i::Int) where {R,S}
Dcut = ctr.params.bond_dimension
tolV = ctr.params.variational_tol
tolS = ctr.params.tol_SVD
Expand Down Expand Up @@ -285,10 +329,7 @@ Construct and memoize the (bottom) Matrix Product State (MPS) using Singular Val

This function constructs the (bottom) MPS using SVD for a given row in the PEPS network contraction. It recursively builds the MPS row by row, performing canonicalization, truncation, and compression steps as needed based on the specified parameters in `ctr.params`. The resulting MPS is memoized for efficient reuse.
"""
@memoize Dict function mps(
ctr::MpsContractor{SVDTruncate,R,S},
i::Int,
) where {R,S}
@memoize Dict function mps(ctr::MpsContractor{SVDTruncate,R,S}, i::Int) where {R,S}
Dcut = ctr.params.bond_dimension
tolV = ctr.params.variational_tol
tolS = ctr.params.tol_SVD
Expand Down Expand Up @@ -328,10 +369,7 @@ Construct and memoize the (bottom) Matrix Product State (MPS) approximation usin

This function constructs the (bottom) MPS approximation using SVD for a given row in the PEPS network contraction. It recursively builds the MPS row by row, performing canonicalization, and truncation steps based on the specified parameters in `ctr.params`. The resulting MPS approximation is memoized for efficient reuse.
"""
@memoize Dict function mps_approx(
ctr::MpsContractor{SVDTruncate,R,S},
i::Int,
) where {R,S}
@memoize Dict function mps_approx(ctr::MpsContractor{SVDTruncate,R,S}, i::Int) where {R,S}
if i > ctr.peps.nrows
W = mpo(ctr, ctr.layers.main, ctr.peps.nrows)
return IdentityQMps(S, local_dims(W, :down); onGPU = ctr.onGPU) # F64 for now
Expand Down Expand Up @@ -370,10 +408,7 @@ Construct and memoize the top Matrix Product State (MPS) using the Zipper (trunc

This function constructs the top Matrix Product State (MPS) using the Zipper (truncated Singular Value Decomposition) method for a given row in the PEPS network contraction. It recursively builds the MPS row by row, performing canonicalization, and truncation steps based on the specified parameters in `ctr.params`. The resulting MPS is memoized for efficient reuse.
"""
@memoize Dict function mps_top(
ctr::MpsContractor{Zipper,R,S},
i::Int,
) where {R,S}
@memoize Dict function mps_top(ctr::MpsContractor{Zipper,R,S}, i::Int) where {R,S}
Dcut = ctr.params.bond_dimension
tolV = ctr.params.variational_tol
tolS = ctr.params.tol_SVD
Expand Down Expand Up @@ -711,22 +746,14 @@ function sweep_gauges!(
end


function update_gauges!(
ctr::MpsContractor{T,S},
row::Site,
::Val{:down},
) where {T,S}
function update_gauges!(ctr::MpsContractor{T,S}, row::Site, ::Val{:down}) where {T,S}
for i ∈ 1:row-1
sweep_gauges!(ctr, i)
end
end


function update_gauges!(
ctr::MpsContractor{T,S},
row::Site,
::Val{:up},
) where {T,S}
function update_gauges!(ctr::MpsContractor{T,S}, row::Site, ::Val{:up}) where {T,S}
for i ∈ row-1:-1:1
sweep_gauges!(ctr, i)
end
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ using LinearAlgebra
using TensorCast
using Statistics
using MetaGraphs
using CUDA

disable_logging(LogLevel(1))

onGPU = true
user_onGPU = true # or false, based on user's preference
gpu_available = CUDA.functional()
onGPU = user_onGPU && gpu_available

using Test
my_tests = []
Expand Down
Loading