Skip to content

Neural network potentials

Marcus Wieder edited this page Nov 5, 2024 · 85 revisions

Introduction

The goal of ModelForge's NNP initiative is to streamline computational chemistry research by offering pre-trained, ready-to-use neural network potentials. This approach allows researchers to bypass the computationally expensive task of training NNPs from scratch and focus directly on their scientific inquiries, such as molecular simulations, free energy calculations, and property prediction.

Structure of a neural network potential

Each neural network potential (NNP) can be separated into three core modules, each with distinct responsibilities in the workflow:

  • Neighborlist Module: The neighbor list consumes the atomic positions $\vec{R}$⃗ for $N$ atoms and generates interaction pairs. Based on a predefined cutoff radius $r_c$ the neighborlist determines which atom pairs interact. If $P$ is the number of interacting pairs, it holds that $P \le N\cdot N$ , though typically fewer pairs are generated due to the cutoff distance. This module outputs pairwise information like distance vectors and distances between atoms.
  • Core Network Module: The core network consumes the pairwise information generated by the neighborlist, along with information about atomic identity (e.g., atomic numbers). It processes this information through several layers and learns atomic properties (e.g., energies, partial charges) by using message-passing algorithms. In this context, the core network predicts $K$ atomic properties for each atom.
  • Postprocessing Module: The postprocessing module normalizes and aggregates the atomic properties from the core network to generate meaningful molecular-level outputs. This module can either retain atomic properties or reduce them to molecular properties (e.g., total energy) via summation or other reduction methods.

During training, data is typically processed in batches of $M$ molecules. This requires no architectural changes to the network but changes the meaning of $N$, which now refers to the total number of atoms in the batch, i.e., $\sum_{m=1}^{M}N_m$. It is the responsibility of the neighborlist to ensure that pairs are only formed between atoms within the same molecule.

Message-Passing Networks: General Structure

In every message-passing neural network (MPNN), there are two fundamental components: the message function and the update function.

  1. Message function: the message function computes the information passed between neighboring atoms. This function depends on both node (e.g., atomic properties or learned feature vectors) and edge (e.g., pairwise distances between atoms) features, ensuring that atomic interactions are encoded.

A general form of the message function is:

$$ m_{i,j} = M(x_i^t, x_j^t, e_{i,j}) $$

where:

  • $x_i^t$ and $x_j^t$ are the feature vectors of atom $i$ and $j$ at iteration $t$
  • $e_{i,j}$ is the edge feature, e.g. the featurized distance between atom $i$ and $j$
  • $M$ is the message-passing function that learns to propagate relevant information between atoms
  1. Update function: after receiving the messages from neighboring atoms, the atomic feature vectors are updated through an update function. The update function aggregates the incoming messages and transforms the atomic feature vectors accordingly.

The general update function can be written as:

$$ x_i^{t+1} = U (x_i^t + \sum_{j\in N(i)} m_{i,j}) $$

where:

  • $x_i^t$ is the current state of atom $i$ at iteration $t$
  • $\sum_{j\in N(i)} m_{i,j}$ is the sum of messages from neighboring atoms
  • $U$ is any learnable operation that updates the atomic features

Conceptually, message-passing neural networks operate by iteratively updating atomic feature vectors based on the messages received from neighboring atoms. The architecture of MPNNs can be further extended with various modifications, such as incorporating edge features (like radial symmetry functions), using attention mechanisms, or enforcing equivariance for modeling directional interactions.

Implemented potentials

We have selected several state-of-the-art NNP models with great inference speed, training efficiency, and adaptability for systematic refinement. These include SchNet, PaiNN, SAKE, TensorNet, AimNet2, PhysNet, DimNet++, and ANI2x. In the following $F$ will be used as the number of embedding features and $K$ as the number of radial symmetry functions.

SchNet

One of the earliest 2nd generation deep neural networks, SchNet, includes many ideas that are still relevant. The network is based on pairwise distances and is invariant under the Euclidean group. The SchNet architecture uses a learnable atom embedding for each atom based on its atomic number $Z$. Interactions are modeled using a continuous-filter convolution layer $W$. The network can be classified as a graph neural network using learned embeddings and message passing to model atom interactions based on interatomic distances.

Message function: The message function for atom pair (i,j) is:

$$ m_{i,j} = W(g(d_{i,j})) \odot s_j $$

with

  • $m_{i,j} \in \mathbb{R^F}$ is the message from atom $j$ to atom $i$
  • $W : \mathbb{R}^K \mapsto \mathbb{R}^F$ are learnable distance filter that depend on the distance $d_{i,j}$ between atom $i$ and $j$,
  • $g: \mathbb{R} \mapsto \mathbb{R}^K$ is the radial symmetry basis
  • $d_{i,j} \in \mathbb{R}$ is the scalar distance between atom $i$ and $j$
  • $s_j : \mathbb{N} \mapsto \mathbb{R}^F$$ as the scalar embedding (feature) of atom $j$
  • $\odot$ denotes element-wise multiplication (Hadamard product).

or, more explicitly:

$$ m_{i,j,f} = W_{i,j,f} \cdot s_{j,f} \\ $$

with $f= 1..F$.

Update function: The update function for atomic feature vector $s_i$ at iteration $t$ is

$$ s_i^{t+1} = U (s_i^t + \sum_{j \in N(i)} m_{i,j}), $$

where:

  • $s_i^t \in \mathbb{R^F}$ is the atomic feature vector of atom $i$ at iteration $t$
  • $N(i)$ denotes the neighborhood of atom $i$ with cutoff $r_c$
  • $U$ is a learnable update function

The final updated feature vector $x_{i}^{t+1}$ remains in the same feature space, so its dimension is $x_{i}^{t+1} \in \mathbb{R^F}$.

Further details are given in the publication: Schütt, K. T., H. E. Sauceda, P. J. Kindermans, A. Tkatchenko, and K. R. Müller. 2018. “SchNet - A Deep Learning Architecture for Molecules and Materials.” The Journal of Chemical Physics 148 (24): 241722.

What we learned about SchNet:

It is recommended to share interaction layers, otherwise the memory footprint becomes large quickly.

PaiNN (Polarizable Interaction Neural Network)

PaiNN is an equivariant graph neural network designed to model atomic interactions by considering interatomic distances and directional information. This architecture allows PaiNN to predict properties that depend on scalar values (such as energy) and tensorial quantities (such as forces and dipole moments), all while ensuring the model respects the fundamental symmetries of physical space.

PaiNN employs a message-passing mechanism that exchanges information between atoms, using both scalar and vector features. Each atom $i$ is represented by:

  • scalar feature $s_i \in \mathbb{R^F}$
  • vector features $\vec{v}_i \in \mathbb{R}^{3xF}$, used to model directional information (e.g., dipoles).

Message function: While the scalar message-passing mechanism is identical to that of SchNet, PaiNN introduces an additional vector message-passing mechanism to handle directional interactions. For a given atom pair $(i,j)$, the vector message function is expressed as:

$$ m_{i,j} = \vec{v}_j \odot W_{v,1}(g(d_{i,j})) \odot s_j + W_{v,2}(g(d_{i,j})) \odot s_j \frac{\vec{r}_{i,j}}{\lVert\vec{r}_{i,j}\rVert} $$

  • $m_{i,j} \in \mathbb{R^F}$ is the message from atom $j$ to atom $i$
  • $W : \mathbb{R}^K \mapsto \mathbb{R}^F$ are learnable distance filter that depend on the distance $d_{i,j}$ between atom $i$ and $j$
  • $g: \mathbb{R} \mapsto \mathbb{R}^K$ is the radial symmetry basis
  • $d_{i,j} \in \mathbb{R}$ is the scalar distance between atom $i$ and $j$
  • $s_j : \mathbb{N} \mapsto \mathbb{R}^F$$ as the scalar embedding (feature) of atom $j$

and the scalar message function is

$$ m_{i,j} = W_{s}(g(d_{i,j})) \odot s_j $$

Note that both terms are conceptually similar to the SchNet architecture:

  1. $\vec{v}_j \odot W_{v,1}(g(d_{i,j})) \odot s_j$ modulates the feature vector component-wise using as a function of distance ($d_{i,j}$) and atom identity $s_j$,
  2. W_{v,2}(g(d_{i,j})) \odot s_j \frac{\vec{r}_{i,j}}{\lVert\vec{r}_{i,j}\rVert} modulates the scalar representation of SchNet by the normalized distance vector.

Update function: The update function for the vector feature in PaiNN is complicated. Scalar features and vecture features are mixing both in the message and update block. The residula of the scalar update is:

$$ \Delta s_i^{u} = a_{ss}(s_i, \lVert V\vec{i}i \rVert) + a{sv}(s_{i}, lVert V\vec{i}_i \rVert) <U\vec{v}_i, V\vec{v}_i> $$

with

  • $s_j \in \mathbb{R}^F$ as the scalar embedding (feature) of atom $j$
  • $v_j \in \mathbb{R}^{3xF}$ as the vector features
  • $U, V \in \mathbb{R}^{FxF}$ as linear transformations

and the residual of the vector update is:

$$ \Delta \vec{v}_i^{u} = a_{vv}(s_i, \lVert V\vec{i}_i \rVert) U\vec{v}_i $$

For further detail, see Schütt, Kristof T., Sch¨ Schütt, Oliver T. Unke, and Michael Gastegger. 2021. “Equivariant Message Passing for the Prediction of Tensorial Properties and Molecular Spectra.”

AimNet2 (Atoms-in-molecules neural network potential)

AIMNet2 is a third-generation atoms-in-molecules neural network potential (NNP) designed for high generalizability across a broad range of organic species. The architecture extends traditional machine learning interatomic potentials (MLIPs) by incorporating both short-range, ML-parameterized terms and long-range physics-based interactions, such as electrostatics and dispersion. AIMNet2 can handle systems composed of multiple chemical elements, including charged species, making it versatile for applications in both molecular and macromolecular simulations.

In AIMNet2, the message-passing framework is central to capturing atomic interactions. Atomic embeddings, initially based on atomic numbers and possibly other atomic features, are iteratively updated during each message-passing step to refine the representation of the chemical environment. Both scalar and vector embeddings are used to describe atomic features, capturing isotropic (distance-dependent) and anisotropic (direction-dependent) interactions.

Message Function

The message function in AIMNet2 incorporates both radial (scalar) and vector (angular) contributions to capture comprehensive atomic interactions.

Radial (Scalar) Contributions

For each atom pair $i, j$, the radial contribution from atom $j$ to atom $i$ is calculated as:

$$ m_{ij}^{s} = g_{i,j} \odot s_{j} $$

  • $g_{i,j} \in \mathbb{R}^{G}$: Radial symmetry functions multiplied by the cutoff function, evaluated at the distance $d_{ij}$ between atoms $i$ and $j$.
  • $\mathbf{s}_{j} \in \mathbb{R}^{F}$: Scalar embedding of atom $j$.
  • $\odot$: Element-wise multiplication with appropriate broadcasting.

This operation results in a tensor of shape $(G, F)$, combining radial basis functions with atomic features.

In the code, this corresponds to:

# Compute radial contributions
avf_s = gs.unsqueeze(-1) * a_j.unsqueeze(1)  # Shape: (num_pairs, G, F_atom)
  • gs: Radial symmetry functions (f_ij * f_cutoff), shape (num_pairs, G).
  • a_j: Atomic features of atom $j$, shape (num_pairs, F_atom).

After summing over the radial basis functions $G$, we get:

avf_s = avf_s.sum(dim=1)  # Shape: (num_pairs, F_atom)

These contributions are then aggregated per atom $i$:

radial_contributions = torch.zeros(
    (number_of_atoms, F_atom),
    device=avf_s.device,
    dtype=avf_s.dtype,
)
radial_contributions.index_add_(0, idx_j, avf_s)
Vector (Angular) Contributions

To capture angular dependencies, the vector contributions are calculated using directional information between atoms and a learnable transformation tensor.

  1. Unit Direction Vector:

$$ u_{ij} = \frac{r_{i,j}}{d_{i,j}} $$

  • $r_{i,j} \in \mathbb{R}^{3}$: Displacement vector from atom $i$ to atom $j$.
  • $d_{i,j}$: Distance between atoms $i$ and $j$.
  1. Vector Symmetry Functions:

$$ v_{i,j} = u_{i,j} \otimes g_{i,j} $$

  • $\otimes$: Outer product, resulting in a tensor of shape $(3, G)$.
  • $\mathbf{g}_{ij} \in \mathbb{R}^{G}$: Radial symmetry functions.

In the code:

# Unit direction vectors
u_ij = r_ij / d_ij  # Shape: (num_pairs, 3)

# Compute gv with shape (num_pairs, 3, G)
gv = u_ij.unsqueeze(-1) * gs.unsqueeze(1)  # Shape: (num_pairs, 3, G)
  1. Vector Message Computation:

    The vector contributions are computed using an Einstein summation over atomic features, vector symmetry functions, and a learnable tensor $\mathbf{A}$:

$$ m_{i,j}^{v,h,d} = \sum_{f, g} s_{j}^{f} \cdot v_{i,j}^{d,g} \cdot A^{f h g} $$

  • $s_{j}^{f}$: Scalar feature $f$ of atom $j$.
  • $v_{ij}^{d,g}$: Component $d$ of vector symmetry function for basis $g$.
  • $A^{f h g} \in \mathbb{R}^{F \times H \times G}$: Learnable transformation tensor (agh in the code).
  • $h$: Index over vector feature dimensions $H$.
  • $d$: Spatial dimensions (3D vectors).

In the code:

# Compute per-pair vector contributions
# avf_v: Shape (num_pairs, H, 3)
avf_v = torch.einsum("pa, pdg, afh -> phd", a_j, gv, agh)
  • a_j: Atomic features of atom $j$, shape (num_pairs, F_atom).
  • gv: Vector symmetry functions, shape (num_pairs, 3, G).
  • agh: Learnable tensor, shape (F_atom, H, G).

Note: The indices in the Einstein summation are arranged to match the dimensions:

  • pa: Pair index $p$, atomic feature $a$.
  • pdg: Pair index $p$, spatial dimension $d$, radial basis $g$.
  • afh: Atomic feature $a$, vector feature $h$, radial basis $g$.
  • The result avf_v has shape (num_pairs, H, 3).
  1. Aggregation of Vector Contributions:

    The vector contributions are aggregated per atom $i$:

# Initialize tensor to accumulate vector contributions per atom
avf_v_sum = torch.zeros(
    (number_of_atoms, H, 3),
    device=device,
    dtype=avf_v.dtype,
)

# Aggregate per atom by summing the vectors
avf_v_sum.index_add_(0, idx_j, avf_v)  # Shape: (number_of_atoms, H, 3)
  1. Compute Norm of Vector Contributions:

    The norm over the spatial dimensions is computed:

$$ {v}{i}^{h} = \left| avf{v}\sum_{i}^{h} \right| $$

In the code:

vector_contributions = torch.norm(avf_v_sum, dim=-1)  # Shape: (number_of_atoms, H)

Message Aggregation

For each atom $i$, messages from neighboring atoms are aggregated:

  • Radial Contributions: Accumulated in radial_contributions, shape (number_of_atoms, F_atom).
  • Vector Contributions: Stored in vector_contributions, shape (number_of_atoms, H).

These are combined to form the combined message for each atom.

Update Function

The update function integrates scalar and vector messages to refine atomic embeddings.

Combined Message

The combined message for each atom $i$ is constructed differently depending on whether it's the first interaction module or not.

  • First Interaction Module:

$$ \text{combined_message}{i} = \left[ \mathbf{m}{i}^{s}, \mathbf{v}_{i} \right] $$

  • Subsequent Interaction Modules:

$$ \text{combined_message}{i} = \left[ \mathbf{m}{i}^{s}, \mathbf{v}{i}, \mathbf{m}{i}^{s,\text{charge}}, \mathbf{v}_{i}^{\text{charge}} \right] $$

  • $\mathbf{m}{i}^{s,\text{charge}}$ and $\mathbf{v}{i}^{\text{charge}}$ are contributions computed from partial charges.

In the code:

if not self.is_first_module:
    # Combine messages from embeddings and charges
    combined_message = torch.cat(
        [
            radial_contributions_emb,    # (N, F_atom)
            vector_contributions_emb,    # (N, H)
            radial_contributions_charge, # (N, 1)
            vector_contributions_charge, # (N, H)
        ],
        dim=1,
    )
else:
    combined_message = torch.cat(
        [
            radial_contributions_emb,    # (N, F_atom)
            vector_contributions_emb,    # (N, H)
        ],
        dim=1,
    )
Atomic Embedding and Charge Updates

The combined message is passed through a single MLP to produce updates for atomic embeddings, partial charges, and a scaling factor:

$$ \begin{align*} \text{output} &= \text{MLP}(\text{combined_message}{i}) \ \left[ \Delta q{i}, f_{i}, \Delta \mathbf{a}_{i} \right] &= \text{split}(\text{output}) \end{align*} $$

In the code:

# Pass combined message through single MLP
out = self.mlp(combined_message)

# Split the output tensor into delta_q, f, and delta_a
delta_q, f, delta_a = torch.split(
    out, [1, 1, self.number_of_per_atom_features], dim=1
)
  • delta_q: Update for partial charges, shape (N, 1).
  • f: Scaling factor for charge updates, shape (N, 1).
  • delta_a: Update for atomic embeddings, shape (N, F_atom).

The atomic embeddings are updated:

# Update atomic embeddings
atomic_embedding = atomic_embedding + delta_a

The partial charges are updated with the scaling factor:

# Apply scaling factor `f` to `delta_q`
scaled_delta_q = f * delta_q

# Update partial charges
if is_first_module:
    partial_charges = scaled_delta_q  # Initialize charges
else:
    partial_charges = partial_charges + scaled_delta_q  # Incremental update

Charge conservation is enforced after each update.

Note: The code snippets provided correspond directly to the mathematical operations described, ensuring clarity and alignment between the implementation and theoretical formulation.

PhysNet (Physically Informed Neural Network)

PhysNet is a neural network architecture designed for the prediction of energies, forces, and dipole moments of molecular systems. It incorporates both short-range machine-learned interactions and long-range physics-based terms such as Coulomb interactions. This makes PhysNet particularly suitable for capturing a variety of chemical environments, including those with significant long-range interactions such as non-covalent complexes and charged systems.

Message function: PhysNet (an invariant, 3rd generation network) uses a message-passing scheme similar to other graph neural networks (GNNs), where atomic features are iteratively updated by exchanging information between neighboring atoms. For each pair of atoms $(i,j)$, the message function depends on the embedding of atom $i$ and atom $j$, the interatomic distance $r_{ij}$ and involves radial basis functions to encode distance information:

$$ m_{i,j}^{l} = G^{l}g(d_{i,j}) \odot \sigma(W_j^l\sigma(x_j^l) + b_j^l) $$

where:

  • $m_{i,j}^{l}\in \mathbb{R}^F$ is the message representing the interaction of atom $i$ with atom $j$ at layer $l$
  • $g(d_{ij}) \in \mathbb{R}^K$ is the featurized distance between atoms $i$ and $j$
  • $b_j \in \mathbb{R}^F$
  • $\mathbf{W}_j \in \mathbb{R}^{FxF}$ is the scalar embedding of atom $j$
  • $\mathbf{G}^l \in \mathbb{R}^{FxK}$ is the attention mask selecting features based on pairwise distances
  • $\odot$ denotes element-wise multiplication.

This message is then passed through several layers to form the input for the update function.

Update function: After receiving the messages from its neighbors, each atom updates its atomic embedding $x_i$ through an update function that aggregates the incoming messages. The update function is similar to other MPNN frameworks and is expressed as:

$$ x_{i}^{t+1} = U(x_i^t + \sum_{j \in N(i)} m_{i,j}) $$

where:

  • $x_i^l \in \mathbb{R}^{F}$ is the feature vector of atom $i$ at iteration $l$
  • $N(i)$ is the set of neighbors of atom $i$ within a specified cutoff distance
  • $m_{i,j}$ is the message received from atom $j$
  • $U$ is the update function, a learnable transformation applied to the aggregated messages.

What is particularly interesting in PhysNet is the update function $U$, which can be summarized as follows

$$ x_i^{l+1} = u^l \odot x_i^{l} + W^{l}\sigma(m_i^l) + b^l $$

This is a straightforward update function for node updates in neural networks, but $u^l \in \mathbb{R}^{F}$ should be noted: this represents a gating vector $u$, allowing individual entries of the feature vector to be damped or reinforced during the update.

Unke, Oliver T., and Markus Meuwly. 2019. “PhysNet: A Neural Network for Predicting Energies, Forces, Dipole Moments and Partial Charges.” https://doi.org/10.1021/acs.jctc.9b00181.

DimNet++

[coming soon]

TensorNet

TensorNet is a 3rd generation neural network potential (NNP) that is able to capture up to rank-2 tensor properties of an atomic system. The central idea of TensorNet is the irreducible tensor decomposition and the corresponding Cartesian tensor representation. For any rank-2 tensor $X$ defined on $\mathbb{R}^3$, the irreducible decomposition writes:

$$X=\frac{1}{3}Tr(X)I+\frac{1}{2}(X-X^T)+\frac{1}{2}(X+X^T-\frac{2}{3}Tr(X)I)\equiv I^X+A^X+S^X$$.

Here, $I^X$ has only 1 degree of freedom and is invariant under rotations ($R\in SO(3), c=Rc$) as a scalar; $A^X$ has 3 independent components as a anti-symmetric tensor and rotates as a vector ($R\in SO(3), \mathbf{v}'=R\cdot \mathbf{v}$); $S^X$ is a traceless symmetric matrix with 5 independent components and rotates like a rank-2 tensor ($R\in SO(3), T'=RTR^T$). In order to compose a irreducible representation $X=I^X+A^X+S^X$ of a given vector $\mathbf{v}\in \mathbb{R}^3$, specifically a displacement vector between an atom pair, we have the following definitions:

$$ A^X\equiv \begin{pmatrix} 0 & \mathbf{v}_z & -\mathbf{v}_y \\ -\mathbf{v}_z & 0 & \mathbf{v}_x \\ \mathbf{v}_y & -\mathbf{v}_x & 0 \\ \end{pmatrix} $$

and $X\equiv A+\mathbf{v}\otimes \mathbf{v}$.

The radial basis functions are defined as $e_k^{RBF}(r_{ij})=e^{(-\beta_k(e^{-r_{ij}}-\mu_k)^2)}$, where $\beta_k=(2d^{-1}(1-exp(-r_c)))^{-2}$ (constant for any k), and $\mu_k=\frac{k}{K}(1-exp(-r_c)) + exp(-r_c)$ (K is the dimension of the radial basis set).

Now that we have the irreducible representation of displacement vectors and the radial basis functions, we can define the initial messages and update messages.

Initialization:

$$X^{(ij)}=\phi(r_{ij})Z_{ij}(f_I^{(0)}I_0^{(ij)}+f_A^{(0)}A_0^{(ij)}+f_S^{(0)}S_0^{(ij)}),$$

where radial symmetry functions are mapped with linear layers: $f_I^{(0)}=W^I(\phi e^{RBF}(r_{ij}))+b^I$, $f_A^{(0)}=W^A(\phi e^{RBF}(r_{ij}))+b^A$, $f_S^{(0)}=W^S(\phi e^{RBF}(r_{ij}))+b^S$.

$$f_I^{(i)}, f_A^{(i)}, f_S^{(i)}=SiLU(MLP(LayerNorm(Tr(X^TX))))$$

$$X^{(i)}\leftarrow f_I^{(i)}W^II^{(i)}+f_A^{(i)}W^AA^{(i)}+f_S^{(i)}W^SS^{(i)}$$

Update:

$$X^{(i)}\leftarrow X^{(i)}/(\lVert X^{(i)}\rVert+1), Y=W^II^{(i)}+W^AA^{(i)}+W^SS^{(i)}$$

$$f_I^{(ij)}, f_A^{(ij)}, f_S^{(ij)}=\phi(r_{ij})SiLU(MLP(e^{RBF}r_{ij}))$$

$$M^{(ij)}=f_I^{(ij)}I^{(j)}+f_A^{(ij)}A^{(j)}+f_S^{(ij)}S^{(j)}, M^{(i)}=\sum M^{(ij)}$$

$I^{(i)}, I^{(i)}, I^{(i)}$ are calculated from the normalized irreducible decomposition of $Y^{(i)}M^{(i)}+M^{(i)}Y^{(i)}$.

Now update $Y^{(i)}$ and then $X^{(i)}$:

$$Y^{(i)}\leftarrow W^II^{(i)}+W^AA^{(i)}+W^SS^{(i)}, X^{(i)}\leftarrow X^{(i)}+\Delta X^{(i)}=X^{(i)}+Y^{(i)}(Y^{(i)})^2$$

For futher details, refer to the TensorNet arXiv paper.

ANI2x