Skip to content

Commit

Permalink
Merge pull request #401 from ab5424/type_ann
Browse files Browse the repository at this point in the history
Type annotations
  • Loading branch information
shyuep authored Jul 18, 2024
2 parents 6496886 + bf19f79 commit 6943676
Show file tree
Hide file tree
Showing 28 changed files with 496 additions and 386 deletions.
2 changes: 1 addition & 1 deletion docs_rst/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@

# -- Options for LaTeX output ------------------------------------------------

latex_elements = {
latex_elements: dict = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
Expand Down
30 changes: 15 additions & 15 deletions pymatgen/analysis/diffusion/aimd/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
class Kmeans:
"""Simple kmeans clustering."""

def __init__(self, max_iterations: int = 1000):
def __init__(self, max_iterations: int = 1000) -> None:
"""
Args:
max_iterations (int): Maximum number of iterations to run KMeans algo.
"""
self.max_iterations = max_iterations

def cluster(self, points, k, initial_centroids=None):
def cluster(self, points: np.ndarray, k: int, initial_centroids: np.ndarray = None) -> tuple:
"""
Args:
points (ndarray): Data points as a mxn ndarray, where m is the
Expand Down Expand Up @@ -67,7 +67,7 @@ def cluster(self, points, k, initial_centroids=None):
return centroids, labels, ss

@staticmethod
def get_labels(points, centroids):
def get_labels(points: np.ndarray, centroids: np.ndarray) -> tuple:
"""
For each element in the dataset, chose the closest centroid.
Make that centroid the element's label.
Expand All @@ -81,7 +81,7 @@ def get_labels(points, centroids):
return np.where(dists == min_dists[:, None])[1], np.sum(min_dists**2)

@staticmethod
def get_centroids(points, labels, k, centroids):
def get_centroids(points: np.ndarray, labels: np.ndarray, k: int, centroids: np.ndarray) -> np.ndarray:
"""
Each centroid is the geometric mean of the points that
have that centroid's label. Important: If a centroid is empty (no
Expand All @@ -94,16 +94,16 @@ def get_centroids(points, labels, k, centroids):
centroids: List of centroids
"""
labels = np.array(labels)
centroids = []
_centroids = []
for i in range(k):
ind = np.where(labels == i)[0]
if len(ind) > 0:
centroids.append(np.average(points[ind, :], axis=0))
_centroids.append(np.average(points[ind, :], axis=0))
else:
centroids.append(get_random_centroid(points))
return np.array(centroids)
_centroids.append(get_random_centroid(points))
return np.array(_centroids)

def should_stop(self, old_centroids, centroids, iterations):
def should_stop(self, old_centroids: np.ndarray | None, centroids: np.ndarray, iterations: int) -> bool:
"""
Check for stopping conditions.
Expand All @@ -127,7 +127,7 @@ class KmeansPBC(Kmeans):
fractional coordinates.
"""

def __init__(self, lattice, max_iterations=1000):
def __init__(self, lattice: np.ndarray, max_iterations: int = 1000) -> None:
"""
Args:
lattice: Lattice
Expand All @@ -136,7 +136,7 @@ def __init__(self, lattice, max_iterations=1000):
self.lattice = lattice
self.max_iterations = max_iterations

def get_labels(self, points, centroids):
def get_labels(self, points, centroids): # noqa: ANN001,ANN201
"""
For each element in the dataset, chose the closest centroid.
Make that centroid the element's label.
Expand All @@ -149,7 +149,7 @@ def get_labels(self, points, centroids):
min_dists = np.min(dists, axis=1)
return np.where(dists == min_dists[:, None])[1], np.sum(min_dists**2)

def get_centroids(self, points, labels, k, centroids):
def get_centroids(self, points, labels, k, centroids): # noqa: ANN001,ANN201
"""
Each centroid is the geometric mean of the points that
have that centroid's label. Important: If a centroid is empty (no
Expand Down Expand Up @@ -179,7 +179,7 @@ def get_centroids(self, points, labels, k, centroids):
new_centroids.append(c)
return np.array(new_centroids)

def should_stop(self, old_centroids, centroids, iterations):
def should_stop(self, old_centroids: np.ndarray | None, centroids: np.ndarray, iterations: int) -> bool:
"""
Check for stopping conditions.
Expand All @@ -196,7 +196,7 @@ def should_stop(self, old_centroids, centroids, iterations):
return all(np.allclose(pbc_diff(c1, c2), [0, 0, 0]) for c1, c2 in zip(old_centroids, centroids))


def get_random_centroid(points):
def get_random_centroid(points: np.ndarray) -> np.ndarray:
"""
Generate a random centroid based on points.
Expand All @@ -209,7 +209,7 @@ def get_random_centroid(points):
return np.array([random.uniform(mind[i], maxd[i]) for i in range(n)])


def get_random_centroids(points, k):
def get_random_centroids(points: np.ndarray, k: int) -> np.ndarray:
"""
Generate k random centroids based on points.
Expand Down
72 changes: 50 additions & 22 deletions pymatgen/analysis/diffusion/aimd/pathway.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@

import itertools
from collections import Counter
from typing import TYPE_CHECKING

import numpy as np
from scipy.cluster.hierarchy import fcluster, linkage
from scipy.spatial.distance import squareform

if TYPE_CHECKING:
from collections.abc import Sequence

from pymatgen.analysis.diffusion.analyzer import DiffusionAnalyzer
from pymatgen.core.structure import Structure
from pymatgen.util.typing import PathLike, SpeciesLike


class ProbabilityDensityAnalysis:
r"""
Expand All @@ -25,7 +33,13 @@ class ProbabilityDensityAnalysis:
Conductor". Chem. Mater. (2015), 27, pp 8318-8325.
"""

def __init__(self, structure, trajectories, interval=0.5, species=("Li", "Na")):
def __init__(
self,
structure: Structure,
trajectories: np.ndarray,
interval: float = 0.5,
species: Sequence[SpeciesLike] = ("Li", "Na"),
) -> None:
"""
Initialization.
Expand Down Expand Up @@ -66,7 +80,7 @@ def __init__(self, structure, trajectories, interval=0.5, species=("Li", "Na")):
grid = agrid[:, None, None] + bgrid[None, :, None] + cgrid[None, None, :]

# Calculate time-averaged probability density function distribution Pr
count = Counter()
count: Counter = Counter()
Pr = np.zeros(ngrid, dtype=np.double)

for it in range(nsteps):
Expand Down Expand Up @@ -113,10 +127,12 @@ def __init__(self, structure, trajectories, interval=0.5, species=("Li", "Na")):
self.lens = lens
self.Pr = Pr
self.species = species
self.stable_sites = None
self.stable_sites: np.ndarray | None = None

@classmethod
def from_diffusion_analyzer(cls, diffusion_analyzer, interval=0.5, species=("Li", "Na")):
def from_diffusion_analyzer(
cls, diffusion_analyzer: DiffusionAnalyzer, interval: float = 0.5, species: Sequence[SpeciesLike] = ("Li", "Na")
) -> ProbabilityDensityAnalysis:
"""
Create a ProbabilityDensityAnalysis from a diffusion_analyzer object.
Expand All @@ -128,16 +144,16 @@ def from_diffusion_analyzer(cls, diffusion_analyzer, interval=0.5, species=("Li"
species(list of str): list of species that are of interest
"""
structure = diffusion_analyzer.structure
trajectories = []
_trajectories = []

for _i, s in enumerate(diffusion_analyzer.get_drift_corrected_structures()):
trajectories.append(s.frac_coords)
_trajectories.append(s.frac_coords)

trajectories = np.array(trajectories)
trajectories = np.array(_trajectories)

return ProbabilityDensityAnalysis(structure, trajectories, interval=interval, species=species)

def generate_stable_sites(self, p_ratio=0.25, d_cutoff=1.0):
def generate_stable_sites(self, p_ratio: float = 0.25, d_cutoff: float = 1.0) -> None:
"""
Obtain a set of low-energy sites from probability density function with
given probability threshold 'p_ratio'. The set of grid points with
Expand All @@ -157,14 +173,14 @@ def generate_stable_sites(self, p_ratio=0.25, d_cutoff=1.0):
as a Nx3 numpy array.
"""
# Set of grid points with high probability density.
grid_fcoords = []
_grid_fcoords = []
indices = np.where(self.Pr > self.Pr.max() * p_ratio)
lattice = self.structure.lattice

for x, y, z in zip(indices[0], indices[1], indices[2]):
grid_fcoords.append([x / self.lens[0], y / self.lens[1], z / self.lens[2]])
_grid_fcoords.append([x / self.lens[0], y / self.lens[1], z / self.lens[2]])

grid_fcoords = np.array(grid_fcoords)
grid_fcoords = np.array(_grid_fcoords)
dist_matrix = np.array(lattice.get_all_distances(grid_fcoords, grid_fcoords))
np.fill_diagonal(dist_matrix, 0)

Expand All @@ -191,34 +207,35 @@ def generate_stable_sites(self, p_ratio=0.25, d_cutoff=1.0):
stable_sites = []

for i in set(cluster_indices):
indices = np.where(cluster_indices == i)[0]
_indices = np.where(cluster_indices == i)[0]

if len(indices) == 1:
stable_sites.append(grid_fcoords[indices[0]])
if len(_indices) == 1:
stable_sites.append(grid_fcoords[_indices[0]])
continue

# Consider periodic boundary condition
members = grid_fcoords[indices] - grid_fcoords[indices[0]]
members = grid_fcoords[_indices] - grid_fcoords[_indices[0]]
members = np.where(members > 0.5, members - 1.0, members)
members = np.where(members < -0.5, members + 1.0, members)
members += grid_fcoords[indices[0]]
members += grid_fcoords[_indices[0]]

stable_sites.append(np.mean(members, axis=0))

self.stable_sites = np.array(stable_sites)

def get_full_structure(self):
def get_full_structure(self) -> Structure:
"""
Generate the structure with the low-energy sites included. In the end, a
pymatgen Structure object will be returned.
"""
full_structure = self.structure.copy()
assert self.stable_sites is not None, "Please run generate_stable_sites() first!"
for fcoord in self.stable_sites:
full_structure.append("X", fcoord)

return full_structure

def to_chgcar(self, filename="CHGCAR.vasp"):
def to_chgcar(self, filename: PathLike = "CHGCAR.vasp") -> None:
"""
Save the probability density distribution in the format of CHGCAR,
which can be visualized by VESTA.
Expand Down Expand Up @@ -282,7 +299,13 @@ class SiteOccupancyAnalyzer:
"""

def __init__(self, structure, coords_ref, trajectories, species=("Li", "Na")):
def __init__(
self,
structure: Structure,
coords_ref: np.ndarray | Sequence[Sequence[float]],
trajectories: np.ndarray | Sequence[Sequence[Sequence[float]]],
species: Sequence[SpeciesLike] = ("Li", "Na"),
) -> None:
"""
Args:
structure (pmg_structure): Initial structure.
Expand All @@ -296,7 +319,7 @@ def __init__(self, structure, coords_ref, trajectories, species=("Li", "Na")):
lattice = structure.lattice
coords_ref = np.array(coords_ref)
trajectories = np.array(trajectories)
count = Counter()
count: Counter = Counter()

indices = [i for i, site in enumerate(structure) if site.specie.symbol in species]

Expand All @@ -318,12 +341,17 @@ def __init__(self, structure, coords_ref, trajectories, species=("Li", "Na")):
self.nsteps = len(trajectories)
self.site_occ = site_occ

def get_average_site_occupancy(self, indices):
def get_average_site_occupancy(self, indices: list) -> float:
"""Get the average site occupancy over a subset of reference sites."""
return np.sum(self.site_occ[indices]) / len(indices)

@classmethod
def from_diffusion_analyzer(cls, coords_ref, diffusion_analyzer, species=("Li", "Na")):
def from_diffusion_analyzer(
cls,
coords_ref: np.ndarray | Sequence[Sequence[float]],
diffusion_analyzer: DiffusionAnalyzer,
species: Sequence[SpeciesLike] = ("Li", "Na"),
) -> SiteOccupancyAnalyzer:
"""
Create a SiteOccupancyAnalyzer object using a diffusion_analyzer object.
Expand Down
Loading

0 comments on commit 6943676

Please sign in to comment.