Skip to content

Commit

Permalink
Start renaming. One tdvp test not passing, observer related.
Browse files Browse the repository at this point in the history
  • Loading branch information
Benedikt Kloss committed Jan 19, 2024
1 parent 5587391 commit 0620943
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 152 deletions.
6 changes: 3 additions & 3 deletions src/solvers/dmrg_x_solver.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
function dmrg_x_solver(
init;
psi_ref!,
PH_ref!,
state!,
projected_operator!,
normalize=nothing,
region,
sweep_regions,
sweep_step,
half_sweep,
step_kwargs...,
)
H = contract(PH_ref![], ITensor(1.0))
H = contract(projected_operator![], ITensor(1.0))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
max_overlap, max_ind = findmax(abs, array(dag(init) * U))
Expand Down
6 changes: 3 additions & 3 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function eigsolve_updater(
init;
psi_ref!,
PH_ref!,
state!,
projected_operator!,
outputlevel,
which_sweep,
region_updates,
Expand All @@ -22,7 +22,7 @@ function eigsolve_updater(
howmany = 1
which = updater_kwargs.which_eigenvalue
vals, vecs, info = eigsolve(
PH_ref![],
projected_operator![],
init,
howmany,
which;
Expand Down
10 changes: 5 additions & 5 deletions src/solvers/exponentiate.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function exponentiate_updater(
init;
psi_ref!,
PH_ref!,
state!,
projected_operator!,
outputlevel,
which_sweep,
region_updates,
Expand All @@ -19,8 +19,8 @@ function exponentiate_updater(
eager=true,
)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence
#H=copy(PH_ref![])
H = PH_ref![] ###since we are not changing H we don't need the copy
#H=copy(projected_operator![])
projected_operator = projected_operator![] ###since we are not changing H we don't need the copy
# let's test whether given region and sweep regions we can find out what the previous and next region were
# this will be needed in subspace expansion
#@show step_kwargs
Expand All @@ -33,7 +33,7 @@ function exponentiate_updater(
previous_region = region_ind == 1 ? nothing : first(region_updates[region_ind - 1])

phi, exp_info = exponentiate(
H,
projected_operator,
time_step,
init;
ishermitian=updater_kwargs.ishermitian,
Expand Down
52 changes: 26 additions & 26 deletions src/treetensornetworks/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ function process_sweeps(
return maxdim, mindim, cutoff, noise, kwargs
end

function sweep_printer(; outputlevel, psi, which_sweep, sw_time)
function sweep_printer(; outputlevel, state, which_sweep, sw_time)
if outputlevel >= 1
print("After sweep ", which_sweep, ":")
print(" maxlinkdim=", maxlinkdim(psi))
print(" maxlinkdim=", maxlinkdim(state))
print(" cpu_time=", round(sw_time; digits=3))
println()
flush(stdout)
Expand All @@ -38,8 +38,8 @@ end

function alternating_update(
updater,
PH,
psi0::AbstractTTN;
projected_operator,
init_state::AbstractTTN;
checkdone=(; kws...) -> false,
outputlevel::Integer=0,
nsweeps::Integer=1,
Expand All @@ -51,7 +51,7 @@ function alternating_update(
)
maxdim, mindim, cutoff, noise, kwargs = process_sweeps(nsweeps; kwargs...)

psi = copy(psi0)
state = copy(init_state)

insert_function!(sweep_observer!, "sweep_printer" => sweep_printer) # FIX THIS

Expand All @@ -63,7 +63,7 @@ function alternating_update(
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[which_sweep] = $(maxdim[which_sweep]), writing environment tensors to disk",
)
end
PH = disk(PH)
projected_operator = disk(projected_operator)
end
sweep_params=(;

Check warning on line 68 in src/treetensornetworks/solvers/alternating_update.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/treetensornetworks/solvers/alternating_update.jl:68:- sweep_params=(; src/treetensornetworks/solvers/alternating_update.jl:69:- maxdim=maxdim[which_sweep], src/treetensornetworks/solvers/alternating_update.jl:70:- mindim=mindim[which_sweep], src/treetensornetworks/solvers/alternating_update.jl:71:- cutoff=cutoff[which_sweep], src/treetensornetworks/solvers/alternating_update.jl:72:- noise=noise[which_sweep], src/treetensornetworks/solvers/alternating_update.jl:68:+ sweep_params = (; src/treetensornetworks/solvers/alternating_update.jl:69:+ maxdim=maxdim[which_sweep], src/treetensornetworks/solvers/alternating_update.jl:70:+ mindim=mindim[which_sweep], src/treetensornetworks/solvers/alternating_update.jl:71:+ cutoff=cutoff[which_sweep], src/treetensornetworks/solvers/alternating_update.jl:72:+ noise=noise[which_sweep],
maxdim=maxdim[which_sweep],
Expand All @@ -72,10 +72,10 @@ function alternating_update(
noise=noise[which_sweep],
)
sw_time = @elapsed begin
psi, PH = sweep_update(
state, projected_operator = sweep_update(
updater,
PH,
psi;
projected_operator,
state;
outputlevel,
which_sweep,
sweep_params,
Expand All @@ -84,30 +84,30 @@ function alternating_update(
)
end

update!(sweep_observer!; psi, which_sweep, sw_time, outputlevel)
update!(sweep_observer!; state, which_sweep, sw_time, outputlevel)

checkdone(; psi, which_sweep, outputlevel, kwargs...) && break
checkdone(; state, which_sweep, outputlevel, kwargs...) && break
end
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer"))
return psi
return state
end

function alternating_update(updater, H::AbstractTTN, psi0::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
function alternating_update(updater, H::AbstractTTN, init_state::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, H, init_state)
check_hascommoninds(siteinds, H, init_state')
# Permute the indices to have a better memory layout
# and minimize permutations
H = ITensors.permute(H, (linkind, siteinds, linkind))
PH = ProjTTN(H)
return alternating_update(updater, PH, psi0; kwargs...)
projected_operator = ProjTTN(H)
return alternating_update(updater, projected_operator, init_state; kwargs...)
end

"""
tdvp(Hs::Vector{MPO},psi0::MPS,t::Number; kwargs...)
tdvp(Hs::Vector{MPO},psi0::MPS,t::Number, sweeps::Sweeps; kwargs...)
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number; kwargs...)
tdvp(Hs::Vector{MPO},init_state::MPS,t::Number, sweeps::Sweeps; kwargs...)
Use the time dependent variational principle (TDVP) algorithm
to compute `exp(t*H)*psi0` using an efficient algorithm based
to compute `exp(t*H)*init_state` using an efficient algorithm based
on alternating optimization of the MPS tensors and local Krylov
exponentiation of H.
Expand All @@ -119,16 +119,16 @@ the set of MPOs [H1,H2,H3,..] is efficiently looped over at
each step of the algorithm when optimizing the MPS.
Returns:
* `psi::MPS` - time-evolved MPS
* `state::MPS` - time-evolved MPS
"""
function alternating_update(
updater, Hs::Vector{<:AbstractTTN}, psi0::AbstractTTN; kwargs...
updater, Hs::Vector{<:AbstractTTN}, init_state::AbstractTTN; kwargs...
)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
check_hascommoninds(siteinds, H, init_state)
check_hascommoninds(siteinds, H, init_state')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjTTNSum(Hs)
return alternating_update(updater, PHs, psi0; kwargs...)
projected_operators = ProjTTNSum(Hs)
return alternating_update(updater, projected_operators, init_state; kwargs...)
end
26 changes: 13 additions & 13 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,21 @@ end

function tdvp(
updater,
H,
operator,
t::Number,
init::AbstractTTN;
init_state::AbstractTTN;
time_step::Number=t,
nsites=2,
nsteps=nothing,
order::Integer=2,
(sweep_observer!)=observer(),
root_vertex=default_root_vertex(init),
root_vertex=default_root_vertex(init_state),
reverse_step=true,
updater_kwargs=(;),
kwargs...,
)
nsweeps = _compute_nsweeps(nsteps, t, time_step, order)
region_updates = tdvp_sweep(order, nsites, time_step, init; root_vertex, reverse_step)
region_updates = tdvp_sweep(order, nsites, time_step, init_state; root_vertex, reverse_step)

Check warning on line 75 in src/treetensornetworks/solvers/tdvp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/treetensornetworks/solvers/tdvp.jl:75:- region_updates = tdvp_sweep(order, nsites, time_step, init_state; root_vertex, reverse_step) src/treetensornetworks/solvers/tdvp.jl:75:+ region_updates = tdvp_sweep( src/treetensornetworks/solvers/tdvp.jl:76:+ order, nsites, time_step, init_state; root_vertex, reverse_step src/treetensornetworks/solvers/tdvp.jl:77:+ )

function sweep_time_printer(; outputlevel, which_sweep, kwargs...)
if outputlevel >= 1
Expand All @@ -87,26 +87,26 @@ function tdvp(

insert_function!(sweep_observer!, "sweep_time_printer" => sweep_time_printer)

psi = alternating_update(
updater, H, init; nsweeps, sweep_observer!, region_updates, updater_kwargs, kwargs...
state = alternating_update(
updater, operator, init_state; nsweeps, sweep_observer!, region_updates, updater_kwargs, kwargs...

Check warning on line 91 in src/treetensornetworks/solvers/tdvp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/treetensornetworks/solvers/tdvp.jl:91:- updater, operator, init_state; nsweeps, sweep_observer!, region_updates, updater_kwargs, kwargs... src/treetensornetworks/solvers/tdvp.jl:93:+ updater, src/treetensornetworks/solvers/tdvp.jl:94:+ operator, src/treetensornetworks/solvers/tdvp.jl:95:+ init_state; src/treetensornetworks/solvers/tdvp.jl:96:+ nsweeps, src/treetensornetworks/solvers/tdvp.jl:97:+ sweep_observer!, src/treetensornetworks/solvers/tdvp.jl:98:+ region_updates, src/treetensornetworks/solvers/tdvp.jl:99:+ updater_kwargs, src/treetensornetworks/solvers/tdvp.jl:100:+ kwargs...,
)

# remove sweep_time_printer from sweep_observer!
select!(sweep_observer!, Observers.DataFrames.Not("sweep_time_printer"))

return psi
return state
end

"""
tdvp(H::TTN, t::Number, psi0::TTN; kwargs...)
tdvp(operator::TTN, t::Number, init_state::TTN; kwargs...)
Use the time dependent variational principle (TDVP) algorithm
to approximately compute `exp(H*t)*psi0` using an efficient algorithm based
to approximately compute `exp(operator*t)*init_state` using an efficient algorithm based
on alternating optimization of the state tensors and local Krylov
exponentiation of H. The time parameter `t` can be a real or complex number.
exponentiation of operator. The time parameter `t` can be a real or complex number.
Returns:
* `psi` - time-evolved state
* `state` - time-evolved state
Optional keyword arguments:
* `time_step::Number = t` - time step to use when evolving the state. Smaller time steps generally give more accurate results but can make the algorithm take more computational time to run.
Expand All @@ -115,6 +115,6 @@ Optional keyword arguments:
* `observer` - object implementing the Observer interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(H, t::Number, init::AbstractTTN; updater=exponentiate_updater, kwargs...)
return tdvp(updater, H, t, init; kwargs...)
function tdvp(operator, t::Number, init_state::AbstractTTN; updater=exponentiate_updater, kwargs...)

Check warning on line 118 in src/treetensornetworks/solvers/tdvp.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/treetensornetworks/solvers/tdvp.jl:118:-function tdvp(operator, t::Number, init_state::AbstractTTN; updater=exponentiate_updater, kwargs...) src/treetensornetworks/solvers/tdvp.jl:127:+function tdvp( src/treetensornetworks/solvers/tdvp.jl:128:+ operator, t::Number, init_state::AbstractTTN; updater=exponentiate_updater, kwargs... src/treetensornetworks/solvers/tdvp.jl:129:+)
return tdvp(updater, operator, t, init_state; kwargs...)
end
Loading

0 comments on commit 0620943

Please sign in to comment.