Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure consistent dihedral angle computations with OpenMM and mdanalysis #1247

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 21 additions & 24 deletions timemachine/cpp/src/kernels/k_periodic_torsion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,25 @@ void __global__ k_periodic_torsion(
int k_idx = torsion_idxs[t_idx * 4 + 2];
int l_idx = torsion_idxs[t_idx * 4 + 3];

RealType rij[D];
RealType rkj[D];
RealType rkl[D];
RealType r0[D];
RealType r1[D];
RealType r2[D];

RealType rkj_norm_square = 0;
RealType r1_norm_square = 0;

// (todo) cap to three dims, while keeping stride at 4
for (int d = 0; d < D; d++) {
RealType vij = coords[j_idx * D + d] - coords[i_idx * D + d];
RealType vkj = coords[j_idx * D + d] - coords[k_idx * D + d];
RealType vkl = coords[l_idx * D + d] - coords[k_idx * D + d];
rij[d] = vij;
rkj[d] = vkj;
rkl[d] = vkl;
rkj_norm_square += vkj * vkj;
r0[d] = coords[i_idx * D + d] - coords[j_idx * D + d];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While here, might be good to use the same terminology for the reference and the kernel. Here you use i, j, k, l and in the reference you use xa, xb, xc, xd

r1[d] = coords[k_idx * D + d] - coords[j_idx * D + d];
r2[d] = coords[k_idx * D + d] - coords[l_idx * D + d];
r1_norm_square += r1[d] * r1[d];
}

RealType rkj_norm = sqrt(rkj_norm_square);
RealType r1_norm = sqrt(r1_norm_square);
RealType n1[D], n2[D];

cross_product(rij, rkj, n1);
cross_product(rkj, rkl, n2);
cross_product(r0, r1, n1);
cross_product(r1, r2, n2);

RealType n1_norm_square, n2_norm_square;

Expand All @@ -74,25 +71,25 @@ void __global__ k_periodic_torsion(
RealType d_angle_dR1[D];
RealType d_angle_dR2[D];

RealType rij_dot_rkj = dot_product(rij, rkj);
RealType rkl_dot_rkj = dot_product(rkl, rkj);
RealType r0_dot_r1 = dot_product(r0, r1);
RealType r2_dot_r1 = dot_product(r2, r1);

for (int d = 0; d < D; d++) {
d_angle_dR0[d] = rkj_norm / n1_norm_square * n1[d];
d_angle_dR3[d] = -rkj_norm / n2_norm_square * n2[d];
d_angle_dR0[d] = r1_norm / n1_norm_square * n1[d];
d_angle_dR3[d] = -r1_norm / n2_norm_square * n2[d];
d_angle_dR1[d] =
(rij_dot_rkj / rkj_norm_square - 1) * d_angle_dR0[d] - d_angle_dR3[d] * rkl_dot_rkj / rkj_norm_square;
(r0_dot_r1 / r1_norm_square - 1) * d_angle_dR0[d] - d_angle_dR3[d] * r2_dot_r1 / r1_norm_square;
d_angle_dR2[d] =
(rkl_dot_rkj / rkj_norm_square - 1) * d_angle_dR3[d] - d_angle_dR0[d] * rij_dot_rkj / rkj_norm_square;
(r2_dot_r1 / r1_norm_square - 1) * d_angle_dR3[d] - d_angle_dR0[d] * r0_dot_r1 / r1_norm_square;
}

RealType rkj_n = sqrt(dot_product(rkj, rkj));
RealType r1_n = sqrt(dot_product(r1, r1));

for (int d = 0; d < D; d++) {
rkj[d] /= rkj_n;
r1[d] /= r1_n;
}

RealType y = dot_product(n3, rkj);
RealType y = dot_product(n3, r1);
RealType x = dot_product(n1, n2);
RealType angle = atan2(y, x);

Expand All @@ -104,7 +101,7 @@ void __global__ k_periodic_torsion(
RealType phase = params[phase_idx];
RealType period = params[period_idx];

RealType prefactor = kt * sin(period * angle - phase) * period;
RealType prefactor = -kt * sin(period * angle - phase) * period;

if (du_dx) {
for (int d = 0; d < D; d++) {
Expand Down
21 changes: 12 additions & 9 deletions timemachine/potentials/bonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ def harmonic_angle(conf, params, box, angle_idxs, cos_angles=True):
return jnp.sum(energies, -1) # reduce over all angles


def signed_torsion_angle(ci, cj, ck, cl):
def signed_torsion_angle(xa, xb, xc, xd):
"""
Batch compute the signed angle of a torsion angle. The torsion angle
between two planes should be periodic but not necessarily symmetric.

Parameters
----------
ci, cj, ck, cl: shape [num_torsions, 3] np.ndarrays
atom coordinates defining torsion angle i-j-k-l
xa, xb, xc, xd: shape [num_torsions, 3] np.ndarrays
atom coordinates defining torsion angle a-b-c-d

Returns
-------
Expand All @@ -164,14 +164,17 @@ def signed_torsion_angle(ci, cj, ck, cl):
# implementation as opposed to the OpenMM energy function to
# avoid a singularity when the angle is zero.

rij = cj - ci
rkj = cj - ck
rkl = cl - ck
# (ytz): Feb 4th 2024, switch to OpenMM convention
# https://github.com/openmm/openmm/blob/master/platforms/reference/src/SimTKReference/ReferenceProperDihedralBond.cpp#L97C23-L97C32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Be good to have a permalink in case openmm changes this code ever


n1 = jnp.cross(rij, rkj)
n2 = jnp.cross(rkj, rkl)
r0 = xa - xb
r1 = xc - xb
r2 = xc - xd

y = jnp.sum(jnp.multiply(jnp.cross(n1, n2), rkj / jnp.linalg.norm(rkj, axis=-1, keepdims=True)), axis=-1)
n1 = jnp.cross(r0, r1)
n2 = jnp.cross(r1, r2)

y = jnp.sum(jnp.multiply(jnp.cross(n1, n2), r1 / jnp.linalg.norm(r1, axis=-1, keepdims=True)), axis=-1)
x = jnp.sum(jnp.multiply(n1, n2), -1)

return jnp.arctan2(y, x)
Expand Down