Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change internal fixed-point API to (x, info) = f(x, info) to simplify SCF metadata tracking #811

Merged
merged 31 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
843c834
Add anderson f(x, info) API suggestion
niklasschmitz Dec 16, 2022
14ca59e
rm whitespace
niklasschmitz Dec 16, 2022
3a7194d
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz Feb 15, 2023
43556b3
Update scf_solvers comments
niklasschmitz Feb 15, 2023
135c28e
Switch scf damping solver to (x, info) API
niklasschmitz Feb 15, 2023
e42526d
Update custom_solvers example with fixedpoint API
niklasschmitz Feb 15, 2023
628d215
Update CROP solver API for info
niklasschmitz Feb 16, 2023
7294b6c
Give CROP solver closure a name
niklasschmitz Feb 16, 2023
3cec66c
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz Mar 14, 2024
91d9b0c
Rename max_iter -> maxiter in example
niklasschmitz Mar 14, 2024
3084397
Fix history_Etot bookkeeping
niklasschmitz Mar 14, 2024
7bb0d45
Update scf algorithms guide with info fixed-point API
niklasschmitz Mar 14, 2024
b0d0a1a
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz May 22, 2024
d234ad0
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz Jul 26, 2024
08f999f
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz Jul 29, 2024
dddaac7
Merge remote-tracking branch 'origin/master' into fp-solvers-function…
niklasschmitz Aug 2, 2024
e9a3394
Remove CROP
niklasschmitz Aug 2, 2024
c0a9d74
Fix MPI broadcast of 'converged' flag
niklasschmitz Aug 2, 2024
caee55e
Remove `tol` from FP solvers and instead use info.converged flag. Mak…
niklasschmitz Aug 2, 2024
96b32e5
Update SCF docs example with new API
niklasschmitz Aug 2, 2024
6e28b8e
Fix SCF docs example info0 handling
niklasschmitz Aug 2, 2024
0ddf05e
Add note on convergence criterion in SCF intro docs
niklasschmitz Aug 5, 2024
60e0210
Add test for SCF timeout
niklasschmitz Aug 5, 2024
f092c4b
Add timeout flag check to SCF solvers
niklasschmitz Aug 5, 2024
00a1b19
Fix timeout tests and rename `timeout` flag to `timedout` for consist…
niklasschmitz Aug 5, 2024
3a37b2a
Add docs on customizing `is_converged` to the custom solvers example
niklasschmitz Aug 5, 2024
009349d
Replace stress convergence example with simpler energy convergence
niklasschmitz Aug 5, 2024
4f8980e
Remove short-circuiting early return from fixpoint_map and leave the …
niklasschmitz Aug 5, 2024
0a89022
Remove density from convergence_criterion
niklasschmitz Aug 5, 2024
1d405e9
Comment on internals for flags in `scf_solvers.jl`
niklasschmitz Aug 5, 2024
21e70ce
Merge branch 'master' into fp-solvers-functional-api
niklasschmitz Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions docs/src/guide/self_consistent_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,27 +123,36 @@

using DFTK
using LinearAlgebra
function fixed_point_iteration(F, ρ₀, maxiter; tol)

function fixed_point_iteration(F, ρ0, info0; maxiter, tol=1e-10)
## F: The SCF step function
## ρ₀: The initial guess density
## ρ0: The initial guess density
## info0: The initial metadata
## maxiter: The maximal number of iterations to be performed
## tol: The selected convergence tolerance

ρ = ρ₀
= F(ρ)
ρ = ρ0
info = info0
for n = 1:maxiter
## If change less than tolerance, break iterations:
Fρ, info = F(ρ, info)
## If the change is less than the tolerance, break iteration.
if norm(Fρ - ρ) < tol
break
end
ρ = Fρ
Fρ = F(ρ)
ρ = Fρ
end

## Return some stuff DFTK needs ...
(fixpoint=ρ, converged=norm(Fρ-ρ) < tol)
(; fixpoint=ρ, info)
end;

# !!! note "Convergence checks in DFTK"
# The ad-hoc convergence criterion in the example above is included only for
# pedagogical purposes. It does not yet include the correct scaling,
# which depends on the discretization.
# It is preferred to use the provided DFTK utilities for specifiying
# convergence, that can be shared across different solvers. For the more
# advanced version, see the tutorial on [custom SCF solvers](@ref custom-solvers).

# To test this algorithm we use the following simple setting, which builds and discretises
# a PBE model for an aluminium supercell.

Expand Down Expand Up @@ -213,21 +222,22 @@ self_consistent_field(aluminium_setup(1); solver=fixed_point_iteration, damping=
# ```
# In terms of an algorithm Anderson iteration is

function anderson_iteration(F, ρ₀, maxiter; tol)
function anderson_iteration(F, ρ0, info0; maxiter)
## F: The SCF step function
## ρ₀: The initial guess density
## ρ0: The initial guess density
## info0: The initial metadata
## maxiter: The maximal number of iterations to be performed
## tol: The selected convergence tolerance

converged = false
ρ = ρ₀
info = info0
ρ = ρ0
ρs = []
Rs = []
for n = 1:maxiter
Fρ = F(ρ)
Fρ, info = F(ρ, info)
if info.converged
break
end
Rρ = Fρ - ρ
converged = norm(Rρ) < tol
converged && break

ρnext = vec(ρ) .+ vec(Rρ)
if !isempty(Rs)
Expand All @@ -241,11 +251,11 @@ function anderson_iteration(F, ρ₀, maxiter; tol)

push!(ρs, vec(ρ))
push!(Rs, vec(Rρ))
ρ = reshape(ρnext, size(ρ₀)...)
ρ = reshape(ρnext, size(ρ0)...)
end

## Return some stuff DFTK needs ...
(fixpoint=ρ, converged=converged)
(; fixpoint=ρ, info)
end;

# To work with this algorithm we will use DFTK's intrinsic mechanism to choose a damping. The syntax for this is
Expand Down
50 changes: 36 additions & 14 deletions examples/custom_solvers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# # Custom solvers
# # [Custom solvers](@id custom-solvers)
# In this example, we show how to define custom solvers. Our system
# will again be silicon, because we are not very imaginative
using DFTK, LinearAlgebra
Expand All @@ -16,20 +16,24 @@ model = model_LDA(lattice, atoms, positions)
basis = PlaneWaveBasis(model; Ecut=5, kgrid=[1, 1, 1]);

# We define our custom fix-point solver: simply a damped fixed-point
function my_fp_solver(f, x0, max_iter; tol)
function my_fp_solver(f, x0, info0; maxiter)
mixing_factor = .7
x = x0
fx = f(x)
for n = 1:max_iter
inc = fx - x
if norm(inc) < tol
info = info0
for n = 1:maxiter
fx, info = f(x, info)
if info.converged || info.timedout
break
end
x = x + mixing_factor * inc
fx = f(x)
x = x + mixing_factor * (fx - x)
end
(; fixpoint=x, converged=norm(fx-x) < tol)
(; fixpoint=x, info)
end;
# Note that the fixpoint map `f` operates on an auxiliary variable `info` for
# state bookkeeping. Early termination criteria are flagged from inside
# the function `f` using boolean flags `info.converged` and `info.timedout`.
# For control over these criteria, see the `is_converged` and `maxtime`
# keyword arguments of `self_consistent_field`.

# Our eigenvalue solver just forms the dense matrix and diagonalizes
# it explicitly (this only works for very small systems)
Expand Down Expand Up @@ -69,8 +73,26 @@ scfres = self_consistent_field(basis;
eigensolver=my_eig_solver,
mixing=MyMixing());
# Note that the default convergence criterion is the difference in
# density. When this gets below `tol`, the
# "driver" `self_consistent_field` artificially makes the fixed-point
# solver think it's converged by forcing `f(x) = x`. You can customize
# this with the `is_converged` keyword argument to
# `self_consistent_field`.
# density. When this gets below `tol`, the fixed-point solver terminates.
# You can also customize this with the `is_converged` keyword argument to
# `self_consistent_field`, as shown below.

# ## Customizing the convergence criterion
# Here is an example of a defining a custom convergence criterion and specifying
# it using the `is_converged` callback keyword to `self_consistent_field`.

function my_convergence_criterion(info)
tol = 1e-10
if last(info.history_Δρ) > 10sqrt(conv.tolerance)
return false # The ρ change should also be small to avoid the SCF being just stuck
end
length(info.history_Etot) < 2 && return false
mfherbst marked this conversation as resolved.
Show resolved Hide resolved
ΔE = (info.history_Etot[end-1] - info.history_Etot[end])
ΔE < tol
end

scfres2 = self_consistent_field(basis;
solver=my_fp_solver,
is_converged=my_convergence_criterion,
eigensolver=my_eig_solver,
mixing=MyMixing());
1 change: 0 additions & 1 deletion src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ export LdosMixing, HybridMixing, χ0Mixing
export FixedBands, AdaptiveBands
export scf_damping_solver
export scf_anderson_solver
export scf_CROP_solver
export self_consistent_field, kwargs_scf_checkpoints
export ScfConvergenceEnergy, ScfConvergenceDensity, ScfConvergenceForce
export ScfSaveCheckpoints, ScfDefaultCallback, AdaptiveDiagtol
Expand Down
100 changes: 18 additions & 82 deletions src/scf/scf_solvers.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,27 @@
# these provide fixed-point solvers that can be passed to scf()
# these provide fixed-point solvers that can be passed to self_consistent_field()

# the fp_solver function must accept being called like fp_solver(f, x0, tol,
# maxiter), where f(x) is the fixed-point map. It must return an
# object supporting res.sol and res.converged
# the fp_solver function must accept being called like
# fp_solver(f, x0, info0; maxiter), where f(x, info) is the fixed-point map.
# It must return an object supporting res.fixpoint and res.info

mfherbst marked this conversation as resolved.
Show resolved Hide resolved
"""
Create a damped SCF solver updating the density as
`x = β * x_new + (1 - β) * x`
"""
function scf_damping_solver(β=0.2)
function fp_solver(f, x0, maxiter; tol=1e-6)
function fp_solver(f, x0, info0; maxiter)
β = convert(eltype(x0), β)
converged = false
x = copy(x0)
info = info0
for i = 1:maxiter
x_new = f(x)

if norm(x_new - x) < tol
x_new, info = f(x, info)
if info.converged || info.timedout
x = x_new
converged = true
break
end

x = @. β * x_new + (1 - β) * x
end
(; fixpoint=x, converged)
(; fixpoint=x, info)
end
fp_solver
end
Expand All @@ -34,80 +31,19 @@ Create a simple anderson-accelerated SCF solver. `m` specifies the number
of steps to keep the history of.
"""
function scf_anderson_solver(m=10; kwargs...)
function anderson(f, x0, maxiter; tol=1e-6)
function anderson(f, x0, info0; maxiter)
T = eltype(x0)
x = x0

converged = false
info = info0
acceleration = AndersonAcceleration(; m, kwargs...)
for n = 1:maxiter
residual = f(x) - x
converged = norm(residual) < tol
converged && break
x = acceleration(x, one(T), residual)
end
(; fixpoint=x, converged)
end
end

"""
CROP-accelerated root-finding iteration for `f`, starting from `x0` and keeping
a history of `m` steps. Optionally `warming` specifies the number of non-accelerated
steps to perform for warming up the history.
"""
function CROP(f, x0, m::Int, maxiter::Int, tol::Real, warming=0)
# CROP iterates maintain xn and fn (/!\ fn != f(xn)).
# xtn+1 = xn + fn
# ftn+1 = f(xtn+1)
# Determine αi from min ftn+1 + sum αi(fi - ftn+1)
# fn+1 = ftn+1 + sum αi(fi - ftn+1)
# xn+1 = xtn+1 + sum αi(xi - xtn+1)
# Reference:
# Patrick Ettenhuber and Poul Jørgensen, JCTC 2015, 11, 1518-1524
# https://doi.org/10.1021/ct501114q

# Cheat support for multidimensional arrays
if length(size(x0)) != 1
x, conv= CROP(x -> vec(f(reshape(x, size(x0)...))), vec(x0), m, maxiter, tol, warming)
return (; fixpoint=reshape(x, size(x0)...), converged=conv)
end
N = size(x0,1)
T = eltype(x0)
xs = zeros(T, N, m+1) # Ring buffers storing the iterates
fs = zeros(T, N, m+1) # newest to oldest
xs[:,1] = x0
fs[:,1] = f(x0) # Residual
errs = zeros(maxiter)
err = Inf

for n = 1:maxiter
xtnp1 = xs[:, 1] + fs[:, 1] # Richardson update
ftnp1 = f(xtnp1) # Residual
err = norm(ftnp1)
errs[n] = err
if err < tol
break
end

# CROP acceleration
m_eff = min(n, m)
if m_eff > 0 && n > warming
mat = fs[:, 1:m_eff] .- ftnp1
alphas = -mat \ ftnp1
bak_xtnp1 = copy(xtnp1)
bak_ftnp1 = copy(ftnp1)
for i = 1:m_eff
xtnp1 .+= alphas[i].*(xs[:, i] .- bak_xtnp1)
ftnp1 .+= alphas[i].*(fs[:, i] .- bak_ftnp1)
for i = 1:maxiter
fx, info = f(x, info)
if info.converged || info.timedout
break
end
residual = fx - x
x = acceleration(x, one(T), residual)
end

xs = circshift(xs,(0,1))
fs = circshift(fs,(0,1))
xs[:,1] = xtnp1
fs[:,1] = ftnp1
# fs[:,1] = f(xs[:,1])
(; fixpoint=x, info)
end
(; fixpoint=xs[:, 1], converged=err < tol)
end
scf_CROP_solver(m=10) = (f, x0, maxiter; tol=1e-6) -> CROP(x -> f(x) - x, x0, m, maxiter, tol)
Loading
Loading