Skip to content

Commit

Permalink
Stricter kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 3, 2023
1 parent 232068c commit d6bce6b
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 71 deletions.
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function contract_solver(PH, psi; kwargs...)
function contract_solver(PH, psi; normalize, region, half_sweep)
v = ITensor(1.0)
for j in sites(PH)
v *= PH.psi0[j]
Expand Down
54 changes: 43 additions & 11 deletions src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
function eigsolve_solver(; solver_which_eigenvalue=:SR, kwargs...)
function solver(H, init; kws...)
function eigsolve_solver(;
solver_which_eigenvalue=:SR,
ishermitian=true,
solver_tol=1e-14,
solver_krylovdim=3,
solver_maxiter=1,
solver_verbosity=0,
)
function solver(H, init; normalize=nothing, region=nothing, half_sweep=nothing)
howmany = 1
which = solver_which_eigenvalue
solver_kwargs = (;
ishermitian=get(kwargs, :ishermitian, true),
tol=get(kwargs, :solver_tol, 1E-14),
krylovdim=get(kwargs, :solver_krylovdim, 3),
maxiter=get(kwargs, :solver_maxiter, 1),
verbosity=get(kwargs, :solver_verbosity, 0),
vals, vecs, info = eigsolve(
H,
init,
howmany,
which;
ishermitian,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_verbosity,
)
vals, vecs, info = eigsolve(H, init, howmany, which; solver_kwargs...)
psi = vecs[1]
return psi, (; solver_info=info, energies=vals)
end
Expand All @@ -19,8 +29,30 @@ end
"""
Overload of `ITensors.dmrg`.
"""
function dmrg(H, init::AbstractTTN; kwargs...)
return alternating_update(eigsolve_solver(; kwargs...), H, init; kwargs...)
function dmrg(
H,
init::AbstractTTN;
solver_which_eigenvalue=:SR,
ishermitian=true,
solver_tol=1e-14,
solver_krylovdim=3,
solver_maxiter=1,
solver_verbosity=0,
kwargs...,
)
return alternating_update(
eigsolve_solver(;
solver_which_eigenvalue,
ishermitian,
solver_tol,
solver_krylovdim,
solver_maxiter,
solver_verbosity,
),
H,
init;
kwargs...,
)
end

"""
Expand Down
4 changes: 3 additions & 1 deletion src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
function dmrg_x_solver(PH, init; kwargs...)
function dmrg_x_solver(
PH, init; normalize=nothing, region=nothing, half_sweep=nothing, reverse_step=nothing
)
H = contract(PH, ITensor(1.0))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
Expand Down
22 changes: 16 additions & 6 deletions src/treetensornetworks/solvers/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ Keyword arguments:
Overload of `KrylovKit.linsolve`.
"""
function linsolve(
A::AbstractTTN, b::AbstractTTN, x₀::AbstractTTN, a₀::Number=0, a₁::Number=1; kwargs...
A::AbstractTTN,
b::AbstractTTN,
x₀::AbstractTTN,
a₀::Number=0,
a₁::Number=1;
normalize,
region,
half_sweep,
)
function linsolve_solver(
P,
Expand All @@ -33,17 +40,20 @@ function linsolve(
solver_krylovdim=30,
solver_maxiter=100,
solver_verbosity=0,
kwargs...,
)
solver_kwargs = (;
ishermitian=ishermitian,
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(
P,
b,
x₀,
a₀,
a₁;
ishermitian,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_verbosity,
)
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(P, b, x₀, a₀, a₁; solver_kwargs...)
return x, NamedTuple()
end

Expand Down
37 changes: 23 additions & 14 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function exponentiate_solver(; kwargs...)
function exponentiate_solver()
function solver(
H,
init;
Expand All @@ -10,10 +10,13 @@ function exponentiate_solver(; kwargs...)
solver_outputlevel=0,
solver_tol=1E-12,
substep,
normalize,
time_step,
kws...,
)
solver_kwargs = (;
psi, exp_info = KrylovKit.exponentiate(
H,
time_step,
init;
ishermitian,
issymmetric,
tol=solver_tol,
Expand All @@ -22,14 +25,12 @@ function exponentiate_solver(; kwargs...)
verbosity=solver_outputlevel,
eager=true,
)

psi, exp_info = KrylovKit.exponentiate(H, time_step, init; solver_kwargs...)
return psi, (; info=exp_info)
end
return solver
end

function applyexp_solver(; kwargs...)
function applyexp_solver()
function solver(
H,
init;
Expand All @@ -39,13 +40,13 @@ function applyexp_solver(; kwargs...)
solver_tol=1E-8,
substep,
time_step,
kws...,
normalize,
)
solver_kwargs = (; maxiter=solver_krylovdim, outputlevel=solver_outputlevel)

#applyexp tol is absolute, compute from tol_per_unit_time:
tol = abs(time_step) * tol_per_unit_time
psi, exp_info = applyexp(H, time_step, init; tol, solver_kwargs..., kws...)
psi, exp_info = applyexp(
H, time_step, init; tol, maxiter=solver_krylovdim, outputlevel=solver_outputlevel
)
return psi, (; info=exp_info)
end
return solver
Expand Down Expand Up @@ -84,7 +85,12 @@ function sub_time_steps(order)
end

function tdvp_sweep(
order::Int, nsite::Int, time_step::Number, graph::AbstractGraph; kwargs...
order::Int,
nsite::Int,
time_step::Number,
graph::AbstractGraph;
root_vertex=default_root_vertex(graph),
reverse_step=true,
)
sweep = []
for (substep, fac) in enumerate(sub_time_steps(order))
Expand All @@ -93,10 +99,11 @@ function tdvp_sweep(
direction(substep),
graph,
make_region;
root_vertex,
nsite,
region_args=(; substep, time_step=sub_time_step),
reverse_args=(; substep, time_step=-sub_time_step),
reverse_step=true,
reverse_step,
)
append!(sweep, half)
end
Expand All @@ -113,10 +120,12 @@ function tdvp(
nsteps=nothing,
order::Integer=2,
(sweep_observer!)=observer(),
root_vertex=default_root_vertex(init),
reverse_step=true,
kwargs...,
)
nsweeps = _compute_nsweeps(nsteps, t, time_step, order)
sweep_regions = tdvp_sweep(order, nsite, time_step, init; kwargs...)
sweep_regions = tdvp_sweep(order, nsite, time_step, init; root_vertex, reverse_step)

function sweep_time_printer(; outputlevel, sweep, kwargs...)
if outputlevel >= 1
Expand Down Expand Up @@ -169,5 +178,5 @@ function tdvp(H, t::Number, init::AbstractTTN; solver_backend="exponentiate", kw
"solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")",
)
end
return tdvp(solver(; kwargs...), H, t, init; kwargs...)
return tdvp(solver(), H, t, init; kwargs...)
end
34 changes: 18 additions & 16 deletions src/treetensornetworks/solvers/update_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,29 @@ function insert_local_tensor(
psi::AbstractTTN,
phi::ITensor,
pos::Vector;
which_decomp=nothing,
normalize=false,
# factorize kwargs
maxdim=nothing,
mindim=nothing,
cutoff=nothing,
which_decomp=nothing,
eigen_perturbation=nothing,
kwargs...,
ortho=nothing,
)
spec = nothing
for (v, vnext) in IterTools.partition(pos, 2, 1)
e = edgetype(psi)(v, vnext)
indsTe = inds(psi[v])
L, phi, spec = factorize(
phi, indsTe; which_decomp, tags=tags(psi, e), eigen_perturbation, kwargs...
phi,
indsTe;
tags=tags(psi, e),
maxdim,
mindim,
cutoff,
which_decomp,
eigen_perturbation,
ortho,
)
psi[v] = L
eigen_perturbation = nothing # TODO: fix this
Expand Down Expand Up @@ -162,16 +174,15 @@ function local_update(
sweep,
sweep_regions,
sweep_step,
kwargs...,
solver_kwargs...,
)
psi = orthogonalize(psi, current_ortho(region))
psi, phi = extract_local_tensor(psi, region)

nsites = (region isa AbstractEdge) ? 0 : length(region)
PH = set_nsite(PH, nsites)
PH = position(PH, psi, region)

phi, info = solver(PH, phi; normalize, region, step_kwargs..., kwargs...)
phi, info = solver(PH, phi; normalize, region, step_kwargs..., solver_kwargs...)
if !(phi isa ITensor && info isa NamedTuple)
println("Solver returned the following types: $(typeof(phi)), $(typeof(info))")
error("In alternating_update, solver must return an ITensor and a NamedTuple")
Expand All @@ -185,16 +196,7 @@ function local_update(
#end

psi, spec = insert_local_tensor(
psi,
phi,
region;
eigen_perturbation=drho,
ortho,
normalize,
cutoff,
maxdim,
mindim,
kwargs...,
psi, phi, region; eigen_perturbation=drho, ortho, normalize, maxdim, mindim, cutoff
)

update!(
Expand Down
4 changes: 2 additions & 2 deletions test/test_treetensornetworks/test_solvers/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ using Test
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5

# Test with less good initial guess MPS not equal to psi
psi_guess = truncate(psi; maxdim=2)
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init_state=psi_guess)
psi_guess = truncate(psit; maxdim=2)
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init=psi_guess)
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5

# Test with nsite=1
Expand Down
26 changes: 6 additions & 20 deletions test/test_treetensornetworks/test_solvers/test_tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ using Test
solver_tol=1e-12,
solver_maxiter=500,
solver_krylovdim=25,
solver_solver_backend="exponentiate",
solver_backend="exponentiate",
)
# TODO: What should `expect` output? Right now
# it outputs a dictionary.
Expand Down Expand Up @@ -312,14 +312,7 @@ using Test

nsite = (step <= 3 ? 2 : 1)
phi = tdvp(
H,
-tau * im,
phi;
nsteps=1,
cutoff,
nsite,
normalize=true,
exponentiate_krylovdim=15,
H, -tau * im, phi; nsteps=1, cutoff, nsite, normalize=true, solver_krylovdim=15
)

Sz1[step] = real(expect("Sz", psi; vertices=[c])[c])
Expand Down Expand Up @@ -379,7 +372,7 @@ using Test
for (step, t) in enumerate(trange)
nsite = (step <= 10 ? 2 : 1)
psi = tdvp(
H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, exponentiate_krylovdim=15
H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, solver_krylovdim=15
)
#Different backend solvers, default solver_backend = "applyexp"
psi2 = tdvp(
Expand All @@ -390,7 +383,7 @@ using Test
nsite,
reverse_step,
normalize=true,
exponentiate_krylovdim=15,
solver_krylovdim=15,
solver_backend="exponentiate",
)
end
Expand Down Expand Up @@ -684,14 +677,7 @@ end

nsite = (step <= 3 ? 2 : 1)
phi = tdvp(
H,
-tau * im,
phi;
nsteps=1,
cutoff,
nsite,
normalize=true,
exponentiate_krylovdim=15,
H, -tau * im, phi; nsteps=1, cutoff, nsite, normalize=true, solver_krylovdim=15
)

Sz1[step] = real(expect("Sz", psi; vertices=[c])[c])
Expand Down Expand Up @@ -742,7 +728,7 @@ end
for (step, t) in enumerate(trange)
nsite = (step <= 10 ? 2 : 1)
psi = tdvp(
H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, exponentiate_krylovdim=15
H, -tau, psi; cutoff, nsite, reverse_step, normalize=true, solver_krylovdim=15
)
end

Expand Down

0 comments on commit d6bce6b

Please sign in to comment.