Skip to content

Commit

Permalink
Filter rgrid symmetries (#642)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoine-levitt authored May 13, 2022
1 parent eeb1a56 commit 58bf881
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 86 deletions.
69 changes: 44 additions & 25 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ struct PlaneWaveBasis{T} <: AbstractBasis{T}
krange_allprocs::Vector{Vector{Int}} # indices of kpoints treated by the
# respective rank in comm_kpts

## Symmetry operations that leave the reducible Brillouin zone invariant.
## Symmetry operations that leave the discretized model (k and r grids) invariant.
# Subset of model.symmetries.
# Nearly all computations will be done inside this symmetry group;
# the exception is inexact operations on the FFT grid (ie xc),
# which don't respect the symmetry
symmetries::Vector{SymOp}
# Whether the symmetry operations leave the rgrid invariant
# If this is true, the symmetries are a property of the complete discretized model.
# Therefore, all quantities should be symmetric to machine precision
symmetries_respect_rgrid::Bool

## Instantiated terms (<: Term). See Hamiltonian for high-level usage
terms::Vector{Any}
Expand Down Expand Up @@ -141,35 +142,41 @@ end
# All given parameters must be the same on all processors
# and are stored in PlaneWaveBasis for easy reconstruction.
function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
kcoords, kweights, kgrid, kshift, symmetries, comm_kpts) where {T <: Real}
kcoords, kweights, kgrid, kshift,
symmetries_respect_rgrid, comm_kpts) where {T <: Real}
# Validate fft_size
if variational
max_E = sum(abs2, model.recip_lattice * floor.(Int, Vec3(fft_size) ./ 2)) / 2
Ecut > max_E && @warn(
"For a variational method, Ecut should be less than the maximal kinetic " *
"energy the grid supports ($max_E)"
"energy the grid supports ($max_E)"
)
end
if !(all(fft_size .== next_working_fft_size(T, fft_size)))
@show fft_size next_working_fft_size(T, fft_size)
error("Selected fft_size will not work for the buggy generic " *
"FFT routines; use next_working_fft_size")
end
fft_size = Tuple{Int, Int, Int}(fft_size) # explicit conversion in case passed as array

# filter out the symmetries that don't preserve the real-space grid
symmetries = model.symmetries
if symmetries_respect_rgrid
symmetries = symmetries_preserving_rgrid(symmetries, fft_size)
end

# build or validate the kgrid, and get symmetries preserving the kgrid
if isnothing(kcoords)
# MP grid based on kgrid/kshift
@assert !isnothing(kgrid)
@assert !isnothing(kshift)
@assert isnothing(kweights)
@assert isnothing(symmetries)
kcoords, kweights, symmetries = bzmesh_ir_wedge(kgrid, model.symmetries; kshift)
kcoords, kweights, symmetries = bzmesh_ir_wedge(kgrid, symmetries; kshift)
else
# Manual kpoint set based on kcoords/kweights
@assert length(kcoords) == length(kweights)
if isnothing(symmetries)
all_kcoords = unfold_kcoords(kcoords, model.symmetries)
symmetries = symmetries_preserving_kgrid(model.symmetries, all_kcoords)
end
all_kcoords = unfold_kcoords(kcoords, symmetries)
symmetries = symmetries_preserving_kgrid(symmetries, all_kcoords)
end

# Init MPI, and store MPI-global values for reference
Expand Down Expand Up @@ -241,7 +248,7 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Number, fft_size, variational,
r_to_G_normalization, G_to_r_normalization,
kpoints, kweights_thisproc, kgrid, kshift,
kcoords_global, kweights_global, comm_kpts, krange_thisproc, krange_allprocs,
symmetries, terms)
symmetries, symmetries_respect_rgrid, terms)

# Instantiate the terms with the basis
for (it, t) in enumerate(model.term_types)
Expand All @@ -257,13 +264,28 @@ end
@timing function PlaneWaveBasis(model::Model{T}, Ecut::Number,
kcoords ::Union{Nothing, AbstractVector},
kweights::Union{Nothing, AbstractVector};
symmetries=nothing,
variational=true,
fft_size=(@assert variational; compute_fft_size(model, Ecut, kcoords)),
variational=true, fft_size=nothing,
kgrid=nothing, kshift=nothing,
symmetries_respect_rgrid=isnothing(fft_size),
comm_kpts=MPI.COMM_WORLD) where {T <: Real}
if isnothing(fft_size)
@assert variational
if symmetries_respect_rgrid
# ensure that the FFT grid is compatible with the "reasonable" symmetries
# (those with fractional translations with denominators 2, 3, 4, 6,
# this set being more or less arbitrary) by forcing the FFT size to be
# a multiple of the denominators.
# See https://github.com/JuliaMolSim/DFTK.jl/pull/642 for discussion
denominators = [denominator(rationalize(sym.w[i]; tol=SYMMETRY_TOLERANCE))
for sym in model.symmetries for i = 1:3]
factors = intersect((2, 3, 4, 6), denominators)
else
factors = (1,)
end
fft_size = compute_fft_size(model, Ecut, kcoords; factors=factors)
end
PlaneWaveBasis(model, Ecut, fft_size, variational, kcoords, kweights,
kgrid, kshift, symmetries, comm_kpts)
kgrid, kshift, symmetries_respect_rgrid, comm_kpts)
end

@doc raw"""
Expand All @@ -287,12 +309,10 @@ Creates a new basis identical to `basis`, but with a custom set of kpoints
@timing function PlaneWaveBasis(basis::PlaneWaveBasis, kcoords::AbstractVector,
kweights::AbstractVector)
kgrid = kshift = nothing
all_kcoords = unfold_kcoords(kcoords, basis.model.symmetries)
symmetries = symmetries_preserving_kgrid(basis.model.symmetries, all_kcoords)
PlaneWaveBasis(basis.model, basis.Ecut,
basis.fft_size, basis.variational,
kcoords, kweights, kgrid, kshift,
symmetries, basis.comm_kpts)
basis.symmetries_respect_rgrid, basis.comm_kpts)
end

"""
Expand Down Expand Up @@ -451,11 +471,10 @@ function gather_kpts(basis::PlaneWaveBasis)
basis.Ecut,
kcoords[1:n_kcoords],
kweights[1:n_kcoords];
basis.symmetries,
fft_size=basis.fft_size,
kgrid=basis.kgrid,
kshift=basis.kshift,
variational=basis.variational,
basis.variational,
basis.kgrid,
basis.kshift,
basis.symmetries_respect_rgrid,
comm_kpts=MPI.COMM_SELF,
)
end
Expand Down
5 changes: 3 additions & 2 deletions src/SymOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# Tolerance to consider two atomic positions as equal (in relative coordinates)
const SYMMETRY_TOLERANCE = 1e-5

is_approx_integer(r; tol=SYMMETRY_TOLERANCE) = all(ri -> abs(ri - round(ri)) tol, r)

struct SymOp{T <: Real}
# (Uu)(x) = u(W x + w) in real space
W::Mat3{Int}
Expand All @@ -38,8 +40,7 @@ end

Base.:(==)(op1::SymOp, op2::SymOp) = op1.W == op2.W && op1.w == op2.w
function Base.isapprox(op1::SymOp, op2::SymOp; atol=SYMMETRY_TOLERANCE)
is_approx_integer(r) = all(ri -> abs(ri - round(ri)) atol, r)
op1.W == op2.W && is_approx_integer(op1.w - op2.w)
op1.W == op2.W && is_approx_integer(op1.w - op2.w; tol=atol)
end
Base.one(::Type{SymOp}) = SymOp(Mat3{Int}(I), Vec3(zeros(Bool, 3)))
Base.one(::SymOp) = one(SymOp)
Expand Down
10 changes: 5 additions & 5 deletions src/bzmesh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,12 @@ end
Construct the irreducible wedge of a uniform Brillouin zone mesh for sampling ``k``-points,
given the crystal symmetries `symmetries`. Returns the list of irreducible ``k``-point
(fractional) coordinates, the associated weights adn the new `symmetries` compatible with
(fractional) coordinates, the associated weights and the new `symmetries` compatible with
the grid.
"""
function bzmesh_ir_wedge(kgrid_size, symmetries; kshift=[0, 0, 0])
all(isequal.(kgrid_size, 1)) && return bzmesh_uniform(kgrid_size; kshift)

# Filter those symmetry operations (S, τ) that preserve the MP grid
kcoords_mp = kgrid_monkhorst_pack(kgrid_size; kshift)
symmetries = symmetries_preserving_kgrid(symmetries, kcoords_mp)

# Transform kshift to the convention used in spglib:
# If is_shift is set (i.e. integer 1), then a shift of 0.5 is performed,
# else no shift is performed along an axis.
Expand All @@ -59,6 +55,10 @@ function bzmesh_ir_wedge(kgrid_size, symmetries; kshift=[0, 0, 0])
convert(Int, 2 * ks)
end

# Filter those symmetry operations that preserve the MP grid
kcoords_mp = kgrid_monkhorst_pack(kgrid_size; kshift)
symmetries = symmetries_preserving_kgrid(symmetries, kcoords_mp)

# Give the remaining symmetries to spglib to compute an irreducible k-point mesh
# TODO implement time-reversal symmetry and turn the flag to true
Ws = [symop.W for symop in symmetries]
Expand Down
11 changes: 5 additions & 6 deletions src/external/jld2io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ struct PlaneWaveBasisSerialisation{T <: Real}
kweights::Vector{T}
kgrid::Union{Nothing,Vec3{Int}}
kshift::Union{Nothing,Vec3{T}}
symmetries_respect_rgrid::Bool
fft_size::Tuple{Int, Int, Int}
symmetries::Vector{SymOp}
end
JLD2.writeas(::Type{PlaneWaveBasis{T}}) where {T} = PlaneWaveBasisSerialisation{T}

Expand All @@ -96,15 +96,14 @@ function Base.convert(::Type{PlaneWaveBasisSerialisation{T}}, basis::PlaneWaveBa
basis.kweights_global,
basis.kgrid,
basis.kshift,
basis.symmetries_respect_rgrid,
basis.fft_size,
basis.symmetries
)
end

function Base.convert(::Type{PlaneWaveBasis{T}}, serial::PlaneWaveBasisSerialisation{T}) where {T}
PlaneWaveBasis(serial.model, serial.Ecut, serial.kcoords,
serial.kweights; serial.symmetries,
fft_size=serial.fft_size,
kgrid=serial.kgrid, kshift=serial.kshift,
variational=serial.variational)
serial.kweights; serial.fft_size,
serial.kgrid, serial.kshift, serial.symmetries_respect_rgrid,
serial.variational)
end
3 changes: 1 addition & 2 deletions src/external/spglib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ end
for coord in group_positions
# If all elements of a difference in diffs is integer, then
# W * coord + w and pos are equivalent lattice positions
is_approx_integer(r) = all(ri -> abs(ri - round(ri)) tol_symmetry, r)
if !any(c -> is_approx_integer(W * coord + w - c), group_positions)
if !any(c -> is_approx_integer(W * coord + w - c; tol=tol_symmetry), group_positions)
error("spglib returned bad symmetries: Cannot map the atom at position " *
"$coord to another atom of the same element under the symmetry " *
"operation (W, w):\n($W, $w)")
Expand Down
41 changes: 32 additions & 9 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ The function will determine the smallest parallelepiped containing the wave vect
``|G|^2/2 \leq E_\text{cut} ⋅ \text{supersampling}^2``.
For an exact representation of the density resulting from wave functions
represented in the spherical basis sets, `supersampling` should be at least `2`.
If `factors` is not empty, ensure that the resulting fft_size contains all the factors
"""
function compute_fft_size(model::Model{T}, Ecut, kcoords=nothing;
ensure_smallprimes=true, algorithm=:fast, kwargs...) where T
ensure_smallprimes=true, algorithm=:fast, factors=1, kwargs...) where T
if algorithm == :fast
Glims = compute_Glims_fast(model.lattice, Ecut; kwargs...)
elseif algorithm == :precise
Expand All @@ -162,18 +164,39 @@ function compute_fft_size(model::Model{T}, Ecut, kcoords=nothing;
error("Unknown fft_size_algorithm :$algorithm, try :fast or :precise")
end

# Optimize FFT grid size: Make sure the number factorises in small primes only
fft_size = Vec3(2 .* Glims .+ 1)
# TODO Make default small primes type-dependent, since generic FFT is broken for some
# prime factors ... temporary workaround, see more details in workarounds/fft_generic.jl
if ensure_smallprimes
fft_size = nextprod.(Ref([2, 3, 5]), fft_size)
smallprimes = default_primes(T) # Usually (2, 3 ,5)
else
smallprimes = ()
end

# TODO generic FFT is kind of broken for some fft sizes
# ... temporary workaround, see more details in workarounds/fft_generic.jl
fft_size = next_working_fft_size(T, fft_size)
# Consider only sizes that are (a) a product of small primes and (b) contain the factors
fft_size = Vec3(2 .* Glims .+ 1)
fft_size = next_compatible_fft_size(fft_size; factors, smallprimes)
Tuple{Int, Int, Int}(fft_size)
end

"""
Find the next compatible FFT size
Sizes must (a) be a product of small primes only and (b) contain the factors.
If smallprimes is empty (a) is skipped.
"""
function next_compatible_fft_size(size::Int; smallprimes=(2, 3, 5), factors=(1, ))
# This could be optimized
is_product_of_primes(n) = isempty(smallprimes) || (n == nextprod(smallprimes, n))
@assert all(is_product_of_primes, factors) # ensure compatibility between (a) and (b)
has_factors(n) = rem(n, prod(factors)) == 0

while !(has_factors(size) && is_product_of_primes(size))
size += 1
end
size
end
function next_compatible_fft_size(sizes::Union{Tuple, AbstractArray}; kwargs...)
next_compatible_fft_size.(sizes; kwargs...)
end

# This uses a more precise and slower algorithm than the one above,
# simply enumerating all G vectors and seeing where their difference
Expand Down Expand Up @@ -255,8 +278,8 @@ end

# TODO Some grid sizes are broken in the generic FFT implementation
# in FourierTransforms, for more details see workarounds/fft_generic.jl
# This function is needed to provide a noop fallback for grid adjustment for
# for floating-point types natively supported by FFTW
default_primes(::Type{Float32}) = (2, 3, 5)
default_primes(::Type{Float64}) = default_primes(Float32)
next_working_fft_size(::Type{Float32}, size::Int) = size
next_working_fft_size(::Type{Float64}, size::Int) = size
next_working_fft_size(T, sizes::Union{Tuple, AbstractArray}) = next_working_fft_size.(T, sizes)
2 changes: 1 addition & 1 deletion src/postprocess/stresses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Compute the stresses (= 1/Vol dE/d(M*lattice), taken at M=I) of an obtained SCF
new_basis = PlaneWaveBasis(new_model,
basis.Ecut, basis.fft_size, basis.variational,
basis.kcoords_global, basis.kweights_global,
basis.kgrid, basis.kshift, basis.symmetries,
basis.kgrid, basis.kshift, basis.symmetries_respect_rgrid,
basis.comm_kpts)
ρ = DFTK.compute_density(new_basis, scfres.ψ, scfres.occupation)
energies, _ = energy_hamiltonian(new_basis, scfres.ψ, scfres.occupation;
Expand Down
27 changes: 21 additions & 6 deletions src/symmetry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function symmetry_operations(lattice, atoms, positions, magnetic_moments=[];
end

"""
Filter out the symmetry operations that respect the symmetries of the discrete BZ grid
Filter out the symmetry operations that don't respect the symmetries of the discrete BZ grid
"""
function symmetries_preserving_kgrid(symmetries, kcoords)
kcoords_normalized = normalize_kpoint_coordinate.(kcoords)
Expand All @@ -63,6 +63,21 @@ function symmetries_preserving_kgrid(symmetries, kcoords)
filter(preserves_grid, symmetries)
end

"""
Filter out the symmetry operations that don't respect the symmetries of the discrete real-space grid
"""
function symmetries_preserving_rgrid(symmetries, fft_size)
is_in_grid(r) = all(zip(r, fft_size)) do (ri, size)
abs(ri * size - round(ri * size)) / size SYMMETRY_TOLERANCE
end

onehot3(i) = (x = zeros(Bool, 3); x[i] = true; Vec3(x))
function preserves_grid(symop)
all(is_in_grid(symop.W * onehot3(i) .// fft_size[i] + symop.w) for i=1:3)
end

filter(preserves_grid, symmetries)
end

@doc raw"""
Apply various standardisations to a lattice and a list of atoms. It uses spglib to detect
Expand Down Expand Up @@ -224,7 +239,6 @@ function symmetrize_forces(model::Model, forces; symmetries)
# see (A.27) of https://arxiv.org/pdf/0906.2569.pdf
# (but careful that our symmetries are r -> Wr+w, not R(r+f))
other_at = W \ (position - w)
is_approx_integer(r) = all(ri -> abs(ri - round(ri)) SYMMETRY_TOLERANCE, r)
i_other_at = findfirst(a -> is_approx_integer(a - other_at), positions_group)
symmetrized_forces[idx] += W * forces[group[i_other_at]]
end
Expand All @@ -245,10 +259,11 @@ function unfold_bz(basis::PlaneWaveBasis)
return basis
else
kcoords = unfold_kcoords(basis.kcoords_global, basis.symmetries)
new_basis = PlaneWaveBasis(basis.model,
basis.Ecut, basis.fft_size, basis.variational,
kcoords, [1/length(kcoords) for _ in kcoords],
basis.kgrid, basis.kshift, basis.symmetries, basis.comm_kpts)
return PlaneWaveBasis(basis.model,
basis.Ecut, basis.fft_size, basis.variational,
kcoords, [1/length(kcoords) for _ in kcoords],
basis.kgrid, basis.kshift,
basis.symmetries_respect_rgrid, basis.comm_kpts)
end
end

Expand Down
11 changes: 3 additions & 8 deletions src/workarounds/fft_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,13 @@ end
# yet fully compliant with the AbstractFFTs interface and has still
# various bugs we work around.

function next_working_fft_size(::Any, size)
function next_working_fft_size(::Any, size::Integer)
# TODO FourierTransforms has a bug, which is triggered
# only in some factorizations, see
# https://github.com/JuliaComputing/FourierTransforms.jl/issues/10
# To be safe we fall back to powers of two

adjusted = nextpow(2, size)
if adjusted != size
@info "Changing fft size to $adjusted (smallest working size for generic FFTs)"
end
adjusted
nextpow(2, size) # We fall back to powers of two to be safe
end
default_primes(::Any) = (2, )

# Generic fallback function, Float32 and Float64 specialization in fft.jl
function build_fft_plans(T, fft_size)
Expand Down
Loading

0 comments on commit 58bf881

Please sign in to comment.