Skip to content

Commit

Permalink
Format.
Browse files Browse the repository at this point in the history
  • Loading branch information
b-kloss committed Jan 20, 2024
1 parent 59f45c1 commit 8bc62fb
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 52 deletions.
36 changes: 18 additions & 18 deletions src/solvers/contract.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
function contract_updater(
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
v = ITensor(1.0)
projected_operator = projected_operator![]
for j in sites(projected_operator)
v *= projected_operator.psi0[j]
end
Hpsi0 = contract(projected_operator, v)
return Hpsi0, (;)
end
init;
state!,
projected_operator!,
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs,
)
v = ITensor(1.0)
projected_operator = projected_operator![]
for j in sites(projected_operator)
v *= projected_operator.psi0[j]
end
Hpsi0 = contract(projected_operator, v)
return Hpsi0, (;)
end
5 changes: 2 additions & 3 deletions src/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ function dmrg_x_updater(
updater_kwargs,
)
# this updater does not seem to accept any kwargs?
default_updater_kwargs = (;
)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs)
default_updater_kwargs = (;)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs)
H = contract(projected_operator![], ITensor(1.0))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
Expand Down
8 changes: 4 additions & 4 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ function eigsolve_updater(
)
updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence
howmany = 1
which, updater_kwargs = _pop_which_eigenvalue(;updater_kwargs...)
which, updater_kwargs = _pop_which_eigenvalue(; updater_kwargs...)
vals, vecs, info = eigsolve(
projected_operator![],
init,
howmany,
which;
updater_kwargs... #this leaves it
updater_kwargs..., #this leaves it
)
return vecs[1], (; info, eigvals=vals)
end

function _pop_which_eigenvalue(;which_eigenvalue, kwargs...)
function _pop_which_eigenvalue(; which_eigenvalue, kwargs...)
return which_eigenvalue, NamedTuple(kwargs)
end
end
7 changes: 2 additions & 5 deletions src/solvers/exponentiate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@ function exponentiate_updater(
issymmetric=true,
eager=true,
)

updater_kwargs = merge(default_updater_kwargs, updater_kwargs) #last collection has precedence
result, exp_info = exponentiate(
projected_operator![],
region_kwargs.time_step,
init;
updater_kwargs...
projected_operator![], region_kwargs.time_step, init; updater_kwargs...
)
return result, (; info=exp_info)
end
10 changes: 5 additions & 5 deletions src/treetensornetworks/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ function alternating_update(
end
projected_operator = disk(projected_operator)
end
sweep_params=(;
maxdim=maxdim[which_sweep],
mindim=mindim[which_sweep],
cutoff=cutoff[which_sweep],
noise=noise[which_sweep],
sweep_params = (;
maxdim=maxdim[which_sweep],
mindim=mindim[which_sweep],
cutoff=cutoff[which_sweep],
noise=noise[which_sweep],
)
sw_time = @elapsed begin
state, projected_operator = sweep_update(
Expand Down
4 changes: 3 additions & 1 deletion src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ function contract(

PH = ProjTTNApply(tn2, tn1)
sweep_plan = default_sweep_regions(nsites, init; kwargs...)
psi = alternating_update(contract_updater, PH, init; nsweeps, sweep_plan, updater_kwargs, kwargs...)
psi = alternating_update(
contract_updater, PH, init; nsweeps, sweep_plan, updater_kwargs, kwargs...
)

return psi
end
Expand Down
1 change: 0 additions & 1 deletion src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@ end
function dmrg_x(operator, init::AbstractTTN; updater=dmrg_x_updater, kwargs...)
return dmrg_x(updater, operator, init; kwargs...)
end

17 changes: 14 additions & 3 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ function tdvp(
kwargs...,
)
nsweeps = _compute_nsweeps(nsteps, t, time_step, order)
sweep_plan = tdvp_sweep_plan(order, nsites, time_step, init_state; root_vertex, reverse_step)
sweep_plan = tdvp_sweep_plan(
order, nsites, time_step, init_state; root_vertex, reverse_step
)

function sweep_time_printer(; outputlevel, which_sweep, kwargs...)
if outputlevel >= 1
Expand All @@ -88,7 +90,14 @@ function tdvp(
insert_function!(sweep_observer!, "sweep_time_printer" => sweep_time_printer)

state = alternating_update(
updater, operator, init_state; nsweeps, sweep_observer!, sweep_plan, updater_kwargs, kwargs...
updater,
operator,
init_state;
nsweeps,
sweep_observer!,
sweep_plan,
updater_kwargs,
kwargs...,
)

# remove sweep_time_printer from sweep_observer!
Expand All @@ -115,6 +124,8 @@ Optional keyword arguments:
* `observer` - object implementing the Observer interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(operator, t::Number, init_state::AbstractTTN; updater=exponentiate_updater, kwargs...)
function tdvp(
operator, t::Number, init_state::AbstractTTN; updater=exponentiate_updater, kwargs...
)
return tdvp(updater, operator, t, init_state; kwargs...)
end
38 changes: 26 additions & 12 deletions src/treetensornetworks/solvers/update_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@ function default_sweep_regions(nsites, graph::AbstractGraph; kwargs...) ###move
end

function region_update_printer(;
cutoff, maxdim, mindim, outputlevel::Int=0, state, sweep_plan, spec, which_region_update, which_sweep,kwargs...
cutoff,
maxdim,
mindim,
outputlevel::Int=0,
state,
sweep_plan,
spec,
which_region_update,
which_sweep,
kwargs...,
)
if outputlevel >= 2
region=first(sweep_plan[which_region_update])
region = first(sweep_plan[which_region_update])
@printf("Sweep %d, region=%s \n", which_sweep, region)
print(" Truncated using")
@printf(" cutoff=%.1E", cutoff)
Expand Down Expand Up @@ -60,10 +69,10 @@ function sweep_update(
"`alternating_update` currently does not support system sizes of 1. You can diagonalize the MPO tensor directly with tools like `LinearAlgebra.eigen`, `KrylovKit.exponentiate`, etc.",
)
end

for which_region_update in eachindex(sweep_plan)
(region, region_kwargs)=sweep_plan[which_region_update]
region_kwargs=merge(region_kwargs, sweep_params) # sweep params has precedence over step_kwargs
(region, region_kwargs) = sweep_plan[which_region_update]
region_kwargs = merge(region_kwargs, sweep_params) # sweep params has precedence over step_kwargs
state, projected_operator = region_update(
solver,
projected_operator,
Expand All @@ -79,7 +88,7 @@ function sweep_update(
)
end

select!(region_observer!, Observers.DataFrames.Not("region_update_printer")) # remove update_printer
select!(region_observer!, Observers.DataFrames.Not("region_update_printer")) # remove update_printer
# Just to be sure:
normalize && normalize!(state)

Expand Down Expand Up @@ -173,16 +182,16 @@ function region_update(
region_observer!,
#insertion_kwargs, #ToDo: later
#extraction_kwargs, #ToDo: implement later with possibility to pass custom extraction/insertion func (or code into func)
updater_kwargs
updater_kwargs,
)
region=first(sweep_plan[which_region_update])
region = first(sweep_plan[which_region_update])
state = orthogonalize(state, current_ortho(region))
state, phi = extract_local_tensor(state, region;)
nsites = (region isa AbstractEdge) ? 0 : length(region) #ToDo move into separate funtion
projected_operator = set_nsite(projected_operator, nsites)
projected_operator = position(projected_operator, state, region)
state! = Ref(state) # create references, in case solver does (out-of-place) modify PH or state
projected_operator! = Ref(projected_operator)
projected_operator! = Ref(projected_operator)
phi, info = updater(
phi;
state!,
Expand All @@ -192,7 +201,7 @@ function region_update(
sweep_plan,
which_region_update,
region_kwargs,
updater_kwargs
updater_kwargs,
) # args passed by reference are supposed to be modified out of place
state = state![] # dereference
projected_operator = projected_operator![]
Expand All @@ -210,10 +219,15 @@ function region_update(
#end

state, spec = insert_local_tensor(
state, phi, region; eigen_perturbation=drho, ortho, normalize,
state,
phi,
region;
eigen_perturbation=drho,
ortho,
normalize,
maxdim=region_kwargs.maxdim,
mindim=region_kwargs.mindim,
cutoff=region_kwargs.cutoff
cutoff=region_kwargs.cutoff,
)

update!(
Expand Down

0 comments on commit 8bc62fb

Please sign in to comment.