Skip to content

Commit

Permalink
Modify alternating update kwarg naming, structure, and also solver_in…
Browse files Browse the repository at this point in the history
…terface.
  • Loading branch information
b-kloss committed Jan 17, 2024
1 parent c2c5455 commit ed0df91
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 171 deletions.
56 changes: 30 additions & 26 deletions src/solvers/exponentiate.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,47 @@
function exponentiate_solver()
function solver(
function exponentiate_updater(
init;
psi_ref!,
PH_ref!,
outputlevel,
which_sweep, # keep, change name
region_updates,
which_region_update,
region_kwargs, # region_kwargs (timestep for solver)
updater_kwargs,
)
default_updater_kwargs=(;
krylovdim=30, #from here only solver kwargs
maxiter=100,
outputlevel=0,
tol=1E-12,
ishermitian=true,
issymmetric=true,
region,
sweep_regions,
sweep_step,
solver_krylovdim=30,
solver_maxiter=100,
solver_outputlevel=0,
solver_tol=1E-12,
substep,
normalize,
time_step,
)
eager=true
)
updater_kwargs=merge(default_updater_kwargs,updater_kwargs) #last collection has precedence
#H=copy(PH_ref![])
H = PH_ref![] ###since we are not changing H we don't need the copy
# let's test whether given region and sweep regions we can find out what the previous and next region were
# this will be needed in subspace expansion
region_ind = sweep_step
next_region =
region_ind == length(sweep_regions) ? nothing : first(sweep_regions[region_ind + 1])
previous_region = region_ind == 1 ? nothing : first(sweep_regions[region_ind - 1])
#@show step_kwargs
substep=get(region_kwargs,:substep,nothing)
time_step=get(region_kwargs,:time_step,nothing)
@assert !isnothing(time_step) && !isnothing(substep)
region_ind = which_region_update
next_region = region_ind == length(region_updates) ? nothing : first(region_updates[region_ind + 1])
previous_region = region_ind == 1 ? nothing : first(region_updates[region_ind - 1])

phi, exp_info = KrylovKit.exponentiate(
H,
time_step,
init;
ishermitian,
issymmetric,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_outputlevel,
eager=true,
ishermitian=updater_kwargs[:ishermitian],
issymmetric=updater_kwargs[:issymmetric],
tol=updater_kwargs[:tol],
krylovdim=updater_kwargs[:krylovdim],
maxiter=updater_kwargs[:maxiter],
verbosity=updater_kwargs[:outputlevel],
eager=updater_kwargs[:eager],
)
return phi, (; info=exp_info)
end
return solver
end
47 changes: 25 additions & 22 deletions src/treetensornetworks/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ function process_sweeps(
return maxdim, mindim, cutoff, noise, kwargs
end

function sweep_printer(; outputlevel, psi, sweep, sw_time)
function sweep_printer(; outputlevel, psi, which_sweep, sw_time)
if outputlevel >= 1
print("After sweep ", sweep, ":")
print("After sweep ", which_sweep, ":")
print(" maxlinkdim=", maxlinkdim(psi))
print(" cpu_time=", round(sw_time; digits=3))
println()
Expand All @@ -37,7 +37,7 @@ function sweep_printer(; outputlevel, psi, sweep, sw_time)
end

function alternating_update(
solver,
updater,
PH,
psi0::AbstractTTN;
checkdone=(; kws...) -> false,
Expand All @@ -46,55 +46,58 @@ function alternating_update(
(sweep_observer!)=observer(),
sweep_printer=sweep_printer,
write_when_maxdim_exceeds::Union{Int,Nothing}=nothing,
updater_kwargs,
kwargs...,
)
maxdim, mindim, cutoff, noise, kwargs = process_sweeps(nsweeps; kwargs...)

psi = copy(psi0)

insert_function!(sweep_observer!, "sweep_printer" => sweep_printer)
insert_function!(sweep_observer!, "sweep_printer" => sweep_printer) # FIX THIS

for sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[sweep] > write_when_maxdim_exceeds
for which_sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[which_sweep] > write_when_maxdim_exceeds
if outputlevel >= 2
println(
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[sweep] = $(maxdim[sweep]), writing environment tensors to disk",
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[which_sweep] = $(maxdim[which_sweep]), writing environment tensors to disk",
)
end
PH = disk(PH)
end

sw_time = @elapsed begin
psi, PH = update_step(
solver,
psi, PH = sweep_update(
updater,
PH,
psi;
outputlevel,
sweep,
maxdim=maxdim[sweep],
mindim=mindim[sweep],
cutoff=cutoff[sweep],
noise=noise[sweep],
which_sweep,
sweep_params=(;
maxdim=maxdim[which_sweep],
mindim=mindim[which_sweep],
cutoff=cutoff[which_sweep],
noise=noise[which_sweep]),
updater_kwargs,
kwargs...,
)
end

update!(sweep_observer!; psi, sweep, sw_time, outputlevel)
update!(sweep_observer!; psi, which_sweep, sw_time, outputlevel)

checkdone(; psi, sweep, outputlevel, kwargs...) && break
checkdone(; psi, which_sweep, outputlevel, kwargs...) && break
end
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer")) # remove sweep_printer
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer"))
return psi
end

function alternating_update(solver, H::AbstractTTN, psi0::AbstractTTN; kwargs...)
function alternating_update(updater, H::AbstractTTN, psi0::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
# Permute the indices to have a better memory layout
# and minimize permutations
H = ITensors.permute(H, (linkind, siteinds, linkind))
PH = ProjTTN(H)
return alternating_update(solver, PH, psi0; kwargs...)
return alternating_update(updater, PH, psi0; kwargs...)
end

"""
Expand All @@ -116,12 +119,12 @@ each step of the algorithm when optimizing the MPS.
Returns:
* `psi::MPS` - time-evolved MPS
"""
function alternating_update(solver, Hs::Vector{<:AbstractTTN}, psi0::AbstractTTN; kwargs...)
function alternating_update(updater, Hs::Vector{<:AbstractTTN}, psi0::AbstractTTN; kwargs...)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjTTNSum(Hs)
return alternating_update(solver, PHs, psi0; kwargs...)
return alternating_update(updater, PHs, psi0; kwargs...)
end
17 changes: 9 additions & 8 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function tdvp_sweep(
end

function tdvp(
solver,
updater,
H,
t::Number,
init::AbstractTTN;
Expand All @@ -68,17 +68,18 @@ function tdvp(
(sweep_observer!)=observer(),
root_vertex=default_root_vertex(init),
reverse_step=true,
updater_kwargs=NamedTuple(;),
kwargs...,
)
nsweeps = _compute_nsweeps(nsteps, t, time_step, order)
sweep_regions = tdvp_sweep(order, nsite, time_step, init; root_vertex, reverse_step)
region_updates = tdvp_sweep(order, nsite, time_step, init; root_vertex, reverse_step)

function sweep_time_printer(; outputlevel, sweep, kwargs...)
function sweep_time_printer(; outputlevel, which_sweep, kwargs...)
if outputlevel >= 1
sweeps_per_step = order ÷ 2
if sweep % sweeps_per_step == 0
current_time = (sweep / sweeps_per_step) * time_step
println("Current time (sweep $sweep) = ", round(current_time; digits=3))
current_time = (which_sweep / sweeps_per_step) * time_step
println("Current time (sweep $which_sweep) = ", round(current_time; digits=3))
end
end
return nothing
Expand All @@ -87,7 +88,7 @@ function tdvp(
insert_function!(sweep_observer!, "sweep_time_printer" => sweep_time_printer)

psi = alternating_update(
solver, H, init; nsweeps, sweep_observer!, sweep_regions, nsite, kwargs...
updater, H, init; nsweeps, sweep_observer!, region_updates, updater_kwargs, kwargs...
)

# remove sweep_time_printer from sweep_observer!
Expand All @@ -114,6 +115,6 @@ Optional keyword arguments:
* `observer` - object implementing the Observer interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(H, t::Number, init::AbstractTTN; solver=exponentiate_solver, kwargs...)
return tdvp(solver(), H, t, init; kwargs...)
function tdvp(H, t::Number, init::AbstractTTN; updater=exponentiate_updater, kwargs...)
return tdvp(updater, H, t, init; kwargs...)
end
Loading

0 comments on commit ed0df91

Please sign in to comment.