diff --git a/timemachine/cpp/src/kernels/k_periodic_torsion.cuh b/timemachine/cpp/src/kernels/k_periodic_torsion.cuh index f8d4f9f2a..99cc8b8c0 100644 --- a/timemachine/cpp/src/kernels/k_periodic_torsion.cuh +++ b/timemachine/cpp/src/kernels/k_periodic_torsion.cuh @@ -38,28 +38,25 @@ void __global__ k_periodic_torsion( int k_idx = torsion_idxs[t_idx * 4 + 2]; int l_idx = torsion_idxs[t_idx * 4 + 3]; - RealType rij[D]; - RealType rkj[D]; - RealType rkl[D]; + RealType r0[D]; + RealType r1[D]; + RealType r2[D]; - RealType rkj_norm_square = 0; + RealType r1_norm_square = 0; // (todo) cap to three dims, while keeping stride at 4 for (int d = 0; d < D; d++) { - RealType vij = coords[j_idx * D + d] - coords[i_idx * D + d]; - RealType vkj = coords[j_idx * D + d] - coords[k_idx * D + d]; - RealType vkl = coords[l_idx * D + d] - coords[k_idx * D + d]; - rij[d] = vij; - rkj[d] = vkj; - rkl[d] = vkl; - rkj_norm_square += vkj * vkj; + r0[d] = coords[i_idx * D + d] - coords[j_idx * D + d]; + r1[d] = coords[k_idx * D + d] - coords[j_idx * D + d]; + r2[d] = coords[k_idx * D + d] - coords[l_idx * D + d]; + r1_norm_square += r1[d] * r1[d]; } - RealType rkj_norm = sqrt(rkj_norm_square); + RealType r1_norm = sqrt(r1_norm_square); RealType n1[D], n2[D]; - cross_product(rij, rkj, n1); - cross_product(rkj, rkl, n2); + cross_product(r0, r1, n1); + cross_product(r1, r2, n2); RealType n1_norm_square, n2_norm_square; @@ -74,25 +71,25 @@ void __global__ k_periodic_torsion( RealType d_angle_dR1[D]; RealType d_angle_dR2[D]; - RealType rij_dot_rkj = dot_product(rij, rkj); - RealType rkl_dot_rkj = dot_product(rkl, rkj); + RealType r0_dot_r1 = dot_product(r0, r1); + RealType r2_dot_r1 = dot_product(r2, r1); for (int d = 0; d < D; d++) { - d_angle_dR0[d] = rkj_norm / n1_norm_square * n1[d]; - d_angle_dR3[d] = -rkj_norm / n2_norm_square * n2[d]; + d_angle_dR0[d] = r1_norm / n1_norm_square * n1[d]; + d_angle_dR3[d] = -r1_norm / n2_norm_square * n2[d]; d_angle_dR1[d] = - (rij_dot_rkj / rkj_norm_square - 1) * d_angle_dR0[d] - d_angle_dR3[d] * rkl_dot_rkj / rkj_norm_square; + (r0_dot_r1 / r1_norm_square - 1) * d_angle_dR0[d] - d_angle_dR3[d] * r2_dot_r1 / r1_norm_square; d_angle_dR2[d] = - (rkl_dot_rkj / rkj_norm_square - 1) * d_angle_dR3[d] - d_angle_dR0[d] * rij_dot_rkj / rkj_norm_square; + (r2_dot_r1 / r1_norm_square - 1) * d_angle_dR3[d] - d_angle_dR0[d] * r0_dot_r1 / r1_norm_square; } - RealType rkj_n = sqrt(dot_product(rkj, rkj)); + RealType r1_n = sqrt(dot_product(r1, r1)); for (int d = 0; d < D; d++) { - rkj[d] /= rkj_n; + r1[d] /= r1_n; } - RealType y = dot_product(n3, rkj); + RealType y = dot_product(n3, r1); RealType x = dot_product(n1, n2); RealType angle = atan2(y, x); @@ -104,7 +101,7 @@ void __global__ k_periodic_torsion( RealType phase = params[phase_idx]; RealType period = params[period_idx]; - RealType prefactor = kt * sin(period * angle - phase) * period; + RealType prefactor = -kt * sin(period * angle - phase) * period; if (du_dx) { for (int d = 0; d < D; d++) { diff --git a/timemachine/potentials/bonded.py b/timemachine/potentials/bonded.py index f203eae11..3a3ea7ccf 100644 --- a/timemachine/potentials/bonded.py +++ b/timemachine/potentials/bonded.py @@ -131,15 +131,15 @@ def harmonic_angle(conf, params, box, angle_idxs): return jnp.sum(energies, axis=-1) # reduce over all angles -def signed_torsion_angle(ci, cj, ck, cl): +def signed_torsion_angle(xa, xb, xc, xd): """ Batch compute the signed angle of a torsion angle. The torsion angle between two planes should be periodic but not necessarily symmetric. Parameters ---------- - ci, cj, ck, cl: shape [num_torsions, 3] np.ndarrays - atom coordinates defining torsion angle i-j-k-l + xa, xb, xc, xd: shape [num_torsions, 3] np.ndarrays + atom coordinates defining torsion angle a-b-c-d Returns ------- @@ -154,14 +154,17 @@ def signed_torsion_angle(ci, cj, ck, cl): # implementation as opposed to the OpenMM energy function to # avoid a singularity when the angle is zero. - rij = cj - ci - rkj = cj - ck - rkl = cl - ck + # (ytz): Feb 4th 2024, switch to OpenMM convention + # https://github.com/openmm/openmm/blob/master/platforms/reference/src/SimTKReference/ReferenceProperDihedralBond.cpp#L97C23-L97C32 - n1 = jnp.cross(rij, rkj) - n2 = jnp.cross(rkj, rkl) + r0 = xa - xb + r1 = xc - xb + r2 = xc - xd - y = jnp.sum(jnp.multiply(jnp.cross(n1, n2), rkj / jnp.linalg.norm(rkj, axis=-1, keepdims=True)), axis=-1) + n1 = jnp.cross(r0, r1) + n2 = jnp.cross(r1, r2) + + y = jnp.sum(jnp.multiply(jnp.cross(n1, n2), r1 / jnp.linalg.norm(r1, axis=-1, keepdims=True)), axis=-1) x = jnp.sum(jnp.multiply(n1, n2), -1) return jnp.arctan2(y, x)