-
Notifications
You must be signed in to change notification settings - Fork 5
Neural network potentials
The goal of ModelForge's NNP initiative is to streamline computational chemistry research by offering pre-trained, ready-to-use neural network potentials. This approach allows researchers to bypass the computationally expensive task of training NNPs from scratch and focus directly on their scientific inquiries, such as molecular simulations, free energy calculations, and property prediction.
Each neural network potential (NNP) can be separated into three core modules, each with distinct responsibilities in the workflow:
-
Neighborlist Module: The neighbor list consumes the atomic positions
$\vec{R}$ ⃗ for$N$ atoms and generates interaction pairs. Based on a predefined cutoff radius$r_c$ the neighborlist determines which atom pairs interact. If$P$ is the number of interacting pairs, it holds that$P \le N\cdot N$ , though typically fewer pairs are generated due to the cutoff distance. This module outputs pairwise information like distance vectors and distances between atoms. -
Core Network Module: The core network consumes the pairwise information generated by the neighborlist, along with information about atomic identity (e.g., atomic numbers). It processes this information through several layers and learns atomic properties (e.g., energies, partial charges) by using message-passing algorithms. In this context, the core network predicts
$K$ atomic properties for each atom. - Postprocessing Module: The postprocessing module normalizes and aggregates the atomic properties from the core network to generate meaningful molecular-level outputs. This module can either retain atomic properties or reduce them to molecular properties (e.g., total energy) via summation or other reduction methods.
During training, data is typically processed in batches of
In every message-passing neural network (MPNN), there are two fundamental components: the message function and the update function.
- Message function: the message function computes the information passed between neighboring atoms. This function depends on both node (e.g., atomic properties or learned feature vectors) and edge (e.g., pairwise distances between atoms) features, ensuring that atomic interactions are encoded.
A general form of the message function is:
where:
-
$x_i^t$ and$x_j^t$ are the feature vectors of atom$i$ and$j$ at iteration$t$ -
$e_{i,j}$ is the edge feature, e.g. the featurized distance between atom$i$ and$j$ -
$M$ is the message-passing function that learns to propagate relevant information between atoms
- Update function: after receiving the messages from neighboring atoms, the atomic feature vectors are updated through an update function. The update function aggregates the incoming messages and transforms the atomic feature vectors accordingly.
The general update function can be written as:
where:
-
$x_i^t$ is the current state of atom$i$ at iteration$t$ -
$\sum_{j\in N(i)} m_{i,j}$ is the sum of messages from neighboring atoms -
$U$ is any learnable operation that updates the atomic features
Conceptually, message-passing neural networks operate by iteratively updating atomic feature vectors based on the messages received from neighboring atoms. The architecture of MPNNs can be further extended with various modifications, such as incorporating edge features (like radial symmetry functions), using attention mechanisms, or enforcing equivariance for modeling directional interactions.
We have selected several state-of-the-art NNP models with great inference speed, training efficiency, and adaptability for systematic refinement. These include SchNet, PaiNN, SAKE, TensorNet, AimNet2, PhysNet, DimNet++, and ANI2x.
In the following
One of the earliest 2nd generation deep neural networks, SchNet,
includes many ideas that are still relevant. The network is based on pairwise distances and is invariant under the Euclidean group. The SchNet
architecture uses a learnable atom embedding for each atom based on its atomic number
Message function: The message function for atom pair (i,j) is:
with
-
$m_{i,j} \in \mathbb{R^F}$ is the message from atom$j$ to atom$i$ -
$W : \mathbb{R}^K \mapsto \mathbb{R}^F$ are learnable distance filter that depend on the distance$d_{i,j}$ between atom$i$ and$j$ , -
$g: \mathbb{R} \mapsto \mathbb{R}^K$ is the radial symmetry basis -
$d_{i,j} \in \mathbb{R}$ is the scalar distance between atom$i$ and$j$ -
$s_j : \mathbb{N} \mapsto \mathbb{R}^F$ $ as the scalar embedding (feature) of atom$j$ -
$\odot$ denotes element-wise multiplication (Hadamard product).
or, more explicitly:
with
Update function: The update function for atomic feature vector
where:
-
$s_i^t \in \mathbb{R^F}$ is the atomic feature vector of atom$i$ at iteration$t$ -
$N(i)$ denotes the neighborhood of atom$i$ with cutoff$r_c$ -
$U$ is a learnable update function
The final updated feature vector
Further details are given in the publication: Schütt, K. T., H. E. Sauceda, P. J. Kindermans, A. Tkatchenko, and K. R. Müller. 2018. “SchNet - A Deep Learning Architecture for Molecules and Materials.” The Journal of Chemical Physics 148 (24): 241722.
It is recommended to share interaction layers, otherwise the memory footprint becomes large quickly.
PaiNN is an equivariant graph neural network designed to model atomic interactions by considering interatomic distances and directional information. This architecture allows PaiNN to predict properties that depend on scalar values (such as energy) and tensorial quantities (such as forces and dipole moments), all while ensuring the model respects the fundamental symmetries of physical space.
PaiNN employs a message-passing mechanism that exchanges information between atoms, using both scalar and vector features. Each atom
- scalar feature
$s_i \in \mathbb{R^F}$ - vector features
$\vec{v}_i \in \mathbb{R}^{3xF}$ , used to model directional information (e.g., dipoles).
Message function: While the scalar message-passing mechanism is identical to that of SchNet, PaiNN introduces an additional vector message-passing mechanism to handle directional interactions. For a given atom pair
-
$m_{i,j} \in \mathbb{R^F}$ is the message from atom$j$ to atom$i$ -
$W : \mathbb{R}^K \mapsto \mathbb{R}^F$ are learnable distance filter that depend on the distance$d_{i,j}$ between atom$i$ and$j$ -
$g: \mathbb{R} \mapsto \mathbb{R}^K$ is the radial symmetry basis -
$d_{i,j} \in \mathbb{R}$ is the scalar distance between atom$i$ and$j$ -
$s_j : \mathbb{N} \mapsto \mathbb{R}^F$ $ as the scalar embedding (feature) of atom$j$
and the scalar message function is
Note that both terms are conceptually similar to the SchNet architecture:
-
$\vec{v}_j \odot W_{v,1}(g(d_{i,j})) \odot s_j$ modulates the feature vector component-wise using as a function of distance ($d_{i,j}$ ) and atom identity$s_j$ , - W_{v,2}(g(d_{i,j})) \odot s_j \frac{\vec{r}_{i,j}}{\lVert\vec{r}_{i,j}\rVert} modulates the scalar representation of SchNet by the normalized distance vector.
Update function: The update function for the vector feature in PaiNN is complicated. Scalar features and vecture features are mixing both in the message and update block. The residula of the scalar update is:
$$ \Delta s_i^{u} = a_{ss}(s_i, \lVert V\vec{i}i \rVert) + a{sv}(s_{i}, lVert V\vec{i}_i \rVert) <U\vec{v}_i, V\vec{v}_i> $$
with
-
$s_j \in \mathbb{R}^F$ as the scalar embedding (feature) of atom$j$ -
$v_j \in \mathbb{R}^{3xF}$ as the vector features -
$U, V \in \mathbb{R}^{FxF}$ as linear transformations
and the residual of the vector update is:
For further detail, see Schütt, Kristof T., Sch¨ Schütt, Oliver T. Unke, and Michael Gastegger. 2021. “Equivariant Message Passing for the Prediction of Tensorial Properties and Molecular Spectra.”
AIMNet2 is a third-generation atoms-in-molecules neural network potential (NNP) designed for high generalizability across a broad range of organic species. The architecture extends traditional machine learning interatomic potentials (MLIPs) by incorporating both short-range, ML-parameterized terms and long-range physics-based interactions, such as electrostatics and dispersion. AIMNet2 can handle systems composed of multiple chemical elements, including charged species, making it versatile for applications in both molecular and macromolecular simulations.
In AIMNet2, the message-passing framework is central to capturing atomic interactions. Atomic embeddings, initially based on atomic numbers and possibly other atomic features, are iteratively updated during each message-passing step to refine the representation of the chemical environment. Both scalar and vector embeddings are used to describe atomic features, capturing isotropic (distance-dependent) and anisotropic (direction-dependent) interactions.
The message function in AIMNet2 incorporates both radial (scalar) and vector (angular) contributions to capture comprehensive atomic interactions.
For each atom pair
-
$g_{i,j} \in \mathbb{R}^{G}$ : Radial symmetry functions multiplied by the cutoff function, evaluated at the distance$d_{ij}$ between atoms$i$ and$j$ . -
$\mathbf{s}_{j} \in \mathbb{R}^{F}$ : Scalar embedding of atom$j$ . -
$\odot$ : Element-wise multiplication with appropriate broadcasting.
This operation results in a tensor of shape
In the code, this corresponds to:
# Compute radial contributions
avf_s = gs.unsqueeze(-1) * a_j.unsqueeze(1) # Shape: (num_pairs, G, F_atom)
-
gs
: Radial symmetry functions (f_ij * f_cutoff
), shape(num_pairs, G)
. -
a_j
: Atomic features of atom$j$ , shape(num_pairs, F_atom)
.
After summing over the radial basis functions
avf_s = avf_s.sum(dim=1) # Shape: (num_pairs, F_atom)
These contributions are then aggregated per atom
radial_contributions = torch.zeros(
(number_of_atoms, F_atom),
device=avf_s.device,
dtype=avf_s.dtype,
)
radial_contributions.index_add_(0, idx_j, avf_s)
To capture angular dependencies, the vector contributions are calculated using directional information between atoms and a learnable transformation tensor.
- Unit Direction Vector:
-
$r_{i,j} \in \mathbb{R}^{3}$ : Displacement vector from atom$i$ to atom$j$ . -
$d_{i,j}$ : Distance between atoms$i$ and$j$ .
- Vector Symmetry Functions:
-
$\otimes$ : Outer product, resulting in a tensor of shape$(3, G)$ . -
$\mathbf{g}_{ij} \in \mathbb{R}^{G}$ : Radial symmetry functions.
In the code:
# Unit direction vectors
u_ij = r_ij / d_ij # Shape: (num_pairs, 3)
# Compute gv with shape (num_pairs, 3, G)
gv = u_ij.unsqueeze(-1) * gs.unsqueeze(1) # Shape: (num_pairs, 3, G)
-
Vector Message Computation:
The vector contributions are computed using an Einstein summation over atomic features, vector symmetry functions, and a learnable tensor
$\mathbf{A}$ :
-
$s_{j}^{f}$ : Scalar feature$f$ of atom$j$ . -
$v_{ij}^{d,g}$ : Component$d$ of vector symmetry function for basis$g$ . -
$A^{f h g} \in \mathbb{R}^{F \times H \times G}$ : Learnable transformation tensor (agh
in the code). -
$h$ : Index over vector feature dimensions$H$ . -
$d$ : Spatial dimensions (3D vectors).
In the code:
# Compute per-pair vector contributions
# avf_v: Shape (num_pairs, H, 3)
avf_v = torch.einsum("pa, pdg, afh -> phd", a_j, gv, agh)
-
a_j
: Atomic features of atom$j$ , shape(num_pairs, F_atom)
. -
gv
: Vector symmetry functions, shape(num_pairs, 3, G)
. -
agh
: Learnable tensor, shape(F_atom, H, G)
.
Note: The indices in the Einstein summation are arranged to match the dimensions:
-
pa
: Pair index$p$ , atomic feature$a$ . -
pdg
: Pair index$p$ , spatial dimension$d$ , radial basis$g$ . -
afh
: Atomic feature$a$ , vector feature$h$ , radial basis$g$ . - The result
avf_v
has shape(num_pairs, H, 3)
.
-
Aggregation of Vector Contributions:
The vector contributions are aggregated per atom
$i$ :
# Initialize tensor to accumulate vector contributions per atom
avf_v_sum = torch.zeros(
(number_of_atoms, H, 3),
device=device,
dtype=avf_v.dtype,
)
# Aggregate per atom by summing the vectors
avf_v_sum.index_add_(0, idx_j, avf_v) # Shape: (number_of_atoms, H, 3)
-
Compute Norm of Vector Contributions:
The norm over the spatial dimensions is computed:
$$ {v}{i}^{h} = \left| avf{v}\sum_{i}^{h} \right| $$
In the code:
vector_contributions = torch.norm(avf_v_sum, dim=-1) # Shape: (number_of_atoms, H)
For each atom
-
Radial Contributions: Accumulated in
radial_contributions
, shape(number_of_atoms, F_atom)
. -
Vector Contributions: Stored in
vector_contributions
, shape(number_of_atoms, H)
.
These are combined to form the combined message for each atom.
The update function integrates scalar and vector messages to refine atomic embeddings.
The combined message for each atom
- First Interaction Module:
$$ \text{combined_message}{i} = \left[ \mathbf{m}{i}^{s}, \mathbf{v}_{i} \right] $$
- Subsequent Interaction Modules:
$$ \text{combined_message}{i} = \left[ \mathbf{m}{i}^{s}, \mathbf{v}{i}, \mathbf{m}{i}^{s,\text{charge}}, \mathbf{v}_{i}^{\text{charge}} \right] $$
- $\mathbf{m}{i}^{s,\text{charge}}$ and $\mathbf{v}{i}^{\text{charge}}$ are contributions computed from partial charges.
In the code:
if not self.is_first_module:
# Combine messages from embeddings and charges
combined_message = torch.cat(
[
radial_contributions_emb, # (N, F_atom)
vector_contributions_emb, # (N, H)
radial_contributions_charge, # (N, 1)
vector_contributions_charge, # (N, H)
],
dim=1,
)
else:
combined_message = torch.cat(
[
radial_contributions_emb, # (N, F_atom)
vector_contributions_emb, # (N, H)
],
dim=1,
)
The combined message is passed through a single MLP to produce updates for atomic embeddings, partial charges, and a scaling factor:
$$ \begin{align*} \text{output} &= \text{MLP}(\text{combined_message}{i}) \ \left[ \Delta q{i}, f_{i}, \Delta \mathbf{a}_{i} \right] &= \text{split}(\text{output}) \end{align*} $$
In the code:
# Pass combined message through single MLP
out = self.mlp(combined_message)
# Split the output tensor into delta_q, f, and delta_a
delta_q, f, delta_a = torch.split(
out, [1, 1, self.number_of_per_atom_features], dim=1
)
-
delta_q
: Update for partial charges, shape(N, 1)
. -
f
: Scaling factor for charge updates, shape(N, 1)
. -
delta_a
: Update for atomic embeddings, shape(N, F_atom)
.
The atomic embeddings are updated:
# Update atomic embeddings
atomic_embedding = atomic_embedding + delta_a
The partial charges are updated with the scaling factor:
# Apply scaling factor `f` to `delta_q`
scaled_delta_q = f * delta_q
# Update partial charges
if is_first_module:
partial_charges = scaled_delta_q # Initialize charges
else:
partial_charges = partial_charges + scaled_delta_q # Incremental update
Charge conservation is enforced after each update.
Note: The code snippets provided correspond directly to the mathematical operations described, ensuring clarity and alignment between the implementation and theoretical formulation.
PhysNet is a neural network architecture designed for the prediction of energies, forces, and dipole moments of molecular systems. It incorporates both short-range machine-learned interactions and long-range physics-based terms such as Coulomb interactions. This makes PhysNet particularly suitable for capturing a variety of chemical environments, including those with significant long-range interactions such as non-covalent complexes and charged systems.
Message function: PhysNet (an invariant, 3rd generation network) uses a message-passing scheme similar to other graph neural networks (GNNs), where atomic features are iteratively updated by exchanging information between neighboring atoms. For each pair of atoms
where:
-
$m_{i,j}^{l}\in \mathbb{R}^F$ is the message representing the interaction of atom$i$ with atom$j$ at layer$l$ -
$g(d_{ij}) \in \mathbb{R}^K$ is the featurized distance between atoms$i$ and$j$ $b_j \in \mathbb{R}^F$ -
$\mathbf{W}_j \in \mathbb{R}^{FxF}$ is the scalar embedding of atom$j$ -
$\mathbf{G}^l \in \mathbb{R}^{FxK}$ is the attention mask selecting features based on pairwise distances -
$\odot$ denotes element-wise multiplication.
This message is then passed through several layers to form the input for the update function.
Update function: After receiving the messages from its neighbors, each atom updates its atomic embedding
where:
-
$x_i^l \in \mathbb{R}^{F}$ is the feature vector of atom$i$ at iteration$l$ -
$N(i)$ is the set of neighbors of atom$i$ within a specified cutoff distance -
$m_{i,j}$ is the message received from atom$j$ -
$U$ is the update function, a learnable transformation applied to the aggregated messages.
What is particularly interesting in PhysNet is the update function
This is a straightforward update function for node updates in neural networks, but
Unke, Oliver T., and Markus Meuwly. 2019. “PhysNet: A Neural Network for Predicting Energies, Forces, Dipole Moments and Partial Charges.” https://doi.org/10.1021/acs.jctc.9b00181.
[coming soon]
TensorNet is a 3rd generation neural network potential (NNP) that is able to capture up to rank-2 tensor properties of an atomic system. The central idea of TensorNet is the irreducible tensor decomposition and the corresponding Cartesian tensor representation. For any rank-2 tensor
Here,
and
The radial basis functions are defined as
Now that we have the irreducible representation of displacement vectors and the radial basis functions, we can define the initial messages and update messages.
Initialization:
where radial symmetry functions are mapped with linear layers:
Update:
Now update
For futher details, refer to the TensorNet arXiv paper.