Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samwaseda committed Jan 11, 2022
1 parent c4458d0 commit a014acc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
31 changes: 22 additions & 9 deletions tds/metadynamics_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,16 @@ def cell(self):
@property
def mesh(self):
if self._mesh is None:
self._mesh = np.einsum('j...,ji->...i', np.meshgrid(
*[np.linspace(0, 1, np.rint(c / self.spacing).astype(int)) for c in self.cell],
indexing='ij'
), self.unit_cell.cell)
linspace = []
for c in self.cell:
ll = np.linspace(0, 1, np.rint(c / self.spacing).astype(int), endpoint=False)
ll += 0.5 * (ll[1] - ll[0])
linspace.append(ll)
self._mesh = np.einsum(
'j...,ji->...i',
np.meshgrid(*linspace, indexing='ij'),
self.unit_cell.cell
)
return self._mesh

@property
Expand Down Expand Up @@ -106,17 +112,24 @@ def append_positions(self, x, symmetrize=True):
dx /= self.sigma
B = self.increment * np.exp(-dist**2 / (2 * self.sigma**2))
np.add.at(self.B, unraveled_indices, B)
np.add.at(self.dBds, unraveled_indices, np.einsum('i,...,->...i', dx, B, 1 / self.sigma))
np.add.at(self.dBds, unraveled_indices, np.einsum('...i,...->...i', dx, B / self.sigma))
if self.use_gradient:
xx = (np.outer(dx, dx) - np.eye(3)) / self.sigma**2
np.add.at(self.ddBdds, unraveled_indices, np.einsum('ij,...->...ij', xx, B))
xx = (np.einsum('...i,...j->...ij', dx, dx) - np.eye(3)) / self.sigma**2
np.add.at(self.ddBdds, unraveled_indices, np.einsum('...ij,...->...ij', xx, B))
self._x_lst.extend(x)

def _get_index(self, x):
return np.unravel_index(self.tree_mesh.query(self.x_to_s(x))[1], self.mesh.shape[:-1])
return np.unravel_index(
self.tree_mesh.query(self.x_to_s(x))[1], self.mesh.shape[:-1]
)

def get_force(self, x):
return self.dBds[self._get_index(x)]
index = self._get_index(x)
dBds = self.dBds[self._get_index(x)]
if self.use_gradient:
dx = x - self.mesh[index]
dBds += np.einsum('...j,ij->...i', dx, self.ddBdds[index])
return dBds

@property
def x_lst(self):
Expand Down
33 changes: 29 additions & 4 deletions tests/test_metadynamics_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def test_get_symmetric_x(self):
def test_append_position(self):
x = np.random.random(3)
self.unit_cell.append_positions(x)
self.assertAlmostEqual(np.linalg.norm(x - self.unit_cell.x_lst, axis=-1).min(), 0)
self.assertAlmostEqual(
np.linalg.norm(x - self.unit_cell.x_lst, axis=-1).min(), 0
)

def test_get_neighbors(self):
x = self.unit_cell._get_symmetric_x(np.random.random(3))
Expand All @@ -49,7 +51,10 @@ def test_get_neighbors(self):
x = self.unit_cell.x_to_s(np.random.random((1, 3)))
dist, dx, indices = self.unit_cell._get_neighbors(x)
self.assertTrue(
np.allclose(dist, np.linalg.norm(self.unit_cell.mesh - x, axis=-1)[indices])
np.allclose(
dist,
np.linalg.norm(self.unit_cell.mesh - x, axis=-1)[indices]
)
)
self.assertTrue(np.allclose(dx, (self.unit_cell.mesh - x)[indices]))

Expand All @@ -59,21 +64,41 @@ def test_get_energy(self):
self.assertLessEqual(self.unit_cell.get_energy(x), -self.unit_cell.increment)
x = self.unit_cell.x_to_s(x)
dist, _ = self.unit_cell.tree_output.query(
x, k=self.unit_cell._num_neighbors_x_lst, distance_upper_bound=self.unit_cell.cutoff
x,
k=self.unit_cell._num_neighbors_x_lst,
distance_upper_bound=self.unit_cell.cutoff
)
self.assertLess(
np.sum(dist < np.inf, axis=-1).max(), self.unit_cell._num_neighbors_x_lst
)

def test_B(self):
x = self.unit_cell.mesh[
tuple(np.random.randint(self.unit_cell.mesh.shape[:-1]))
]
self.unit_cell.append_positions(x)
ind = self.unit_cell._get_index(x)
self.assertGreater(self.unit_cell.B[ind], self.unit_cell.increment)
self.assertLess(np.linalg.eigh(self.unit_cell.ddBdds[ind])[0].max(), 0.)

def test_get_force(self):
x = np.random.random((1, 3))
x = self.unit_cell.x_to_s(x)
self.unit_cell.append_positions(x, symmetrize=False)
force_max = self.unit_cell.increment / self.unit_cell.sigma * np.exp(-1 / 2)
self.assertLess(
abs(force_max - np.linalg.norm(self.unit_cell.dBds, axis=-1).max()), 1.0e-4
abs(force_max - np.linalg.norm(self.unit_cell.dBds, axis=-1).max()),
1.0e-4
)

def test_get_index(self):
x = 100 * np.random.randn(3)
ind_meta = self.unit_cell._get_index(x)
x = self.unit_cell.x_to_s(x)
ind_min = np.argmin(np.linalg.norm(self.unit_cell.mesh - x, axis=-1))
ind_min = np.unravel_index(ind_min, self.unit_cell.mesh.shape[:-1])
self.assertTrue(np.array_equal(ind_min, ind_meta))


if __name__ == "__main__":
unittest.main()

0 comments on commit a014acc

Please sign in to comment.