Skip to content

Commit

Permalink
type hints and minor fixes in html/js
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Jan 10, 2022
1 parent 0848fc4 commit 33d58bb
Show file tree
Hide file tree
Showing 11 changed files with 15,412 additions and 321 deletions.
7 changes: 4 additions & 3 deletions dev_scripts/extract_den.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def parse_and_generate_fc(path):
if os.path.isfile(ae_path):
print("Warning: Cannot overwrite existent file:", ae_path)
return 1
#fc_file, ae_file = open(fc_path, "wt"), open(ae_path, "wt")
fc_file, ae_file = sys.stdout, sys.stdout

fc_file, ae_file = open(fc_path, "wt"), open(ae_path, "wt")
#fc_file, ae_file = sys.stdout, sys.stdout
psp8_get_densities(path, fc_file=fc_file, ae_file=ae_file)
#fc_file.close()
#ae_file.close()
Expand All @@ -48,4 +49,4 @@ def parse_and_generate_fc(path):


if __name__ == "__main__":
sys.exit(main())
sys.exit(main())
116 changes: 58 additions & 58 deletions pseudo_dojo/core/atom.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# coding: utf-8
"""This module provides objects and helper functions for atomic calculations."""
from __future__ import annotations

import collections
import numpy as np

from io import StringIO
from typing import Any, List, Union, Optional, Iterable, Tuple
from pseudo_dojo.refdata.nist import database as nist_database
from scipy.interpolate import UnivariateSpline
from scipy.integrate import cumtrapz
Expand All @@ -22,8 +25,6 @@
"plot_logders",
]

# Helper functions

_char2l = {
"s": 0,
"p": 1,
Expand All @@ -35,14 +36,14 @@
}


def _asl(obj):
def _asl(obj: Any) -> int:
try:
return _char2l[obj]
except KeyError:
return int(obj)


def states_from_string(confstr: str):
def states_from_string(confstr: str) -> List[QState]:
"""
Parse a string with an atomic configuration and build a list of `QState` instance.
"""
Expand All @@ -57,7 +58,7 @@ def states_from_string(confstr: str):
return states


def parse_orbtoken(orbtoken: str):
def parse_orbtoken(orbtoken: str) -> QState:
import re
m = re.match(r"(\d+)([spdfghi]+)(\d+)", orbtoken.strip())
if m:
Expand All @@ -68,9 +69,9 @@ def parse_orbtoken(orbtoken: str):

class NlkState(collections.namedtuple("NlkState", "n, l, k")):
"""
Named tuple storing (n,l) or (n,l,k) for relativistic pseudos.
Named tuple storing (n,l) or (n,l,k) if relativistic pseudos.
"""
def __str__(self):
def __str__(self) -> str:
if self.k is None:
return "n=%i, l=%i" % (self.n, self.l)
else:
Expand Down Expand Up @@ -98,26 +99,15 @@ class QState(collections.namedtuple("QState", "n, l, occ, eig, j, s")):
# TODO
# Spin +1, -1 or 1,2 or 0,1?

def __new__(cls, n, l, occ, eig=None, j=None, s=None):
def __new__(cls, n: int, l: int, occ: float,
eig: Optional[float] = None,
j: Optional[int] = None,
s: Optional[int] = None):
"""Intercepts super.__new__ adding type conversion and default values."""
eig = float(eig) if eig is not None else eig
j = int(j) if j is not None else j
s = int(s) if s is not None else s
return super(QState, cls).__new__(cls, int(n), _asl(l), float(occ), eig=eig, j=j, s=s)

# Rich comparison support.
# Note that the ordering is based on the quantum numbers and not on energies!
#def __gt__(self, other):
# if self.has_j:
# raise NotImplementedError("")
# if self.n != other.n: return self.n > other.n
# if self.l != other.l

# if self == other:
# return False
# else:
# raise RuntimeError("Don't know how to compare %s with %s" % (self, other))
#def __lt__(self, other):
return super().__new__(cls, int(n), _asl(l), float(occ), eig=eig, j=j, s=s)

@property
def has_j(self) -> bool:
Expand All @@ -128,9 +118,11 @@ def has_s(self) -> bool:
return self.s is not None


class AtomicConfiguration(object):
"""Atomic configuration defining the all-electron atom."""
def __init__(self, Z, states):
class AtomicConfiguration:
"""
Atomic configuration of an all-electron atom.
"""
def __init__(self, Z: int, states: List[QState]) -> None:
"""
Args:
Z: Atomic number.
Expand All @@ -140,7 +132,8 @@ def __init__(self, Z, states):
self.states = states

@classmethod
def from_string(cls, Z, string, has_s=False, has_j=False):
def from_string(cls, Z: int, string: str,
has_s: bool = False, has_j: bool = False) -> AtomicConfiguration:
if not has_s and not has_j:
# Ex: [He] 2s2 2p3
states = states_from_string(string)
Expand All @@ -149,29 +142,29 @@ def from_string(cls, Z, string, has_s=False, has_j=False):

return cls(Z, states)

def __str__(self):
def __str__(self) -> str:
lines = ["%s: " % self.Z]
lines += [str(state) for state in self]
return "\n".join(lines)

def __len__(self):
def __len__(self) -> int:
return len(self.states)

def __iter__(self):
def __iter__(self) -> Iterable:
return self.states.__iter__()

def __eq__(self, other):
def __eq__(self, other: AtomicConfiguration) -> bool:
if len(self.states) != len(other.states):
return False

return (self.Z == other.Z and
all(s1 == s2 for s1, s2 in zip(self.states, other.states)))

def __ne__(self, other):
def __ne__(self, other: AtomicConfiguration) -> bool:
return not self == other

@classmethod
def neutral_from_symbol(cls, symbol):
def neutral_from_symbol(cls, symbol: Union[str, int]) -> AtomicConfiguration:
"""
symbol: str or int
Can be a chemical symbol (str) or an atomic number (int).
Expand All @@ -180,17 +173,17 @@ def neutral_from_symbol(cls, symbol):
states = [QState(n=s[0], l=s[1], occ=s[2]) for s in entry.states]
return cls(entry.Z, states)

def copy(self):
def copy(self) -> AtomicConfiguration:
"""Shallow copy of self."""
return AtomicConfiguration(self.Z, [s for s in self.states])

@property
def symbol(self):
def symbol(self) -> str:
"""Chemical symbol"""
return nist_database.symbol_from_Z(self.Z)

@property
def spin_mode(self):
def spin_mode(self) -> str:
"""
unpolarized: Spin-unpolarized calculation.
polarized: Spin-polarized calculation.
Expand All @@ -202,12 +195,12 @@ def spin_mode(self):
return "unpolarized"

@property
def echarge(self):
"""Electronic charge (float <0)."""
def echarge(self) -> float:
"""Electronic charge (float < 0 )."""
return -sum(state.occ for state in self)

@property
def isneutral(self):
def isneutral(self) -> bool:
"""True if self is a neutral configuration."""
return abs(self.echarge + self.Z) < 1.e-8

Expand All @@ -232,8 +225,10 @@ def _pop(self, state):
raise


class RadialFunction(object):
"""A RadialFunction has a name, a radial mesh and values defined on this mesh."""
class RadialFunction:
"""
A RadialFunction has a name, a radial mesh and values defined on this mesh.
"""

def __init__(self, name: str, rmesh, values):
"""
Expand All @@ -248,7 +243,7 @@ def __init__(self, name: str, rmesh, values):
assert len(self.rmesh) == len(self.values)

@classmethod
def from_filename(cls, filename: str, rfunc_name=None, cols=(0, 1)):
def from_filename(cls, filename: str, rfunc_name=None, cols=(0, 1)) -> RadialFunction:
"""
Initialize the object reading data from filename (txt format).
Expand All @@ -262,20 +257,20 @@ def from_filename(cls, filename: str, rfunc_name=None, cols=(0, 1)):
name = filename if rfunc_name is None else rfunc_name
return cls(name, rmesh, values)

def __len__(self):
def __len__(self) -> int:
return len(self.values)

def __iter__(self):
def __iter__(self) -> Iterable:
"""Iterate over (rpoint, value)."""
return iter(zip(self.rmesh, self.values))

def __getitem__(self, rslice):
return self.rmesh[rslice], self.values[rslice]

def __repr__(self):
def __repr__(self) -> str:
return "<%s, name=%s at %s>" % (self.__class__.__name__, self.name, id(self))

def __str__(self):
def __str__(self) -> str:
stream = StringIO()
self.pprint(stream=stream)
return stream.getvalue()
Expand All @@ -284,7 +279,7 @@ def __str__(self):
#def __sub__(self, other):
#def __mul__(self, other):

def __abs__(self):
def __abs__(self) -> RadialFunction:
return self.__class__(self.rmesh, np.abs(self.values))

@property
Expand All @@ -295,7 +290,7 @@ def to_dict(self) -> dict:
values=list(self.values),
)

def pprint(self, what: str = "rmesh+values", stream=None):
def pprint(self, what: str = "rmesh+values", stream=None) -> None:
"""pprint method (useful for debugging)"""
from pprint import pprint
if "rmesh" in what:
Expand All @@ -307,17 +302,17 @@ def pprint(self, what: str = "rmesh+values", stream=None):
pprint(self.values, stream=stream)

@property
def rmax(self):
def rmax(self) -> float:
"""Outermost point of the radial mesh."""
return self.rmesh[-1]

@property
def rsize(self):
def rsize(self) -> float:
"""Size of the radial mesh."""
return len(self.rmesh)

@property
def minmax_ridx(self):
def minmax_ridx(self) -> Tuple[int, int]:
"""
Returns the indices of the values in a list with the maximum and minimum value.
"""
Expand All @@ -326,8 +321,10 @@ def minmax_ridx(self):
return minimum[0], maximum[0]

@property
def inodes(self):
""""List with the index of the nodes of the radial function."""
def inodes(self) -> List[int]:
""""
List with the index of the nodes of the radial function.
"""
inodes = []
for i in range(len(self.values)-1):
if self.values[i] * self.values[i+1] <= 0:
Expand All @@ -352,7 +349,7 @@ def derivatives(self, r):
"""Return all derivatives of the spline at the point r."""
return self.spline.derivatives(r)

def integral(self):
def integral(self) -> RadialFunction:
r"""
Cumulatively integrate y(x) using the composite trapezoidal rule.
Expand All @@ -369,6 +366,7 @@ def integral(self):
def integral3d(self, a=None, b=None):
"""
Return definite integral of the spline of (r**2 values**2) between two given points a and b
Args:
a: First point. rmesh[0] if a is None
b: Last point. rmesh[-1] if a is None
Expand All @@ -380,17 +378,19 @@ def integral3d(self, a=None, b=None):
return r2v2_spline.integral(a, b)

def ifromr(self, rpoint):
"""The index of the point."""
"""
The index of the point in the radial mesh.
"""
for (i, r) in enumerate(self.rmesh):
if r > rpoint:
return i-1
return i - 1

if rpoint == self.rmesh[-1]:
return len(self.rmesh)
else:
raise ValueError("Cannot find %s in rmesh" % rpoint)

def ir_small(self, abs_tol=0.01):
def ir_small(self, abs_tol: float = 0.01) -> int:
"""
Returns the rightmost index where the abs value of the wf becomes greater than abs_tol
Expand Down Expand Up @@ -425,7 +425,7 @@ def r2f_integral(self):
pad_intg[1:] = integ
return pad_intg

def get_intr2j0(self, ecut, numq=3001):
def get_intr2j0(self, ecut: float, numq: float = 3001):
r"""
Compute 4\pi\int[(\frac{\sin(2\pi q r)}{2\pi q r})(r^2 n(r))dr].
"""
Expand Down
Loading

0 comments on commit 33d58bb

Please sign in to comment.