diff --git a/tds/metadynamics_3d.py b/tds/metadynamics_3d.py index 03b0cb0..f418a0f 100644 --- a/tds/metadynamics_3d.py +++ b/tds/metadynamics_3d.py @@ -64,10 +64,16 @@ def cell(self): @property def mesh(self): if self._mesh is None: - self._mesh = np.einsum('j...,ji->...i', np.meshgrid( - *[np.linspace(0, 1, np.rint(c / self.spacing).astype(int)) for c in self.cell], - indexing='ij' - ), self.unit_cell.cell) + linspace = [] + for c in self.cell: + ll = np.linspace(0, 1, np.rint(c / self.spacing).astype(int), endpoint=False) + ll += 0.5 * (ll[1] - ll[0]) + linspace.append(ll) + self._mesh = np.einsum( + 'j...,ji->...i', + np.meshgrid(*linspace, indexing='ij'), + self.unit_cell.cell + ) return self._mesh @property @@ -106,17 +112,24 @@ def append_positions(self, x, symmetrize=True): dx /= self.sigma B = self.increment * np.exp(-dist**2 / (2 * self.sigma**2)) np.add.at(self.B, unraveled_indices, B) - np.add.at(self.dBds, unraveled_indices, np.einsum('i,...,->...i', dx, B, 1 / self.sigma)) + np.add.at(self.dBds, unraveled_indices, np.einsum('...i,...->...i', dx, B / self.sigma)) if self.use_gradient: - xx = (np.outer(dx, dx) - np.eye(3)) / self.sigma**2 - np.add.at(self.ddBdds, unraveled_indices, np.einsum('ij,...->...ij', xx, B)) + xx = (np.einsum('...i,...j->...ij', dx, dx) - np.eye(3)) / self.sigma**2 + np.add.at(self.ddBdds, unraveled_indices, np.einsum('...ij,...->...ij', xx, B)) self._x_lst.extend(x) def _get_index(self, x): - return np.unravel_index(self.tree_mesh.query(self.x_to_s(x))[1], self.mesh.shape[:-1]) + return np.unravel_index( + self.tree_mesh.query(self.x_to_s(x))[1], self.mesh.shape[:-1] + ) def get_force(self, x): - return self.dBds[self._get_index(x)] + index = self._get_index(x) + dBds = self.dBds[self._get_index(x)] + if self.use_gradient: + dx = x - self.mesh[index] + dBds += np.einsum('...j,ij->...i', dx, self.ddBdds[index]) + return dBds @property def x_lst(self): diff --git a/tests/test_metadynamics_3d.py b/tests/test_metadynamics_3d.py index bf3c898..8226d62 100644 --- a/tests/test_metadynamics_3d.py +++ b/tests/test_metadynamics_3d.py @@ -40,7 +40,9 @@ def test_get_symmetric_x(self): def test_append_position(self): x = np.random.random(3) self.unit_cell.append_positions(x) - self.assertAlmostEqual(np.linalg.norm(x - self.unit_cell.x_lst, axis=-1).min(), 0) + self.assertAlmostEqual( + np.linalg.norm(x - self.unit_cell.x_lst, axis=-1).min(), 0 + ) def test_get_neighbors(self): x = self.unit_cell._get_symmetric_x(np.random.random(3)) @@ -49,7 +51,10 @@ def test_get_neighbors(self): x = self.unit_cell.x_to_s(np.random.random((1, 3))) dist, dx, indices = self.unit_cell._get_neighbors(x) self.assertTrue( - np.allclose(dist, np.linalg.norm(self.unit_cell.mesh - x, axis=-1)[indices]) + np.allclose( + dist, + np.linalg.norm(self.unit_cell.mesh - x, axis=-1)[indices] + ) ) self.assertTrue(np.allclose(dx, (self.unit_cell.mesh - x)[indices])) @@ -59,21 +64,41 @@ def test_get_energy(self): self.assertLessEqual(self.unit_cell.get_energy(x), -self.unit_cell.increment) x = self.unit_cell.x_to_s(x) dist, _ = self.unit_cell.tree_output.query( - x, k=self.unit_cell._num_neighbors_x_lst, distance_upper_bound=self.unit_cell.cutoff + x, + k=self.unit_cell._num_neighbors_x_lst, + distance_upper_bound=self.unit_cell.cutoff ) self.assertLess( np.sum(dist < np.inf, axis=-1).max(), self.unit_cell._num_neighbors_x_lst ) + def test_B(self): + x = self.unit_cell.mesh[ + tuple(np.random.randint(self.unit_cell.mesh.shape[:-1])) + ] + self.unit_cell.append_positions(x) + ind = self.unit_cell._get_index(x) + self.assertGreater(self.unit_cell.B[ind], self.unit_cell.increment) + self.assertLess(np.linalg.eigh(self.unit_cell.ddBdds[ind])[0].max(), 0.) + def test_get_force(self): x = np.random.random((1, 3)) x = self.unit_cell.x_to_s(x) self.unit_cell.append_positions(x, symmetrize=False) force_max = self.unit_cell.increment / self.unit_cell.sigma * np.exp(-1 / 2) self.assertLess( - abs(force_max - np.linalg.norm(self.unit_cell.dBds, axis=-1).max()), 1.0e-4 + abs(force_max - np.linalg.norm(self.unit_cell.dBds, axis=-1).max()), + 1.0e-4 ) + def test_get_index(self): + x = 100 * np.random.randn(3) + ind_meta = self.unit_cell._get_index(x) + x = self.unit_cell.x_to_s(x) + ind_min = np.argmin(np.linalg.norm(self.unit_cell.mesh - x, axis=-1)) + ind_min = np.unravel_index(ind_min, self.unit_cell.mesh.shape[:-1]) + self.assertTrue(np.array_equal(ind_min, ind_meta)) + if __name__ == "__main__": unittest.main()