Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change internal fixed-point API to (x, info) = f(x, info) to simplify SCF metadata tracking #811

Merged
merged 31 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
843c834
Add anderson f(x, info) API suggestion
niklasschmitz Dec 16, 2022
14ca59e
rm whitespace
niklasschmitz Dec 16, 2022
3a7194d
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz Feb 15, 2023
43556b3
Update scf_solvers comments
niklasschmitz Feb 15, 2023
135c28e
Switch scf damping solver to (x, info) API
niklasschmitz Feb 15, 2023
e42526d
Update custom_solvers example with fixedpoint API
niklasschmitz Feb 15, 2023
628d215
Update CROP solver API for info
niklasschmitz Feb 16, 2023
7294b6c
Give CROP solver closure a name
niklasschmitz Feb 16, 2023
3cec66c
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz Mar 14, 2024
91d9b0c
Rename max_iter -> maxiter in example
niklasschmitz Mar 14, 2024
3084397
Fix history_Etot bookkeeping
niklasschmitz Mar 14, 2024
7bb0d45
Update scf algorithms guide with info fixed-point API
niklasschmitz Mar 14, 2024
b0d0a1a
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz May 22, 2024
d234ad0
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz Jul 26, 2024
08f999f
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz Jul 29, 2024
dddaac7
Merge remote-tracking branch 'origin/master' into fp-solvers-function…
niklasschmitz Aug 2, 2024
e9a3394
Remove CROP
niklasschmitz Aug 2, 2024
c0a9d74
Fix MPI broadcast of 'converged' flag
niklasschmitz Aug 2, 2024
caee55e
Remove `tol` from FP solvers and instead use info.converged flag. Mak…
niklasschmitz Aug 2, 2024
96b32e5
Update SCF docs example with new API
niklasschmitz Aug 2, 2024
6e28b8e
Fix SCF docs example info0 handling
niklasschmitz Aug 2, 2024
0ddf05e
Add note on convergence criterion in SCF intro docs
niklasschmitz Aug 5, 2024
60e0210
Add test for SCF timeout
niklasschmitz Aug 5, 2024
f092c4b
Add timeout flag check to SCF solvers
niklasschmitz Aug 5, 2024
00a1b19
Fix timeout tests and rename `timeout` flag to `timedout` for consist…
niklasschmitz Aug 5, 2024
3a37b2a
Add docs on customizing `is_converged` to the custom solvers example
niklasschmitz Aug 5, 2024
009349d
Replace stress convergence example with simpler energy convergence
niklasschmitz Aug 5, 2024
4f8980e
Remove short-circuiting early return from fixpoint_map and leave the …
niklasschmitz Aug 5, 2024
0a89022
Remove density from convergence_criterion
niklasschmitz Aug 5, 2024
1d405e9
Comment on internals for flags in `scf_solvers.jl`
niklasschmitz Aug 5, 2024
21e70ce
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/custom_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@ model = model_LDA(lattice, atoms, positions)
basis = PlaneWaveBasis(model; Ecut=5, kgrid=[1, 1, 1]);

# We define our custom fix-point solver: simply a damped fixed-point
function my_fp_solver(f, x0, max_iter; tol)
function my_fp_solver(f, x0, info0, max_iter; tol)
mixing_factor = .7
x = x0
fx = f(x)
info = info0
fx, info = f(x, info)
for n = 1:max_iter
inc = fx - x
if norm(inc) < tol
break
end
x = x + mixing_factor * inc
fx = f(x)
fx, info = f(x, info)
end
(fixpoint=x, converged=norm(fx-x) < tol)
(fixpoint=x, info, converged=norm(fx-x) < tol)
end;

# Our eigenvalue solver just forms the dense matrix and diagonalizes
Expand Down
52 changes: 34 additions & 18 deletions src/scf/scf_solvers.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# these provide fixed-point solvers that can be passed to scf()
# these provide fixed-point solvers that can be passed to self_consistent_field()

# the fp_solver function must accept being called like fp_solver(f, x0, tol,
# maxiter), where f(x) is the fixed-point map. It must return an
# object supporting res.sol and res.converged
# the fp_solver function must accept being called like
# fp_solver(f, x0, info0, tol, maxiter), where f(x, info) is the fixed-point map.
# It must return an object supporting res.fixpoint, res.info and res.converged

# TODO max_iter could go to the solver generator function arguments

Expand All @@ -11,12 +11,13 @@ Create a damped SCF solver updating the density as
`x = β * x_new + (1 - β) * x`
"""
function scf_damping_solver(β=0.2)
function fp_solver(f, x0, max_iter; tol=1e-6)
function fp_solver(f, x0, info0, max_iter; tol=1e-6)
niklasschmitz marked this conversation as resolved.
Show resolved Hide resolved
β = convert(eltype(x0), β)
converged = false
x = copy(x0)
info = info0
for i in 1:max_iter
x_new = f(x)
x_new, info = f(x, info)

if norm(x_new - x) < tol
x = x_new
Expand All @@ -26,7 +27,7 @@ function scf_damping_solver(β=0.2)

x = @. β * x_new + (1 - β) * x
end
(; fixpoint=x, converged)
(; fixpoint=x, info, converged)
end
fp_solver
end
Expand All @@ -36,19 +37,21 @@ Create a simple anderson-accelerated SCF solver. `m` specifies the number
of steps to keep the history of.
"""
function scf_anderson_solver(m=10; kwargs...)
function anderson(f, x0, max_iter; tol=1e-6)
function anderson(f, x0, info0, max_iter; tol=1e-6)
T = eltype(x0)
x = x0
info = info0

converged = false
acceleration = AndersonAcceleration(; m, kwargs...)
for n = 1:max_iter
residual = f(x) - x
fx, info = f(x, info)
residual = fx - x
converged = norm(residual) < tol
converged && break
x = acceleration(x, one(T), residual)
end
(; fixpoint=x, converged)
(; fixpoint=x, info, converged)
end
end

Expand All @@ -57,7 +60,7 @@ CROP-accelerated root-finding iteration for `f`, starting from `x0` and keeping
a history of `m` steps. Optionally `warming` specifies the number of non-accelerated
steps to perform for warming up the history.
"""
function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
function CROP(f, x0, info0, m::Int, max_iter::Int, tol::Real, warming=0)
mfherbst marked this conversation as resolved.
Show resolved Hide resolved
# CROP iterates maintain xn and fn (/!\ fn != f(xn)).
# xtn+1 = xn + fn
# ftn+1 = f(xtn+1)
Expand All @@ -70,21 +73,25 @@ function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)

# Cheat support for multidimensional arrays
if length(size(x0)) != 1
x, conv= CROP(x -> vec(f(reshape(x, size(x0)...))), vec(x0), m, max_iter, tol, warming)
return (fixpoint=reshape(x, size(x0)...), converged=conv)
function vec_f(x, info)
x, info = f(reshape(x, size(x0)...), info)
vec(x), info
end
x, info, conv = CROP(vec_f, vec(x0), info0, m, max_iter, tol, warming)
return (; fixpoint=reshape(x, size(x0)...), info, converged=conv)
end
N = size(x0,1)
N = size(x0, 1)
T = eltype(x0)
xs = zeros(T, N, m+1) # Ring buffers storing the iterates
fs = zeros(T, N, m+1) # newest to oldest
xs[:,1] = x0
fs[:,1] = f(x0) # Residual
fs[:,1], info = f(x0, info0) # Residual
errs = zeros(max_iter)
err = Inf

for n = 1:max_iter
xtnp1 = xs[:, 1] + fs[:, 1] # Richardson update
ftnp1 = f(xtnp1) # Residual
ftnp1, info = f(xtnp1, info) # Residual
err = norm(ftnp1)
errs[n] = err
if err < tol
Expand All @@ -110,6 +117,15 @@ function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
fs[:,1] = ftnp1
# fs[:,1] = f(xs[:,1])
end
(fixpoint=xs[:, 1], converged=err < tol)
(; fixpoint=xs[:, 1], info, converged=err < tol)
end
function scf_CROP_solver(m=10)
function crop_solver(f, x0, info0, max_iter; tol=1e-6)
function residual_f(x, info)
niklasschmitz marked this conversation as resolved.
Show resolved Hide resolved
fx, info = f(x, info)
(fx - x, info)
end
CROP(residual_f, x0, info0, m, max_iter, tol)
end
crop_solver
end
scf_CROP_solver(m=10) = (f, x0, max_iter; tol=1e-6) -> CROP(x -> f(x) - x, x0, m, max_iter, tol)
50 changes: 24 additions & 26 deletions src/scf/self_consistent_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,12 @@ Overview of parameters:
if !isnothing(ψ)
@assert length(ψ) == length(basis.kpoints)
end
occupation = nothing
eigenvalues = nothing
ρout = ρ
εF = nothing
n_iter = 0
energies = nothing
ham = nothing
info = (; n_iter=0, ρin=ρ) # Populate info with initial values
converged = false

# We do density mixing in the real representation
# TODO support other mixing types
function fixpoint_map(ρin)
converged && return ρin # No more iterations if convergence flagged
function fixpoint_map(ρin, info)
(; ψ, occupation, eigenvalues, εF, n_iter, converged) = info
converged && return ρin, info # No more iterations if convergence flagged
n_iter += 1

# Note that ρin is not the density of ψ, and the eigenvalues
Expand All @@ -128,7 +120,7 @@ Overview of parameters:
ψ, eigenvalues, occupation, εF, ρout = nextstate

# Update info with results gathered so far
info = (; ham, basis, converged, stage=:iterate, algorithm="SCF",
info_next = (; ham, basis, converged, stage=:iterate, algorithm="SCF",
ρin, ρout, α=damping, n_iter, nbandsalg.occupation_threshold,
nextstate..., diagonalization=[nextstate.diagonalization])

Expand All @@ -137,37 +129,43 @@ Overview of parameters:
energies = energy_hamiltonian(basis, ψ, occupation;
ρ=ρout, eigenvalues, εF).energies
end
info = merge(info, (; energies))
info_next = merge(info_next, (; energies))

# Apply mixing and pass it the full info as kwargs
δρ = mix_density(mixing, basis, ρout - ρin; info...)
ρnext = ρin .+ T(damping) .* δρ
info = merge(info, (; ρnext))
Δρ = mix_density(mixing, basis, ρout - ρin; info_next...)
ρnext = ρin .+ T(damping) .* Δρ
info_next = merge(info_next, (; ρnext))

callback(info)
is_converged(info) && (converged = true)
callback(info_next)
if is_converged(info_next)
info_next = merge(info_next, (; converged=true))
end

ρnext
ρnext, info_next
end

info_init = (; ρin=ρ, ψ=ψ, occupation=nothing, eigenvalues=nothing, εF=nothing,
n_iter=0, converged=false)

# Tolerance and maxiter are only dummy here: Convergence is flagged by is_converged
# inside the fixpoint_map.
solver(fixpoint_map, ρout, maxiter; tol=eps(T))

_, info = solver(fixpoint_map, ρ, info_init, maxiter; tol=eps(T)) # TODO ?? why dummy?
niklasschmitz marked this conversation as resolved.
Show resolved Hide resolved
# We do not use the return value of solver but rather the one that got updated by fixpoint_map
# ψ is consistent with ρout, so we return that. We also perform a last energy computation
# to return a correct variational energy
(; ρin, ρout, ψ, occupation, eigenvalues, εF, converged) = info
energies, ham = energy_hamiltonian(basis, ψ, occupation; ρ=ρout, eigenvalues, εF)

# Measure for the accuracy of the SCF
# TODO probably should be tracked all the way ...
norm_Δρ = norm(info.ρout - info.ρin) * sqrt(basis.dvol)
norm_Δρ = norm(ρout - ρin) * sqrt(basis.dvol)

# Callback is run one last time with final state to allow callback to clean up
info = (; ham, basis, energies, converged, nbandsalg.occupation_threshold,
scfres = (; ham, basis, energies, converged, nbandsalg.occupation_threshold,
ρ=ρout, α=damping, eigenvalues, occupation, εF, info.n_bands_converge,
n_iter, ψ, info.diagonalization, stage=:finalize,
info.n_iter, ψ, info.diagonalization, stage=:finalize,
algorithm="SCF", norm_Δρ)
callback(info)
info
callback(scfres)
scfres
end