From d6bce6b78e9dd03f3af10be494a2ebccd51d4737 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 3 Nov 2023 09:13:54 -0400 Subject: [PATCH 1/2] Stricter kwargs --- src/treetensornetworks/solvers/contract.jl | 2 +- src/treetensornetworks/solvers/dmrg.jl | 54 +++++++++++++++---- src/treetensornetworks/solvers/dmrg_x.jl | 4 +- src/treetensornetworks/solvers/linsolve.jl | 22 +++++--- src/treetensornetworks/solvers/tdvp.jl | 37 ++++++++----- src/treetensornetworks/solvers/update_step.jl | 34 ++++++------ .../test_solvers/test_contract.jl | 4 +- .../test_solvers/test_tdvp.jl | 26 +++------ 8 files changed, 112 insertions(+), 71 deletions(-) diff --git a/src/treetensornetworks/solvers/contract.jl b/src/treetensornetworks/solvers/contract.jl index a2572112..51f6787f 100644 --- a/src/treetensornetworks/solvers/contract.jl +++ b/src/treetensornetworks/solvers/contract.jl @@ -1,4 +1,4 @@ -function contract_solver(PH, psi; kwargs...) +function contract_solver(PH, psi; normalize, region, half_sweep) v = ITensor(1.0) for j in sites(PH) v *= PH.psi0[j] diff --git a/src/treetensornetworks/solvers/dmrg.jl b/src/treetensornetworks/solvers/dmrg.jl index ac35d0ad..c67b3b86 100644 --- a/src/treetensornetworks/solvers/dmrg.jl +++ b/src/treetensornetworks/solvers/dmrg.jl @@ -1,15 +1,25 @@ -function eigsolve_solver(; solver_which_eigenvalue=:SR, kwargs...) - function solver(H, init; kws...) +function eigsolve_solver(; + solver_which_eigenvalue=:SR, + ishermitian=true, + solver_tol=1e-14, + solver_krylovdim=3, + solver_maxiter=1, + solver_verbosity=0, +) + function solver(H, init; normalize=nothing, region=nothing, half_sweep=nothing) howmany = 1 which = solver_which_eigenvalue - solver_kwargs = (; - ishermitian=get(kwargs, :ishermitian, true), - tol=get(kwargs, :solver_tol, 1E-14), - krylovdim=get(kwargs, :solver_krylovdim, 3), - maxiter=get(kwargs, :solver_maxiter, 1), - verbosity=get(kwargs, :solver_verbosity, 0), + vals, vecs, info = eigsolve( + H, + init, + howmany, + which; + ishermitian, + tol=solver_tol, + krylovdim=solver_krylovdim, + maxiter=solver_maxiter, + verbosity=solver_verbosity, ) - vals, vecs, info = eigsolve(H, init, howmany, which; solver_kwargs...) psi = vecs[1] return psi, (; solver_info=info, energies=vals) end @@ -19,8 +29,30 @@ end """ Overload of `ITensors.dmrg`. """ -function dmrg(H, init::AbstractTTN; kwargs...) - return alternating_update(eigsolve_solver(; kwargs...), H, init; kwargs...) +function dmrg( + H, + init::AbstractTTN; + solver_which_eigenvalue=:SR, + ishermitian=true, + solver_tol=1e-14, + solver_krylovdim=3, + solver_maxiter=1, + solver_verbosity=0, + kwargs..., +) + return alternating_update( + eigsolve_solver(; + solver_which_eigenvalue, + ishermitian, + solver_tol, + solver_krylovdim, + solver_maxiter, + solver_verbosity, + ), + H, + init; + kwargs..., + ) end """ diff --git a/src/treetensornetworks/solvers/dmrg_x.jl b/src/treetensornetworks/solvers/dmrg_x.jl index 15595dd1..30a97f3a 100644 --- a/src/treetensornetworks/solvers/dmrg_x.jl +++ b/src/treetensornetworks/solvers/dmrg_x.jl @@ -1,4 +1,6 @@ -function dmrg_x_solver(PH, init; kwargs...) +function dmrg_x_solver( + PH, init; normalize=nothing, region=nothing, half_sweep=nothing, reverse_step=nothing +) H = contract(PH, ITensor(1.0)) D, U = eigen(H; ishermitian=true) u = uniqueind(U, H) diff --git a/src/treetensornetworks/solvers/linsolve.jl b/src/treetensornetworks/solvers/linsolve.jl index 5ecf8c49..90ba8572 100644 --- a/src/treetensornetworks/solvers/linsolve.jl +++ b/src/treetensornetworks/solvers/linsolve.jl @@ -23,7 +23,14 @@ Keyword arguments: Overload of `KrylovKit.linsolve`. """ function linsolve( - A::AbstractTTN, b::AbstractTTN, x₀::AbstractTTN, a₀::Number=0, a₁::Number=1; kwargs... + A::AbstractTTN, + b::AbstractTTN, + x₀::AbstractTTN, + a₀::Number=0, + a₁::Number=1; + normalize, + region, + half_sweep, ) function linsolve_solver( P, @@ -33,17 +40,20 @@ function linsolve( solver_krylovdim=30, solver_maxiter=100, solver_verbosity=0, - kwargs..., ) - solver_kwargs = (; - ishermitian=ishermitian, + b = dag(only(proj_mps(P))) + x, info = KrylovKit.linsolve( + P, + b, + x₀, + a₀, + a₁; + ishermitian, tol=solver_tol, krylovdim=solver_krylovdim, maxiter=solver_maxiter, verbosity=solver_verbosity, ) - b = dag(only(proj_mps(P))) - x, info = KrylovKit.linsolve(P, b, x₀, a₀, a₁; solver_kwargs...) return x, NamedTuple() end diff --git a/src/treetensornetworks/solvers/tdvp.jl b/src/treetensornetworks/solvers/tdvp.jl index 24e69b57..7391d357 100644 --- a/src/treetensornetworks/solvers/tdvp.jl +++ b/src/treetensornetworks/solvers/tdvp.jl @@ -1,4 +1,4 @@ -function exponentiate_solver(; kwargs...) +function exponentiate_solver() function solver( H, init; @@ -10,10 +10,13 @@ function exponentiate_solver(; kwargs...) solver_outputlevel=0, solver_tol=1E-12, substep, + normalize, time_step, - kws..., ) - solver_kwargs = (; + psi, exp_info = KrylovKit.exponentiate( + H, + time_step, + init; ishermitian, issymmetric, tol=solver_tol, @@ -22,14 +25,12 @@ function exponentiate_solver(; kwargs...) verbosity=solver_outputlevel, eager=true, ) - - psi, exp_info = KrylovKit.exponentiate(H, time_step, init; solver_kwargs...) return psi, (; info=exp_info) end return solver end -function applyexp_solver(; kwargs...) +function applyexp_solver() function solver( H, init; @@ -39,13 +40,13 @@ function applyexp_solver(; kwargs...) solver_tol=1E-8, substep, time_step, - kws..., + normalize, ) - solver_kwargs = (; maxiter=solver_krylovdim, outputlevel=solver_outputlevel) - #applyexp tol is absolute, compute from tol_per_unit_time: tol = abs(time_step) * tol_per_unit_time - psi, exp_info = applyexp(H, time_step, init; tol, solver_kwargs..., kws...) + psi, exp_info = applyexp( + H, time_step, init; tol, maxiter=solver_krylovdim, outputlevel=solver_outputlevel + ) return psi, (; info=exp_info) end return solver @@ -84,7 +85,12 @@ function sub_time_steps(order) end function tdvp_sweep( - order::Int, nsite::Int, time_step::Number, graph::AbstractGraph; kwargs... + order::Int, + nsite::Int, + time_step::Number, + graph::AbstractGraph; + root_vertex=default_root_vertex(graph), + reverse_step=true, ) sweep = [] for (substep, fac) in enumerate(sub_time_steps(order)) @@ -93,10 +99,11 @@ function tdvp_sweep( direction(substep), graph, make_region; + root_vertex, nsite, region_args=(; substep, time_step=sub_time_step), reverse_args=(; substep, time_step=-sub_time_step), - reverse_step=true, + reverse_step, ) append!(sweep, half) end @@ -113,10 +120,12 @@ function tdvp( nsteps=nothing, order::Integer=2, (sweep_observer!)=observer(), + root_vertex=default_root_vertex(init), + reverse_step=true, kwargs..., ) nsweeps = _compute_nsweeps(nsteps, t, time_step, order) - sweep_regions = tdvp_sweep(order, nsite, time_step, init; kwargs...) + sweep_regions = tdvp_sweep(order, nsite, time_step, init; root_vertex, reverse_step) function sweep_time_printer(; outputlevel, sweep, kwargs...) if outputlevel >= 1 @@ -169,5 +178,5 @@ function tdvp(H, t::Number, init::AbstractTTN; solver_backend="exponentiate", kw "solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")", ) end - return tdvp(solver(; kwargs...), H, t, init; kwargs...) + return tdvp(solver(), H, t, init; kwargs...) end diff --git a/src/treetensornetworks/solvers/update_step.jl b/src/treetensornetworks/solvers/update_step.jl index c1c2bdc0..0f69ed52 100644 --- a/src/treetensornetworks/solvers/update_step.jl +++ b/src/treetensornetworks/solvers/update_step.jl @@ -112,17 +112,29 @@ function insert_local_tensor( psi::AbstractTTN, phi::ITensor, pos::Vector; - which_decomp=nothing, normalize=false, + # factorize kwargs + maxdim=nothing, + mindim=nothing, + cutoff=nothing, + which_decomp=nothing, eigen_perturbation=nothing, - kwargs..., + ortho=nothing, ) spec = nothing for (v, vnext) in IterTools.partition(pos, 2, 1) e = edgetype(psi)(v, vnext) indsTe = inds(psi[v]) L, phi, spec = factorize( - phi, indsTe; which_decomp, tags=tags(psi, e), eigen_perturbation, kwargs... + phi, + indsTe; + tags=tags(psi, e), + maxdim, + mindim, + cutoff, + which_decomp, + eigen_perturbation, + ortho, ) psi[v] = L eigen_perturbation = nothing # TODO: fix this @@ -162,7 +174,7 @@ function local_update( sweep, sweep_regions, sweep_step, - kwargs..., + solver_kwargs..., ) psi = orthogonalize(psi, current_ortho(region)) psi, phi = extract_local_tensor(psi, region) @@ -170,8 +182,7 @@ function local_update( nsites = (region isa AbstractEdge) ? 0 : length(region) PH = set_nsite(PH, nsites) PH = position(PH, psi, region) - - phi, info = solver(PH, phi; normalize, region, step_kwargs..., kwargs...) + phi, info = solver(PH, phi; normalize, region, step_kwargs..., solver_kwargs...) if !(phi isa ITensor && info isa NamedTuple) println("Solver returned the following types: $(typeof(phi)), $(typeof(info))") error("In alternating_update, solver must return an ITensor and a NamedTuple") @@ -185,16 +196,7 @@ function local_update( #end psi, spec = insert_local_tensor( - psi, - phi, - region; - eigen_perturbation=drho, - ortho, - normalize, - cutoff, - maxdim, - mindim, - kwargs..., + psi, phi, region; eigen_perturbation=drho, ortho, normalize, maxdim, mindim, cutoff ) update!( diff --git a/test/test_treetensornetworks/test_solvers/test_contract.jl b/test/test_treetensornetworks/test_solvers/test_contract.jl index 65d039a0..2643b5db 100644 --- a/test/test_treetensornetworks/test_solvers/test_contract.jl +++ b/test/test_treetensornetworks/test_solvers/test_contract.jl @@ -40,8 +40,8 @@ using Test @test inner(psit, Hpsi) ≈ inner(psit, H, psi) atol = 1E-5 # Test with less good initial guess MPS not equal to psi - psi_guess = truncate(psi; maxdim=2) - Hpsi = apply(H, psi; alg="fit", nsweeps=4, init_state=psi_guess) + psi_guess = truncate(psit; maxdim=2) + Hpsi = apply(H, psi; alg="fit", nsweeps=4, init=psi_guess) @test inner(psit, Hpsi) ≈ inner(psit, H, psi) atol = 1E-5 # Test with nsite=1 diff --git a/test/test_treetensornetworks/test_solvers/test_tdvp.jl b/test/test_treetensornetworks/test_solvers/test_tdvp.jl index b9f05766..ec002af1 100644 --- a/test/test_treetensornetworks/test_solvers/test_tdvp.jl +++ b/test/test_treetensornetworks/test_solvers/test_tdvp.jl @@ -248,7 +248,7 @@ using Test solver_tol=1e-12, solver_maxiter=500, solver_krylovdim=25, - solver_solver_backend="exponentiate", + solver_backend="exponentiate", ) # TODO: What should `expect` output? Right now # it outputs a dictionary. @@ -312,14 +312,7 @@ using Test nsite = (step <= 3 ? 2 : 1) phi = tdvp( - H, - -tau * im, - phi; - nsteps=1, - cutoff, - nsite, - normalize=true, - exponentiate_krylovdim=15, + H, -tau * im, phi; nsteps=1, cutoff, nsite, normalize=true, solver_krylovdim=15 ) Sz1[step] = real(expect("Sz", psi; vertices=[c])[c]) @@ -379,7 +372,7 @@ using Test for (step, t) in enumerate(trange) nsite = (step <= 10 ? 2 : 1) psi = tdvp( - H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, exponentiate_krylovdim=15 + H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, solver_krylovdim=15 ) #Different backend solvers, default solver_backend = "applyexp" psi2 = tdvp( @@ -390,7 +383,7 @@ using Test nsite, reverse_step, normalize=true, - exponentiate_krylovdim=15, + solver_krylovdim=15, solver_backend="exponentiate", ) end @@ -684,14 +677,7 @@ end nsite = (step <= 3 ? 2 : 1) phi = tdvp( - H, - -tau * im, - phi; - nsteps=1, - cutoff, - nsite, - normalize=true, - exponentiate_krylovdim=15, + H, -tau * im, phi; nsteps=1, cutoff, nsite, normalize=true, solver_krylovdim=15 ) Sz1[step] = real(expect("Sz", psi; vertices=[c])[c]) @@ -742,7 +728,7 @@ end for (step, t) in enumerate(trange) nsite = (step <= 10 ? 2 : 1) psi = tdvp( - H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, exponentiate_krylovdim=15 + H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, solver_krylovdim=15 ) end From 147d8e4ffcc41b1cf38fc2d1d0234fb0b4484ba8 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 3 Nov 2023 09:16:54 -0400 Subject: [PATCH 2/2] Explicit kwargs in applyexp --- src/treetensornetworks/solvers/applyexp.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/treetensornetworks/solvers/applyexp.jl b/src/treetensornetworks/solvers/applyexp.jl index f18e8091..6e84036e 100644 --- a/src/treetensornetworks/solvers/applyexp.jl +++ b/src/treetensornetworks/solvers/applyexp.jl @@ -21,12 +21,7 @@ struct ApplyExpInfo converged::Int end -function applyexp(H, tau::Number, x0; kwargs...) - maxiter = get(kwargs, :maxiter, 30) - tol = get(kwargs, :tol, 1E-12) - outputlevel = get(kwargs, :outputlevel, 0) - beta_tol = get(kwargs, :normcutoff, 1E-7) - +function applyexp(H, tau::Number, x0; maxiter=30, tol=1e-12, outputlevel=0, normcutoff=1e-7) # Initialize Lanczos vectors v1 = copy(x0) nrm = norm(v1)