From 439d7d8dd36e7b2c8efe0be87af266d0041446fe Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 24 Dec 2024 10:23:54 -0500 Subject: [PATCH 01/24] Compute log wrapped gaussians. --- .../score/wrapped_gaussian_score.py | 55 +++++++++++++++++++ tests/score/test_wrapped_gaussian_score.py | 51 +++++++++++++++-- 2 files changed, 100 insertions(+), 6 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/score/wrapped_gaussian_score.py b/src/diffusion_for_multi_scale_molecular_dynamics/score/wrapped_gaussian_score.py index c7062291..e9339fa0 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/score/wrapped_gaussian_score.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/score/wrapped_gaussian_score.py @@ -30,6 +30,7 @@ from typing import Optional +import einops import numpy as np import torch @@ -37,6 +38,60 @@ U_THRESHOLD = torch.Tensor([0.5]) +def get_log_wrapped_gaussians( + relative_coordinates: torch.tensor, sigmas: torch.tensor, kmax: int +): + """Get Log Wrapped Gaussians. + + Args: + relative_coordinates : input relative coordinates: should be between 0 and 1. + relative_coordinates has dimensions [..., number_of_atoms, spatial_dimension], where (...) + are arbitrary batch dimensions. + sigmas : the values of sigma. Should have the same dimension as relative coordinates. + kmax : largest positive integer in the sum. The sum is from -kmax to +kmax. + + Returns: + log_wrapped_gaussians: log of wrapped gaussian values, of dimensions [...], namely the batch dimensions. + """ + device = relative_coordinates.device + assert ( + sigmas.device == device + ), "relative_coordinates and sigmas should be on the same device." + + assert ( + relative_coordinates.shape == sigmas.shape + ), "The relative coordinates and sigmas array should have the same shape." + + assert ( + len(relative_coordinates.shape) >= 3 + ), "relative_coordinates should have at least 3 dimensions." + + input_shape = relative_coordinates.shape + + # The dimension of list_k is [2 kmax + 1]. + list_k = torch.arange(-kmax, kmax + 1).to(device) + + # Broadcast to shape [Nu, 1] + column_u = einops.rearrange(relative_coordinates, "... -> (...) 1") + column_sigma = einops.rearrange(sigmas, "... -> (...) 1") + + norm = torch.tensor(2 * torch.pi).sqrt() * column_sigma.squeeze(-1) + + flat_log_norm = torch.log(norm) + + # Broadcast to shape [Nu, Nk] + exponentials = -0.5 * (column_u + list_k) ** 2 / column_sigma**2 + + flat_logsumexp = torch.logsumexp(exponentials, dim=-1) + + flat_log_gaussians = flat_logsumexp - flat_log_norm + + log_gaussians = flat_log_gaussians.reshape(input_shape) + + log_wrapped_gaussians = log_gaussians.sum(dim=[-2, -1]) + return log_wrapped_gaussians + + def get_sigma_normalized_score_brute_force( u: float, sigma: float, kmax: Optional[int] = None ) -> float: diff --git a/tests/score/test_wrapped_gaussian_score.py b/tests/score/test_wrapped_gaussian_score.py index 3cecc2a9..656b0628 100644 --- a/tests/score/test_wrapped_gaussian_score.py +++ b/tests/score/test_wrapped_gaussian_score.py @@ -1,3 +1,4 @@ +import einops import numpy as np import pytest import torch @@ -7,7 +8,8 @@ _get_s1b_exponential, _get_sigma_normalized_s2, _get_sigma_square_times_score_1_from_exponential, _get_small_sigma_large_u_mask, _get_small_sigma_small_u_mask, - get_sigma_normalized_score, get_sigma_normalized_score_brute_force) + get_log_wrapped_gaussians, get_sigma_normalized_score, + get_sigma_normalized_score_brute_force) @pytest.fixture(scope="module", autouse=True) @@ -50,12 +52,17 @@ def list_k(kmax): @pytest.fixture -def expected_sigma_normalized_scores(relative_coordinates, sigmas): +def large_sigmas(shape): + return (10.0 * torch.rand(shape)).clip(0.1) + + +@pytest.fixture +def expected_sigma_normalized_scores(relative_coordinates, large_sigmas): shape = relative_coordinates.shape list_sigma_normalized_scores = [] for u, sigma in zip( - relative_coordinates.numpy().flatten(), sigmas.numpy().flatten() + relative_coordinates.numpy().flatten(), large_sigmas.numpy().flatten() ): s = get_sigma_normalized_score_brute_force(u, sigma) list_sigma_normalized_scores.append(s) @@ -157,6 +164,36 @@ def test_get_sigma_normalized_s1_from_exponential( torch.testing.assert_close(expected_results, computed_results) +@pytest.mark.parametrize("shape", [(10, 10, 3, 2), (50, 8, 3)]) +@pytest.mark.parametrize("kmax", [1, 5, 10]) +def test_get_wrapped_gaussian(relative_coordinates, sigmas, shape, kmax): + + log_wrapped_gaussians = get_log_wrapped_gaussians(relative_coordinates, sigmas, kmax) + assert log_wrapped_gaussians.shape == shape[:-2] + + computed_wrapped_gaussians = einops.rearrange(torch.exp(log_wrapped_gaussians), " ... -> (...)") + + list_x = einops.rearrange( + relative_coordinates, "... natoms space -> (...) (natoms space)" + ) + list_sigma = einops.rearrange(sigmas, "... natoms space -> (...) (natoms space)") + + list_expected_wrapped_gaussians = [] + for row_x, row_sigma in zip(list_x, list_sigma): + wrapped_gaussian = torch.tensor(1.0) + for x, sigma in zip(row_x, row_sigma): + prob = torch.tensor(0.0) + for k in range(-kmax, kmax + 1): + log_prob = torch.distributions.Normal(loc=k, scale=sigma).log_prob(x) + prob += torch.exp(log_prob) + wrapped_gaussian *= prob + list_expected_wrapped_gaussians.append(wrapped_gaussian) + + expected_wrapped_gaussians = torch.tensor(list_expected_wrapped_gaussians) + + torch.testing.assert_close(computed_wrapped_gaussians, expected_wrapped_gaussians) + + @pytest.mark.parametrize("kmax", [1, 5, 10]) @pytest.mark.parametrize("shape", test_shapes) @pytest.mark.parametrize("numerical_type", [torch.double]) @@ -194,16 +231,18 @@ def test_get_sigma_normalized_s2(list_u, list_sigma, list_k, numerical_type): @pytest.mark.parametrize("kmax", [4]) @pytest.mark.parametrize("shape", test_shapes) def test_get_sigma_normalized_score( - relative_coordinates, sigmas, kmax, expected_sigma_normalized_scores + relative_coordinates, large_sigmas, kmax, expected_sigma_normalized_scores ): sigma_normalized_score_small_sigma = get_sigma_normalized_score( - relative_coordinates, sigmas, kmax + relative_coordinates, large_sigmas, kmax ) # The brute force calculation is fragile to the creation of NaNs. # Let's give the test a free pass when this happens. nan_mask = torch.where(expected_sigma_normalized_scores.isnan()) - expected_sigma_normalized_scores[nan_mask] = sigma_normalized_score_small_sigma[nan_mask] + expected_sigma_normalized_scores[nan_mask] = sigma_normalized_score_small_sigma[ + nan_mask + ] torch.testing.assert_close( sigma_normalized_score_small_sigma, From 12d787147d7ee38764f45808a596b207d56cf5c5 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 24 Dec 2024 12:20:15 -0500 Subject: [PATCH 02/24] Permutation over atomic indices. --- .../utils/symmetry_utils.py | 28 ++++++++ tests/utils/test_symmetry_utils.py | 69 +++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 src/diffusion_for_multi_scale_molecular_dynamics/utils/symmetry_utils.py create mode 100644 tests/utils/test_symmetry_utils.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/symmetry_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/symmetry_utils.py new file mode 100644 index 00000000..5afdfcb7 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/symmetry_utils.py @@ -0,0 +1,28 @@ +import itertools +from typing import Tuple + +import torch + + +def get_all_permutation_indices(number_of_atoms) -> Tuple[torch.Tensor, torch.Tensor]: + """Get all permutation indices. + + Produce all permutation indices to permute tensors which represent atoms. + + Args: + number_of_atoms: number of atoms to permute. + + Returns: + perm_indices: indices for all permutations. + inverse_perm_indices: indices for all inverse permutations. + """ + perm_indices = torch.stack( + [ + torch.tensor(perm) + for perm in itertools.permutations(range(number_of_atoms)) + ] + ) + + inverse_perm_indices = perm_indices.argsort(dim=1) + + return perm_indices, inverse_perm_indices diff --git a/tests/utils/test_symmetry_utils.py b/tests/utils/test_symmetry_utils.py new file mode 100644 index 00000000..52da5105 --- /dev/null +++ b/tests/utils/test_symmetry_utils.py @@ -0,0 +1,69 @@ +import pytest +import torch + +from diffusion_for_multi_scale_molecular_dynamics.utils.symmetry_utils import \ + get_all_permutation_indices + + +def factorial(n): + if n == 1: + return 1 + else: + return n * factorial(n - 1) + + +class TestPermutationIndices: + + @pytest.fixture + def number_of_atoms(self): + return 4 + + @pytest.fixture + def dim(self): + return 8 + + @pytest.fixture + def random_data(self, number_of_atoms, dim): + return torch.rand(number_of_atoms, dim) + + @pytest.fixture + def perm_indices_and_inverse_perm_indices(self, number_of_atoms, dim): + perm_indices, inverse_perm_indices = get_all_permutation_indices(number_of_atoms) + return perm_indices, inverse_perm_indices + + def test_shape(self, number_of_atoms, perm_indices_and_inverse_perm_indices): + perm_indices, inverse_perm_indices = perm_indices_and_inverse_perm_indices + number_of_permutations = factorial(number_of_atoms) + expected_shape = (number_of_permutations, number_of_atoms) + assert perm_indices.shape == expected_shape + assert inverse_perm_indices.shape == expected_shape + + def test_unique(self, number_of_atoms, perm_indices_and_inverse_perm_indices): + perm_indices, inverse_perm_indices = perm_indices_and_inverse_perm_indices + + perm_indices_set = set([tuple(list(indices)) for indices in perm_indices.numpy()]) + inv_perm_indices_set = set([tuple(list(indices)) for indices in inverse_perm_indices.numpy()]) + + number_of_permutations = factorial(number_of_atoms) + + assert len(perm_indices_set) == number_of_permutations + assert len(inv_perm_indices_set) == number_of_permutations + + def test_correct_range(self, number_of_atoms, perm_indices_and_inverse_perm_indices): + perm_indices, inverse_perm_indices = perm_indices_and_inverse_perm_indices + + expected_sorted_indices = torch.arange(number_of_atoms) + + for sorted_indices in perm_indices.sort(dim=1).values: + torch.testing.assert_close(expected_sorted_indices, sorted_indices) + + for sorted_indices in inverse_perm_indices.sort(dim=1).values: + torch.testing.assert_close(expected_sorted_indices, sorted_indices) + + def test_inverse(self, number_of_atoms, perm_indices_and_inverse_perm_indices, random_data): + perm_indices, inverse_perm_indices = perm_indices_and_inverse_perm_indices + + for indices, inverse_indices in zip(perm_indices, inverse_perm_indices): + permuted_data = random_data[indices].clone() + original_data = permuted_data[inverse_indices].clone() + torch.testing.assert_close(original_data, random_data) From 83e99b848cc6a566af58d45741cfe78bb0265977 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 24 Dec 2024 12:23:22 -0500 Subject: [PATCH 03/24] Code a trivial factorial function to avoid importing something else. --- .../utils/symmetry_utils.py | 8 ++++++++ tests/utils/test_symmetry_utils.py | 11 ++--------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/symmetry_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/symmetry_utils.py index 5afdfcb7..0875c6cd 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/symmetry_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/symmetry_utils.py @@ -4,6 +4,14 @@ import torch +def factorial(n): + """Factorial function.""" + if n == 1: + return 1 + else: + return n * factorial(n - 1) + + def get_all_permutation_indices(number_of_atoms) -> Tuple[torch.Tensor, torch.Tensor]: """Get all permutation indices. diff --git a/tests/utils/test_symmetry_utils.py b/tests/utils/test_symmetry_utils.py index 52da5105..c769c07f 100644 --- a/tests/utils/test_symmetry_utils.py +++ b/tests/utils/test_symmetry_utils.py @@ -1,15 +1,8 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.utils.symmetry_utils import \ - get_all_permutation_indices - - -def factorial(n): - if n == 1: - return 1 - else: - return n * factorial(n - 1) +from diffusion_for_multi_scale_molecular_dynamics.utils.symmetry_utils import ( + factorial, get_all_permutation_indices) class TestPermutationIndices: From 58c373447287edae7f131a3e6d8b5d9e18a01ee3 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 24 Dec 2024 12:30:18 -0500 Subject: [PATCH 04/24] Improved analytical score network. --- .../analytical_score_network.py | 293 ++++++++-------- .../test_analytical_score_network.py | 318 +++++++++++++----- 2 files changed, 384 insertions(+), 227 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py index 67a65814..eb33979d 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py @@ -1,9 +1,9 @@ """Analytical Score Network. This module implements an "exact" score network that is obtained under the approximation that the -atomic positions are just small displacements around some equilibrium positions and that the -energy is purely harmonic (ie, quadratic in the displacements). Furthermore, it is assumed that -the covariance matrix is proportional to the identity; this makes the lattice sums manageable. +atomic positions are just small Gaussian-distributed displacements around some equilibrium positions. +Furthermore, it is assumed that the covariance matrix is proportional to the identity; +this makes the lattice sums manageable. Optionally, the score network is made permutation invariant by summing on all atomic permutations. @@ -11,9 +11,8 @@ meant to generate 'production' results. """ -import itertools from dataclasses import dataclass -from typing import Any, AnyStr, Dict +from typing import Any, AnyStr, Dict, Tuple import einops import torch @@ -22,10 +21,12 @@ ScoreNetwork, ScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( AXL, NOISE, NOISY_AXL_COMPOSITION) -from diffusion_for_multi_scale_molecular_dynamics.score.wrapped_gaussian_score import \ - get_sigma_normalized_score +from diffusion_for_multi_scale_molecular_dynamics.score.wrapped_gaussian_score import ( + get_log_wrapped_gaussians, get_sigma_normalized_score) from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell +from diffusion_for_multi_scale_molecular_dynamics.utils.symmetry_utils import \ + get_all_permutation_indices @dataclass(kw_only=True) @@ -33,20 +34,36 @@ class AnalyticalScoreNetworkParameters(ScoreNetworkParameters): """Specific Hyper-parameters for analytical score networks.""" architecture: str = "analytical" - number_of_atoms: int # the number of atoms in a configuration. - kmax: int # the maximum lattice translation along any dimension. Translations will be [-kmax,..,kmax]. - equilibrium_relative_coordinates: ( - torch.Tensor - ) # Should have shape [number_of_atoms, spatial_dimensions] - # Harmonic energy is defined as U = 1/2 u^T . Phi . u for "u" - # the relative coordinate displacements. The 'inverse covariance' is beta Phi - # and should be unitless. This is assumed to be proportional to the identity, such that - # (beta Phi)^{-1} = sigma_d^2 Id, where the shape of Id is - # [number_of_atoms, spatial_dimensions,number_of_atoms, spatial_dimensions] - variance_parameter: float # sigma_d^2 - use_permutation_invariance: bool = ( - False # should the analytical score consider every coordinate permutations. - ) + + # the number of atoms in a configuration. + number_of_atoms: int + + # the maximum lattice translation along any dimension. Translations will be [-kmax,..,kmax]. + kmax: int + + # Should have shape [number_of_atoms, spatial_dimensions] + equilibrium_relative_coordinates: torch.Tensor + + # the data distribution variance. + sigma_d: float + + # should the analytical score consider every coordinate permutations. + # Careful! The number of permutations will scale as number_of_atoms!. This will not + # scale to large number of atoms. + use_permutation_invariance: bool = False + + def __post_init__(self): + """Post init.""" + assert ( + self.num_atom_types == 1 + ), "The analytical score network is only appropriate for a single atom type." + + assert self.sigma_d > 0.0, "the sigma_d parameter should be positive." + + assert self.equilibrium_relative_coordinates.shape == ( + self.number_of_atoms, + self.spatial_dimension, + ), "equilibrium relative coordinates have the wrong shape" class AnalyticalScoreNetwork(ScoreNetwork): @@ -63,26 +80,18 @@ def __init__(self, hyper_params: AnalyticalScoreNetworkParameters): """ super(AnalyticalScoreNetwork, self).__init__(hyper_params) - assert hyper_params.num_atom_types == 1, \ - "The analytical score network is only appropriate for a single atom type." - - self.number_of_atomic_classes = hyper_params.num_atom_types + 1 # account for the MASK class. + self.number_of_atomic_classes = ( + hyper_params.num_atom_types + 1 + ) # account for the MASK class. self.natoms = hyper_params.number_of_atoms self.spatial_dimension = hyper_params.spatial_dimension self.nd = self.natoms * self.spatial_dimension self.kmax = hyper_params.kmax - assert ( - hyper_params.variance_parameter >= 0.0 - ), "the variance parameter should be non negative." - self.sigma_d_square = hyper_params.variance_parameter - - assert hyper_params.equilibrium_relative_coordinates.shape == ( - self.natoms, - self.spatial_dimension, - ), "equilibrium relative coordinates have the wrong shape" + self.sigma_d_square = hyper_params.sigma_d**2 self.use_permutation_invariance = hyper_params.use_permutation_invariance + self.device = hyper_params.equilibrium_relative_coordinates.device # shape: [number_of_translations] @@ -116,140 +125,136 @@ def _get_all_equilibrium_permutations( number_of_atoms = relative_coordinates.shape[0] # Shape : [number of permutations, number of atoms] - perm_indices = torch.stack( - [ - torch.tensor(perm) - for perm in itertools.permutations(range(number_of_atoms)) - ] - ) + perm_indices, _ = get_all_permutation_indices(number_of_atoms) + equilibrium_permutations = relative_coordinates[perm_indices] return equilibrium_permutations - def _forward_unchecked( - self, batch: Dict[AnyStr, Any], conditional: bool = False - ) -> AXL: - """Forward unchecked. - - This method assumes that the input data has already been checked with respect to expectations - and computes the scores assuming that the data is in the correct format. + def get_log_wrapped_gaussians_and_normalized_scores_centered_on_equilibrium_positions( + self, relative_coordinates: torch.tensor, sigmas_t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get all probabilities and normalized scores centered on equilibrium positions. Args: - batch : dictionary containing the data to be processed by the model. - conditional (optional): CURRENTLY DOES NOTHING. + relative_coordinates : input relative coordinates: should be between 0 and 1. + relative_coordinates has dimensions [batch, number_of_atoms, spatial_dimension] + sigmas_t : the values of sigma(t). Should have the same dimension as relative coordinates. Returns: - output : an AXL namedtuple with the coordinates scores computed by the model as a - [batch_size, n_atom, spatial_dimension] tensor. Empty tensors are returned for the atom types and - lattice. + list_log_wrapped_gaussians: list of log wrapped gaussians, of + dimensions [number_of_equilibrium_positions, batch] + list_sigma_normalized_scores : list of sigma normalized scores, of + dimensions [number_of_equilibrium_positions, batch, natoms, spatial_dimension] """ - sigmas = batch[NOISE] # dimension: [batch_size, 1] - xt = batch[NOISY_AXL_COMPOSITION].X - batch_size = xt.shape[0] - xt.requires_grad_(True) - - list_unnormalized_log_prob = [] - for x0 in self.all_x0: - unnormalized_log_prob = self._compute_unnormalized_log_probability( - sigmas, xt, x0 - ) - list_unnormalized_log_prob.append(unnormalized_log_prob) + assert ( + relative_coordinates.shape == sigmas_t.shape + ), "relative_coordinates and sigmas_t have different shapes." + assert ( + len(relative_coordinates.shape) == 3 + ), "relative_coordinates should have 3 dimensions." - list_unnormalized_log_prob = torch.stack(list_unnormalized_log_prob) - log_probs = torch.logsumexp(list_unnormalized_log_prob, dim=0, keepdim=False) + effective_sigmas = torch.sqrt(self.sigma_d_square + sigmas_t**2) - grad_outputs = [torch.ones_like(log_probs)] + number_of_equilibrium_positions = self.all_x0.shape[0] + batch_size = relative_coordinates.shape[0] - scores = torch.autograd.grad( - outputs=[log_probs], inputs=[xt], grad_outputs=grad_outputs - )[0] + x = einops.repeat( + relative_coordinates, + "batch natoms space -> (batch n) natoms space", + n=number_of_equilibrium_positions, + ) - # We actually want sigma x score. - broadcast_sigmas = einops.repeat( - sigmas, "batch 1 -> batch n d", n=self.natoms, d=self.spatial_dimension + x0 = einops.repeat( + self.all_x0, "n natoms space -> (batch n) natoms space", batch=batch_size ) - sigma_normalized_scores = broadcast_sigmas * scores - # Mimic perfect predictions of single possible atomic type. - atomic_logits = torch.zeros(batch_size, self.natoms, self.number_of_atomic_classes) - atomic_logits[..., -1] = -torch.inf + u = map_relative_coordinates_to_unit_cell(x - x0) - axl_scores = AXL( - A=atomic_logits, - X=sigma_normalized_scores, - L=torch.zeros_like(sigma_normalized_scores), + repeated_sigmas_t = einops.repeat( + sigmas_t, + "batch natoms space -> (batch n) natoms space", + n=number_of_equilibrium_positions, ) - return axl_scores - - def _compute_unnormalized_log_probability( - self, sigmas: torch.Tensor, xt: torch.Tensor, x_eq: torch.Tensor - ) -> torch.Tensor: + repeated_effective_sigmas = einops.repeat( + effective_sigmas, + "batch natoms space -> (batch n) natoms space", + n=number_of_equilibrium_positions, + ) - batch_size = sigmas.shape[0] + list_log_wrapped_gaussians = get_log_wrapped_gaussians( + u, repeated_effective_sigmas, self.kmax + ) - # Recast various spatial arrays to the correct dimensions to combine them, - # in dimensions [batch, nd, number_of_translations] - effective_variance = einops.repeat( - sigmas**2 + self.sigma_d_square, - "batch 1 -> batch nd t", - t=self.number_of_translations, - nd=self.nd, + # We leverage the fact that the probability is a wrapped Gaussian to extract the + # score. However, the normalized score thus obtained is improperly normalized. + # the empirical scores are normalized with the time-dependent sigmas; the "data" sigma + # are unknown (or even ill-defined) in general! + list_effective_sigma_normalized_scores = get_sigma_normalized_score( + u, repeated_effective_sigmas, self.kmax ) + list_scores = list_effective_sigma_normalized_scores / repeated_effective_sigmas + list_sigma_normalized_scores = repeated_sigmas_t * list_scores - sampling_coordinates = einops.repeat( - xt, - "batch natom d -> batch (natom d) t", + wrapped_gaussians = einops.rearrange( + list_log_wrapped_gaussians, + "(batch n) -> n batch", + n=number_of_equilibrium_positions, batch=batch_size, - t=self.number_of_translations, ) - equilibrium_coordinates = einops.repeat( - x_eq, - "natom d -> batch (natom d) t", + sigma_normalized_scores = einops.rearrange( + list_sigma_normalized_scores, + "(batch n) natoms space -> n batch natoms space", + n=number_of_equilibrium_positions, batch=batch_size, - t=self.number_of_translations, ) - translations = einops.repeat( - self.translations_k, "t -> batch nd t", batch=batch_size, nd=self.nd - ) + return wrapped_gaussians, sigma_normalized_scores - exponent = ( - -0.5 - * (sampling_coordinates - equilibrium_coordinates - translations) ** 2 - / effective_variance - ) - # logsumexp on lattice translation vectors, then sum on spatial indices - unnormalized_log_prob = torch.logsumexp(exponent, dim=2, keepdim=False).sum( - dim=1 - ) + def get_probabilities_and_normalized_scores( + self, relative_coordinates: torch.tensor, sigmas_t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get probabilities and normalized scores. - return unnormalized_log_prob + Args: + relative_coordinates: relative coordinates, of dimensions [batch_size, natoms, spatial_dimension] + sigmas_t : the values of sigma(t). Should have the same dimension as relative coordinates. + Returns: + probabilities : probability P(x, t) at the input relative coordinates. Dimension [batch_size]. + normalized_scores : normalized scores sigma S(x, t) at the input relative coordinates. + Dimension [batch_size, natoms, space_dimension]. + """ + batch_size, natoms, space_dimension = relative_coordinates.shape + # list_log_w has dimensions [number_of_equilibrium_positions, batch_size] + # list_s has dimensions [number_of_equilibrium_positions, batch_size, natoms, spatial_dimensions] + list_log_w, list_s = ( + self.get_log_wrapped_gaussians_and_normalized_scores_centered_on_equilibrium_positions( + relative_coordinates, sigmas_t + ) + ) -class TargetScoreBasedAnalyticalScoreNetwork(AnalyticalScoreNetwork): - """Target Score-Based Analytical Score Network. + number_of_equilibrium_positions = list_log_w.shape[0] - An analytical score network that leverages the computation of the target for the score network. - This can only work if the permutation equivariance is turned off. This should produce exactly the same results - as the AnalyticalScoreNetwork, but does not require gradient calculation. - """ + probabilities = ( + torch.exp(list_log_w).sum(dim=0) / number_of_equilibrium_positions + ) - def __init__(self, hyper_params: AnalyticalScoreNetworkParameters): - """__init__. + list_weights = einops.repeat( + torch.softmax(list_log_w, dim=0), + "n batch -> n batch natoms space", + natoms=natoms, + space=space_dimension, + ) - Args: - hyper_params : hyper parameters from the config file. - """ - super(TargetScoreBasedAnalyticalScoreNetwork, self).__init__(hyper_params) - assert ( - not hyper_params.use_permutation_invariance - ), "This implementation is only valid in the absence of permutation equivariance." - self.x0 = self.all_x0[0] + normalized_scores = (list_weights * list_s).sum(dim=0) + + return probabilities, normalized_scores def _forward_unchecked( self, batch: Dict[AnyStr, Any], conditional: bool = False - ) -> torch.Tensor: + ) -> AXL: """Forward unchecked. This method assumes that the input data has already been checked with respect to expectations @@ -260,32 +265,26 @@ def _forward_unchecked( conditional (optional): CURRENTLY DOES NOTHING. Returns: - output : the scores computed by the model as a [batch_size, n_atom, spatial_dimension] tensor. + output : an AXL namedtuple with + - the coordinates scores computed by the model as a [batch_size, n_atom, spatial_dimension] tensor. + - perfect atom type predictions, assuming a single atom type possibility. + - a tensor of zeros for the lattice parameters. """ sigmas = batch[NOISE] # dimension: [batch_size, 1] xt = batch[NOISY_AXL_COMPOSITION].X batch_size = xt.shape[0] broadcast_sigmas = einops.repeat( - sigmas, - "batch 1 -> batch natoms spatial_dimension", - natoms=self.natoms, - spatial_dimension=self.spatial_dimension, - ) - - broadcast_effective_sigmas = (broadcast_sigmas**2 + self.sigma_d_square).sqrt() - - delta_relative_coordinates = map_relative_coordinates_to_unit_cell(xt - self.x0) - misnormalized_scores = get_sigma_normalized_score( - delta_relative_coordinates, broadcast_effective_sigmas, kmax=self.kmax + sigmas, "batch 1 -> batch n d", n=self.natoms, d=self.spatial_dimension ) - - sigma_normalized_scores = ( - broadcast_sigmas / broadcast_effective_sigmas * misnormalized_scores + _, sigma_normalized_scores = self.get_probabilities_and_normalized_scores( + relative_coordinates=xt, sigmas_t=broadcast_sigmas ) # Mimic perfect predictions of single possible atomic type. - atomic_logits = torch.zeros(batch_size, self.natoms, self.number_of_atomic_classes) + atomic_logits = torch.zeros( + batch_size, self.natoms, self.number_of_atomic_classes + ) atomic_logits[..., -1] = -torch.inf axl_scores = AXL( diff --git a/tests/models/score_network/test_analytical_score_network.py b/tests/models/score_network/test_analytical_score_network.py index a0537d54..8e4256d9 100644 --- a/tests/models/score_network/test_analytical_score_network.py +++ b/tests/models/score_network/test_analytical_score_network.py @@ -1,25 +1,19 @@ import itertools +import einops import pytest import torch from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( - AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters, - TargetScoreBasedAnalyticalScoreNetwork) + AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( AXL, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.utils.symmetry_utils import \ + factorial from tests.models.score_network.base_test_score_network import \ BaseTestScoreNetwork -def factorial(n): - - if n == 1: - return 1 - else: - return n * factorial(n - 1) - - class TestAnalyticalScoreNetwork(BaseTestScoreNetwork): @pytest.fixture(scope="class", autouse=True) def set_default_type_to_float64(self): @@ -51,27 +45,13 @@ def number_of_atoms(self, request): def equilibrium_relative_coordinates(self, number_of_atoms, spatial_dimension): return torch.rand(number_of_atoms, spatial_dimension) - """ - @pytest.fixture - def atom_types(self, batch_size, number_of_atoms, num_atom_types): - return torch.randint( - 0, - num_atom_types, - ( - batch_size, - number_of_atoms, - ), - ) - """ + @pytest.fixture(params=[True, False]) + def use_permutation_invariance(self, request): + return request.param - @pytest.fixture(params=["finite", "zero"]) - def variance_parameter(self, request): - if request.param == "zero": - return 0.0 - elif request.param == "finite": - # Make the spring constants pretty large so that the displacements will be small - inverse_variance = float(1000 * torch.rand(1)) - return 1.0 / inverse_variance + @pytest.fixture(params=[0.01, 0.1, 0.5]) + def sigma_d(self, request): + return request.param @pytest.fixture() def batch(self, batch_size, number_of_atoms, spatial_dimension, atom_types): @@ -97,18 +77,18 @@ def score_network_parameters( spatial_dimension, kmax, equilibrium_relative_coordinates, - variance_parameter, + sigma_d, use_permutation_invariance, - num_atom_types + num_atom_types, ): hyper_params = AnalyticalScoreNetworkParameters( number_of_atoms=number_of_atoms, spatial_dimension=spatial_dimension, kmax=kmax, equilibrium_relative_coordinates=equilibrium_relative_coordinates, - variance_parameter=variance_parameter, + sigma_d=sigma_d, use_permutation_invariance=use_permutation_invariance, - num_atom_types=num_atom_types + num_atom_types=num_atom_types, ) return hyper_params @@ -116,10 +96,6 @@ def score_network_parameters( def score_network(self, score_network_parameters): return AnalyticalScoreNetwork(score_network_parameters) - @pytest.fixture() - def target_score_based_score_network(self, score_network_parameters): - return TargetScoreBasedAnalyticalScoreNetwork(score_network_parameters) - def test_all_translations(self, kmax): computed_translations = AnalyticalScoreNetwork._get_all_translations(kmax) expected_translations = torch.tensor(list(range(-kmax, kmax + 1))) @@ -151,73 +127,255 @@ def test_get_all_equilibrium_permutations( torch.testing.assert_close(expected_permutations, computed_permutations) - @pytest.mark.parametrize("use_permutation_invariance", [False]) - def test_compute_unnormalized_log_probability( + def compute_log_wrapped_gaussian_for_testing( self, + relative_coordinates, equilibrium_relative_coordinates, - variance_parameter, + sigmas, + sigma_d, kmax, - batch, - score_network, ): - sigmas = batch[NOISE] # dimension: [batch_size, 1] - xt = batch[NOISY_AXL_COMPOSITION].X - computed_log_prob = score_network._compute_unnormalized_log_probability( - sigmas, xt, equilibrium_relative_coordinates + """Compute the log of a Wrapped Gaussian, for testing purposes.""" + batch_size, natoms, spatial_dimension = relative_coordinates.shape + + assert sigmas.shape == ( + batch_size, + 1, + ), "Unexpected shape for the sigmas tensor." + + assert equilibrium_relative_coordinates.shape == ( + natoms, + spatial_dimension, + ), "A single equilibrium configuration should be used." + + nd = natoms * spatial_dimension + + list_translations = torch.arange(-kmax, kmax + 1) + nt = len(list_translations) + + # Recast various spatial arrays to the correct dimensions to combine them, + # in dimensions [batch, nd, number_of_translations] + effective_variance = einops.repeat( + sigmas**2 + sigma_d**2, "batch 1 -> batch nd t", t=nt, nd=nd ) - batch_size = sigmas.shape[0] + x = einops.repeat( + relative_coordinates, "batch natoms space -> batch (natoms space) t", t=nt + ) - expected_log_prob = torch.zeros(batch_size) - for batch_idx in range(batch_size): - sigma = sigmas[batch_idx, 0] + x0 = einops.repeat( + equilibrium_relative_coordinates, + "natoms space -> batch (natoms space) t", + batch=batch_size, + t=nt, + ) - for i in range(score_network.natoms): - for alpha in range(score_network.spatial_dimension): + translations = einops.repeat( + list_translations, "t -> batch nd t", batch=batch_size, nd=nd + ) - eq_coordinate = equilibrium_relative_coordinates[i, alpha] - coordinate = xt[batch_idx, i, alpha] + exponent = -0.5 * (x - x0 - translations) ** 2 / effective_variance + # logsumexp on lattice translation vectors, then sum on spatial indices + unnormalized_log_prob = torch.logsumexp(exponent, dim=-1, keepdim=False).sum( + dim=1 + ) - sum_on_k = torch.tensor(0.0) - for k in range(-kmax, kmax + 1): - exponent = ( - -0.5 - * (coordinate - eq_coordinate - k) ** 2 - / (sigma**2 + variance_parameter) - ) - sum_on_k += torch.exp(exponent) + sigma2 = effective_variance[ + :, :, 0 + ] # We shouldn't sum the normalization term on k. + normalization_term = torch.sum( + 0.5 * torch.log(2.0 * torch.pi * sigma2), dim=[-1] + ) - expected_log_prob[batch_idx] += torch.log(sum_on_k) + log_wrapped_gaussian = unnormalized_log_prob - normalization_term - # Let's give a free pass to any problematic expected values, which are calculated with a fragile - # brute force approach - problem_mask = torch.logical_or(torch.isnan(expected_log_prob), torch.isinf(expected_log_prob)) - expected_log_prob[problem_mask] = computed_log_prob[problem_mask] + return log_wrapped_gaussian - torch.testing.assert_close(expected_log_prob, computed_log_prob) + @pytest.fixture() + def all_equilibrium_permutations(self, score_network): + return score_network.all_x0 + + @pytest.fixture() + def expected_wrapped_gaussians( + self, batch, all_equilibrium_permutations, sigma_d, kmax + ): + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X + sigmas = batch[NOISE] + + list_log_w = [] + for x0 in all_equilibrium_permutations: + log_w = self.compute_log_wrapped_gaussian_for_testing( + relative_coordinates, x0, sigmas, sigma_d=sigma_d, kmax=kmax + ) + list_log_w.append(log_w) + + expected_wrapped_gaussians = torch.stack(list_log_w) + return expected_wrapped_gaussians + + def test_log_wrapped_gaussians_computation( + self, expected_wrapped_gaussians, score_network, batch + ): + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X + sigmas = batch[NOISE] + batch_size, natoms, space_dimensions = relative_coordinates.shape + + sigmas_t = einops.repeat( + sigmas, + "batch 1 -> batch natoms space", + natoms=natoms, + space=space_dimensions, + ) + computed_wrapped_gaussians, _ = ( + score_network.get_log_wrapped_gaussians_and_normalized_scores_centered_on_equilibrium_positions( + relative_coordinates, sigmas_t + ) + ) + + torch.testing.assert_close( + expected_wrapped_gaussians, computed_wrapped_gaussians + ) + + def compute_sigma_normalized_scores_by_autograd_for_testing( + self, + relative_coordinates, + equilibrium_relative_coordinates, + sigmas, + sigma_d, + kmax, + ): + """Compute scores by autograd, for testing.""" + batch_size, natoms, spatial_dimension = relative_coordinates.shape + xt = relative_coordinates.clone() + xt.requires_grad_(True) + + log_w = self.compute_log_wrapped_gaussian_for_testing( + xt, equilibrium_relative_coordinates, sigmas, sigma_d=sigma_d, kmax=kmax + ) + grad_outputs = [torch.ones_like(log_w)] + + scores = torch.autograd.grad( + outputs=[log_w], inputs=[xt], grad_outputs=grad_outputs + )[0] + + # We actually want sigma x score. + broadcast_sigmas = einops.repeat( + sigmas, + "batch 1 -> batch natoms space", + natoms=natoms, + space=spatial_dimension, + ) + + sigma_normalized_scores = broadcast_sigmas * scores + return sigma_normalized_scores + + @pytest.fixture() + def expected_sigma_normalized_scores( + self, batch, all_equilibrium_permutations, sigma_d, kmax + ): + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X + sigmas = batch[NOISE] + + list_sigma_normalized_scores = [] + for x0 in all_equilibrium_permutations: + sigma_scores = self.compute_sigma_normalized_scores_by_autograd_for_testing( + relative_coordinates, x0, sigmas, sigma_d=sigma_d, kmax=kmax + ) + list_sigma_normalized_scores.append(sigma_scores) + sigma_normalized_scores = torch.stack(list_sigma_normalized_scores) + return sigma_normalized_scores + + def test_normalized_score_computation( + self, expected_sigma_normalized_scores, score_network, batch + ): + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X + sigmas = batch[NOISE] + batch_size, natoms, space_dimensions = relative_coordinates.shape + + sigmas_t = einops.repeat( + sigmas, + "batch 1 -> batch natoms space", + natoms=natoms, + space=space_dimensions, + ) + _, computed_normalized_scores = ( + score_network.get_log_wrapped_gaussians_and_normalized_scores_centered_on_equilibrium_positions( + relative_coordinates, sigmas_t + ) + ) + + torch.testing.assert_close( + expected_sigma_normalized_scores, computed_normalized_scores + ) @pytest.mark.parametrize( "number_of_atoms, use_permutation_invariance", [(1, False), (1, True), (2, False), (2, True), (8, False)], ) - def test_analytical_score_network( - self, score_network, batch, batch_size, number_of_atoms, spatial_dimension + def test_analytical_score_network_shapes( + self, + score_network, + batch, + batch_size, + number_of_atoms, + num_atom_types, + spatial_dimension, ): - normalized_scores = score_network.forward(batch) + model_axl = score_network.forward(batch) + + normalized_scores = model_axl.X + atom_type_preds = model_axl.A - assert normalized_scores.X.shape == ( + assert normalized_scores.shape == ( batch_size, number_of_atoms, spatial_dimension, ) + assert atom_type_preds.shape == ( + batch_size, + number_of_atoms, + num_atom_types + 1, + ) - @pytest.mark.parametrize("use_permutation_invariance", [False]) - @pytest.mark.parametrize("number_of_atoms", [1, 2, 8]) - def test_compare_score_networks( - self, score_network, target_score_based_score_network, batch + @pytest.mark.parametrize( + "number_of_atoms, use_permutation_invariance", + [(1, False), (1, True), (2, False), (2, True), (8, False)], + ) + def test_analytical_score_network( + self, score_network, batch, sigma_d, kmax, equilibrium_relative_coordinates ): + model_axl = score_network.forward(batch) - normalized_scores1 = score_network.forward(batch) - normalized_scores2 = target_score_based_score_network.forward(batch) + computed_normalized_scores = model_axl.X - torch.testing.assert_close(normalized_scores1, normalized_scores2) + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X + batch_size, natoms, spatial_dimension = relative_coordinates.shape + sigmas = batch[NOISE] + + list_log_w = [] + list_s = [] + for x0 in score_network.all_x0: + log_w = self.compute_log_wrapped_gaussian_for_testing( + relative_coordinates, x0, sigmas, sigma_d, kmax + ) + list_log_w.append(log_w) + + s = self.compute_sigma_normalized_scores_by_autograd_for_testing( + relative_coordinates, x0, sigmas, sigma_d, kmax + ) + list_s.append(s) + + list_log_w = torch.stack(list_log_w) + list_s = torch.stack(list_s) + + list_weights = einops.repeat( + torch.softmax(list_log_w, dim=0), + "n batch -> n batch natoms space", + natoms=natoms, + space=spatial_dimension, + ) + + expected_normalized_scores = torch.sum(list_weights * list_s, dim=0) + + torch.testing.assert_close( + expected_normalized_scores, computed_normalized_scores + ) From ecfe3cea149146f9a1cb276c3d956b1a67ea7e51 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 24 Dec 2024 13:17:44 -0500 Subject: [PATCH 05/24] Refactor name of datamodule module. --- data/process_lammps_data.py | 2 +- .../patches/fixed_position_data_loader.py | 2 +- experiments/dataset_analysis/dataset_covariance.py | 2 +- experiments/dataset_analysis/energy_consistency_analysis.py | 2 +- .../{data_loader.py => lammps_for_diffusion_data_module.py} | 0 .../train_diffusion.py | 2 +- tests/data/diffusion/test_data_loader.py | 2 +- 7 files changed, 6 insertions(+), 6 deletions(-) rename src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/{data_loader.py => lammps_for_diffusion_data_module.py} (100%) diff --git a/data/process_lammps_data.py b/data/process_lammps_data.py index 12713b3e..909e87ce 100644 --- a/data/process_lammps_data.py +++ b/data/process_lammps_data.py @@ -3,7 +3,7 @@ import logging import tempfile -from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import ( +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( LammpsForDiffusionDataModule, LammpsLoaderParameters) from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ setup_analysis_logger diff --git a/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py b/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py index 007cea40..48c69517 100644 --- a/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py +++ b/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py @@ -7,7 +7,7 @@ from equilibrium_structure import create_equilibrium_sige_structure from torch_geometric.data import DataLoader -from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import \ +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import \ LammpsLoaderParameters from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ ElementTypes diff --git a/experiments/dataset_analysis/dataset_covariance.py b/experiments/dataset_analysis/dataset_covariance.py index 70994a7e..abc19042 100644 --- a/experiments/dataset_analysis/dataset_covariance.py +++ b/experiments/dataset_analysis/dataset_covariance.py @@ -13,7 +13,7 @@ from diffusion_for_multi_scale_molecular_dynamics import (ANALYSIS_RESULTS_DIR, DATA_DIR) -from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import ( +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( LammpsForDiffusionDataModule, LammpsLoaderParameters) from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell diff --git a/experiments/dataset_analysis/energy_consistency_analysis.py b/experiments/dataset_analysis/energy_consistency_analysis.py index ec797b20..ba080327 100644 --- a/experiments/dataset_analysis/energy_consistency_analysis.py +++ b/experiments/dataset_analysis/energy_consistency_analysis.py @@ -19,7 +19,7 @@ PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) from diffusion_for_multi_scale_molecular_dynamics.callbacks.sampling_visualization_callback import \ SamplingVisualizationCallback -from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import ( +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( LammpsForDiffusionDataModule, LammpsLoaderParameters) from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ get_energy_and_forces_from_lammps diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_loader.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_for_diffusion_data_module.py similarity index 100% rename from src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_loader.py rename to src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_for_diffusion_data_module.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py index e95987c6..aafc92ac 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py @@ -12,7 +12,7 @@ from diffusion_for_multi_scale_molecular_dynamics.callbacks.callback_loader import \ create_all_callbacks -from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import ( +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( LammpsForDiffusionDataModule, LammpsLoaderParameters) from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ ElementTypes diff --git a/tests/data/diffusion/test_data_loader.py b/tests/data/diffusion/test_data_loader.py index a01297e0..7942bb17 100644 --- a/tests/data/diffusion/test_data_loader.py +++ b/tests/data/diffusion/test_data_loader.py @@ -5,7 +5,7 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import ( +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( LammpsForDiffusionDataModule, LammpsLoaderParameters) from diffusion_for_multi_scale_molecular_dynamics.data.element_types import ( NULL_ELEMENT, ElementTypes) From b1f7255eaddec8c12635df511c3a0f6024698101 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 24 Dec 2024 13:19:26 -0500 Subject: [PATCH 06/24] Rename lammps preprocessor code. --- .../data/diffusion/lammps_for_diffusion_data_module.py | 2 +- .../{data_preprocess.py => lammps_processor_for_diffusion.py} | 0 tests/data/diffusion/test_data_preprocess.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/{data_preprocess.py => lammps_processor_for_diffusion.py} (100%) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_for_diffusion_data_module.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_for_diffusion_data_module.py index 5fbc9d17..3ae96751 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_for_diffusion_data_module.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_for_diffusion_data_module.py @@ -12,7 +12,7 @@ import torch.nn.functional as F from torch.utils.data import DataLoader -from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_preprocess import \ +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_processor_for_diffusion import \ LammpsProcessorForDiffusion from diffusion_for_multi_scale_molecular_dynamics.data.element_types import ( NULL_ELEMENT, ElementTypes) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_preprocess.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_processor_for_diffusion.py similarity index 100% rename from src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_preprocess.py rename to src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_processor_for_diffusion.py diff --git a/tests/data/diffusion/test_data_preprocess.py b/tests/data/diffusion/test_data_preprocess.py index 6684448f..954a7972 100644 --- a/tests/data/diffusion/test_data_preprocess.py +++ b/tests/data/diffusion/test_data_preprocess.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_preprocess import \ +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_processor_for_diffusion import \ LammpsProcessorForDiffusion from diffusion_for_multi_scale_molecular_dynamics.namespace import ( CARTESIAN_FORCES, CARTESIAN_POSITIONS, RELATIVE_COORDINATES) From 51a612c135b131d7a21f398693d6fbe9d5e1d39d Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 24 Dec 2024 13:21:23 -0500 Subject: [PATCH 07/24] Better name. --- ...ata_preprocess.py => test_lammps_processor_for_diffusion.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/data/diffusion/{test_data_preprocess.py => test_lammps_processor_for_diffusion.py} (99%) diff --git a/tests/data/diffusion/test_data_preprocess.py b/tests/data/diffusion/test_lammps_processor_for_diffusion.py similarity index 99% rename from tests/data/diffusion/test_data_preprocess.py rename to tests/data/diffusion/test_lammps_processor_for_diffusion.py index 954a7972..1fba751d 100644 --- a/tests/data/diffusion/test_data_preprocess.py +++ b/tests/data/diffusion/test_lammps_processor_for_diffusion.py @@ -12,7 +12,7 @@ from tests.fake_data_utils import generate_parquet_dataframe -class TestDataProcess(TestDiffusionDataBase): +class TestLammpsProcessorForDiffusion(TestDiffusionDataBase): @pytest.fixture def processor(self, paths): From 7f10189913d30b634911e690bdfde3ea6ee54711 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 24 Dec 2024 13:23:14 -0500 Subject: [PATCH 08/24] Align test name. --- ..._data_loader.py => test_lammps_for_diffusion_data_module.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/data/diffusion/{test_data_loader.py => test_lammps_for_diffusion_data_module.py} (99%) diff --git a/tests/data/diffusion/test_data_loader.py b/tests/data/diffusion/test_lammps_for_diffusion_data_module.py similarity index 99% rename from tests/data/diffusion/test_data_loader.py rename to tests/data/diffusion/test_lammps_for_diffusion_data_module.py index 7942bb17..44b76ce6 100644 --- a/tests/data/diffusion/test_data_loader.py +++ b/tests/data/diffusion/test_lammps_for_diffusion_data_module.py @@ -40,7 +40,7 @@ def convert_configurations_to_dataset( return configuration_dataset -class TestDiffusionDataLoader(TestDiffusionDataBase): +class TestLammpsForDiffusionDataModule(TestDiffusionDataBase): @pytest.fixture def element_types(self, unique_elements): From d6fc645fb5c96d19dd8eac94cd0abb2afd5252be Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 24 Dec 2024 14:22:12 -0500 Subject: [PATCH 09/24] Better naming and clean data module instantiation. --- data/process_lammps_data.py | 4 +- .../patches/fixed_position_data_loader.py | 6 +-- .../dataset_analysis/dataset_covariance.py | 4 +- .../energy_consistency_analysis.py | 4 +- .../data/diffusion/data_module_parameters.py | 39 +++++++++++++++ .../data/diffusion/instantiate_data_module.py | 49 +++++++++++++++++++ .../lammps_for_diffusion_data_module.py | 22 ++++----- .../train_diffusion.py | 27 +++++----- .../test_lammps_for_diffusion_data_module.py | 4 +- 9 files changed, 121 insertions(+), 38 deletions(-) create mode 100644 src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_module_parameters.py create mode 100644 src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py diff --git a/data/process_lammps_data.py b/data/process_lammps_data.py index 909e87ce..03679be9 100644 --- a/data/process_lammps_data.py +++ b/data/process_lammps_data.py @@ -4,7 +4,7 @@ import tempfile from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( - LammpsForDiffusionDataModule, LammpsLoaderParameters) + LammpsDataModuleParameters, LammpsForDiffusionDataModule) from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ setup_analysis_logger from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ @@ -31,7 +31,7 @@ def main(): logger.info(f" --processed_datadir : {args.processed_datadir}") logger.info(f" --config: {args.config}") - data_params = LammpsLoaderParameters(**hyper_params) + data_params = LammpsDataModuleParameters(**hyper_params) with tempfile.TemporaryDirectory() as tmp_work_dir: data_module = LammpsForDiffusionDataModule(lammps_run_dir=lammps_run_dir, diff --git a/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py b/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py index 48c69517..dfb52dac 100644 --- a/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py +++ b/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py @@ -8,7 +8,7 @@ from torch_geometric.data import DataLoader from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import \ - LammpsLoaderParameters + LammpsDataModuleParameters from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ ElementTypes from diffusion_for_multi_scale_molecular_dynamics.namespace import ( @@ -24,7 +24,7 @@ def __init__( self, lammps_run_dir: str, # dummy processed_dataset_dir: str, - hyper_params: LammpsLoaderParameters, + hyper_params: LammpsDataModuleParameters, working_cache_dir: Optional[str] = None, # dummy ): """Init method.""" @@ -99,7 +99,7 @@ def clean_up(self): elements = ["Si", "Ge"] processed_dataset_dir = Path("/experiments/atom_types_only_experiments") - hyper_params = LammpsLoaderParameters( + hyper_params = LammpsDataModuleParameters( batch_size=64, train_batch_size=1024, valid_batch_size=1024, diff --git a/experiments/dataset_analysis/dataset_covariance.py b/experiments/dataset_analysis/dataset_covariance.py index abc19042..d94f44b1 100644 --- a/experiments/dataset_analysis/dataset_covariance.py +++ b/experiments/dataset_analysis/dataset_covariance.py @@ -14,7 +14,7 @@ from diffusion_for_multi_scale_molecular_dynamics import (ANALYSIS_RESULTS_DIR, DATA_DIR) from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( - LammpsForDiffusionDataModule, LammpsLoaderParameters) + LammpsDataModuleParameters, LammpsForDiffusionDataModule) from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ @@ -40,7 +40,7 @@ cache_dir = lammps_run_dir / "cache" -data_params = LammpsLoaderParameters(batch_size=2048, max_atom=max_atom) +data_params = LammpsDataModuleParameters(batch_size=2048, max_atom=max_atom) if __name__ == "__main__": setup_analysis_logger() diff --git a/experiments/dataset_analysis/energy_consistency_analysis.py b/experiments/dataset_analysis/energy_consistency_analysis.py index ba080327..4ca57cdb 100644 --- a/experiments/dataset_analysis/energy_consistency_analysis.py +++ b/experiments/dataset_analysis/energy_consistency_analysis.py @@ -20,7 +20,7 @@ from diffusion_for_multi_scale_molecular_dynamics.callbacks.sampling_visualization_callback import \ SamplingVisualizationCallback from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( - LammpsForDiffusionDataModule, LammpsLoaderParameters) + LammpsDataModuleParameters, LammpsForDiffusionDataModule) from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ get_energy_and_forces_from_lammps from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ @@ -38,7 +38,7 @@ cache_dir = str(EXPERIMENT_ANALYSIS_DIR / "cache" / dataset_name) -data_params = LammpsLoaderParameters(batch_size=64, max_atom=8) +data_params = LammpsDataModuleParameters(batch_size=64, max_atom=8) sample_size = 1000 diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_module_parameters.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_module_parameters.py new file mode 100644 index 00000000..5fa59715 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_module_parameters.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass(kw_only=True) +class DataModuleParameters: + """Base Hyper-parameters for Data Modules.""" + + # The data source must be specified in the concrete class. + data_source = None + + # Either batch_size XOR train_batch_size and valid_batch_size should be specified. + batch_size: Optional[int] = None + train_batch_size: Optional[int] = None + valid_batch_size: Optional[int] = None + num_workers: int = 0 + max_atom: int = 64 + spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. + elements: list[str] # the elements that can exist. + + def __post_init__(self): + """Post init.""" + assert self.data_source is not None, "The data source must be set." + + if self.batch_size is None: + assert ( + self.valid_batch_size is not None + ), "If batch_size is None, valid_batch_size must be specified." + assert ( + self.train_batch_size is not None + ), "If batch_size is None, train_batch_size must be specified." + + else: + assert ( + self.valid_batch_size is None + ), "If batch_size is specified, valid_batch_size must be None." + assert ( + self.train_batch_size is None + ), "If batch_size is specified, train_batch_size must be None." diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py new file mode 100644 index 00000000..3f28c89d --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py @@ -0,0 +1,49 @@ +"""Functions to instantiate a data loader based on the provided hyperparameters.""" +import argparse +import logging +from typing import Any, AnyStr, Dict + +import pytorch_lightning as pl + +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( + LammpsDataModuleParameters, LammpsForDiffusionDataModule) + +logger = logging.getLogger(__name__) + + +def load_data_module(hyper_params: Dict[AnyStr, Any], args: argparse.Namespace) -> pl.LightningDataModule: + """Load data module. + + This method creates the data module based on configuration and input arguments. + + Args: + hyper_params: configuration parameters. + args: parsed command line arguments. + + Returns: + data_module: the data module corresponding to the configuration and input arguments. + """ + assert 'data' in hyper_params, \ + "The configuration should contain a 'data' block describing the data source." + + data_config = hyper_params["data"] + + data_source = "LAMMPS" + if "data_source" in data_config: + data_source = data_config["data_source"] + else: + data_config["data_source"] = data_source + + match data_source: + case "LAMMPS": + data_params = LammpsDataModuleParameters(**data_config, elements=hyper_params["elements"]) + data_module = LammpsForDiffusionDataModule(hyper_params=data_params, + lammps_run_dir=args.data, + processed_dataset_dir=args.processed_datadir, + working_cache_dir=args.dataset_working_dir) + case _: + raise NotImplementedError( + f"Data source '{data_source}' is not implemented" + ) + + return data_module diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_for_diffusion_data_module.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_for_diffusion_data_module.py index 3ae96751..42d5bc30 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_for_diffusion_data_module.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/lammps_for_diffusion_data_module.py @@ -12,6 +12,8 @@ import torch.nn.functional as F from torch.utils.data import DataLoader +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_module_parameters import \ + DataModuleParameters from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_processor_for_diffusion import \ LammpsProcessorForDiffusion from diffusion_for_multi_scale_molecular_dynamics.data.element_types import ( @@ -23,17 +25,9 @@ @dataclass(kw_only=True) -class LammpsLoaderParameters: - """Base Hyper-parameters for score networks.""" - - # Either batch_size XOR train_batch_size and valid_batch_size should be specified. - batch_size: Optional[int] = None - train_batch_size: Optional[int] = None - valid_batch_size: Optional[int] = None - num_workers: int = 0 - max_atom: int = 64 - spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. - elements: list[str] # the elements that can exist. +class LammpsDataModuleParameters(DataModuleParameters): + """Hyper-Parameters for a Lammps-based data module.""" + data_source: str = "LAMMPS" class LammpsForDiffusionDataModule(pl.LightningDataModule): @@ -43,7 +37,7 @@ def __init__( self, lammps_run_dir: str, processed_dataset_dir: str, - hyper_params: LammpsLoaderParameters, + hyper_params: LammpsDataModuleParameters, working_cache_dir: Optional[str] = None, ): """Initialize a dataset of LAMMPS structures for training a diffusion model. @@ -60,7 +54,11 @@ def __init__( # check_and_log_hp(["batch_size", "num_workers"], hyper_params) # validate the hyperparameters # TODO add the padding parameters for number of atoms self.lammps_run_dir = lammps_run_dir + assert self.lammps_run_dir is not None, \ + "The LAMMPS run directory must be specified to use the LAMMPS data source." self.processed_dataset_dir = processed_dataset_dir + assert self.processed_dataset_dir is not None, \ + "The LAMMPS processed dataset directory must be specified to use the LAMMPS data source." self.working_cache_dir = working_cache_dir self.num_workers = hyper_params.num_workers self.max_atom = hyper_params.max_atom # number of atoms to pad tensors diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py index aafc92ac..7f3d44b3 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py @@ -12,8 +12,8 @@ from diffusion_for_multi_scale_molecular_dynamics.callbacks.callback_loader import \ create_all_callbacks -from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( - LammpsForDiffusionDataModule, LammpsLoaderParameters) +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.instantiate_data_module import \ + load_data_module from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ ElementTypes from diffusion_for_multi_scale_molecular_dynamics.loggers.logger_loader import \ @@ -45,15 +45,19 @@ def main(args: typing.Optional[typing.Any] = None): help="config file with generic hyper-parameters, such as optimizer, " "batch_size, ... - in yaml format", ) - parser.add_argument("--data", help="path to a LAMMPS data set", required=True) + parser.add_argument("--data", + help="path to a LAMMPS data set. REQUIRED if the data source is 'LAMMPS'.", + default=None, + required=False) parser.add_argument( "--processed_datadir", - help="path to the processed data directory", - required=True, + help="path to the processed data directory. REQUIRED if the data source is 'LAMMPS'.", + default=None, + required=False, ) parser.add_argument( "--dataset_working_dir", - help="path to the Datasets working directory. Defaults to None", + help="path to the Datasets working directory. Only relevant if the data source is 'LAMMPS'. Defaults to None", default=None, ) parser.add_argument( @@ -108,7 +112,7 @@ def main(args: typing.Optional[typing.Any] = None): run(args, output_dir, hyper_params) -def run(args, output_dir, hyper_params): +def run(args: argparse.Namespace, output_dir, hyper_params): """Create and run the dataloaders, training loops, etc. Args: @@ -123,14 +127,7 @@ def run(args, output_dir, hyper_params): ElementTypes.validate_elements(hyper_params["elements"]) - data_params = LammpsLoaderParameters(**hyper_params["data"], elements=hyper_params["elements"]) - - datamodule = LammpsForDiffusionDataModule( - lammps_run_dir=args.data, - processed_dataset_dir=args.processed_datadir, - hyper_params=data_params, - working_cache_dir=args.dataset_working_dir, - ) + datamodule = load_data_module(hyper_params, args) model = load_diffusion_model(hyper_params) diff --git a/tests/data/diffusion/test_lammps_for_diffusion_data_module.py b/tests/data/diffusion/test_lammps_for_diffusion_data_module.py index 44b76ce6..26000e25 100644 --- a/tests/data/diffusion/test_lammps_for_diffusion_data_module.py +++ b/tests/data/diffusion/test_lammps_for_diffusion_data_module.py @@ -6,7 +6,7 @@ import torch from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( - LammpsForDiffusionDataModule, LammpsLoaderParameters) + LammpsDataModuleParameters, LammpsForDiffusionDataModule) from diffusion_for_multi_scale_molecular_dynamics.data.element_types import ( NULL_ELEMENT, ElementTypes) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( @@ -141,7 +141,7 @@ def test_pad_dataset(self, input_data_for_padding, number_of_atoms, max_atom_for @pytest.fixture def data_module_hyperparameters(self, number_of_atoms, spatial_dimension, unique_elements): - return LammpsLoaderParameters( + return LammpsDataModuleParameters( batch_size=2, num_workers=0, max_atom=number_of_atoms, From 14825ca2e0856e8b530c95c943e3acfd4901d990 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 24 Dec 2024 16:19:23 -0500 Subject: [PATCH 10/24] A new data module for on-the-fly Gaussian datasets. --- .../data/diffusion/gaussian_data_module.py | 144 ++++++++++++++++++ .../diffusion/test_gaussian_data_module.py | 83 ++++++++++ 2 files changed, 227 insertions(+) create mode 100644 src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/gaussian_data_module.py create mode 100644 tests/data/diffusion/test_gaussian_data_module.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/gaussian_data_module.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/gaussian_data_module.py new file mode 100644 index 00000000..fd2b1328 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/gaussian_data_module.py @@ -0,0 +1,144 @@ +import logging +from dataclasses import dataclass +from typing import List, Optional + +import einops +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader + +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_module_parameters import \ + DataModuleParameters +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ + ElementTypes +from diffusion_for_multi_scale_molecular_dynamics.namespace import ( + ATOM_TYPES, CARTESIAN_FORCES, RELATIVE_COORDINATES) +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell + +logger = logging.getLogger(__name__) + + +@dataclass(kw_only=True) +class GaussianDataModuleParameters(DataModuleParameters): + """Hyper-parameters for a Gaussian, in memory data module.""" + data_source = "gaussian" + + random_seed: int + # the number of atoms in a configuration. + number_of_atoms: int + + # Standard deviation of the Gaussian distribution. + # The covariance matrix is assumed to be proportional to the identity matrix. + sigma_d: float = 0.01 + + # mean of the Gaussian Distribution + equilibrium_relative_coordinates: List[List[float]] + + train_dataset_size: int = 8_192 + valid_dataset_size: int = 1_024 + + def __post_init__(self): + """Post init.""" + assert self.sigma_d > 0.0, "the sigma_d parameter should be positive." + + assert len(self.equilibrium_relative_coordinates) == self.number_of_atoms, \ + "There should be exactly one list of equilibrium coordinates per atom." + + for x in self.equilibrium_relative_coordinates: + assert len(x) == self.spatial_dimension, \ + "The equilibrium coordinates should be consistent with the spatial dimension." + + assert len(self.elements) == 1, "There can only be one element type for the gaussian data module." + + +class GaussianDataModule(pl.LightningDataModule): + """Gaussian Data Module. + + Data module class that creates an in-memory dataset of relative coordinates that follow a Gaussian distribution. + """ + + def __init__( + self, + hyper_params: GaussianDataModuleParameters, + ): + """Init method.""" + super().__init__() + + self.random_seed = hyper_params.random_seed + self.number_of_atoms = hyper_params.number_of_atoms + self.spatial_dimension = hyper_params.spatial_dimension + self.sigma_d = hyper_params.sigma_d + self.equilibrium_coordinates = torch.tensor(hyper_params.equilibrium_relative_coordinates, + dtype=torch.float) + + self.train_dataset_size = hyper_params.train_dataset_size + self.valid_dataset_size = hyper_params.valid_dataset_size + + assert hyper_params.batch_size, "batch_size must be specified" + + self.batch_size = hyper_params.batch_size + self.train_size = hyper_params.train_batch_size + self.valid_size = hyper_params.valid_batch_size + + self.num_workers = hyper_params.num_workers + + self.element_types = ElementTypes(hyper_params.elements) + + def setup(self, stage: Optional[str] = None): + """Setup method.""" + self.train_dataset = [] + self.valid_dataset = [] + + rng = torch.Generator() + rng.manual_seed(self.random_seed) + + box = torch.ones(self.spatial_dimension, dtype=torch.float) + + atom_types = torch.zeros(self.number_of_atoms, dtype=torch.long) + + for dataset, batch_size in zip([self.train_dataset, self.valid_dataset], + [self.train_dataset_size, self.valid_dataset_size]): + + mean = einops.repeat(self.equilibrium_coordinates, + "natoms space -> batch natoms space", batch=batch_size) + std = self.sigma_d * torch.ones_like(mean) + relative_coordinates = map_relative_coordinates_to_unit_cell( + torch.normal(mean=mean, std=std, generator=rng).to(torch.float)) + + for x in relative_coordinates: + row = { + "natom": self.number_of_atoms, + "box": box, + RELATIVE_COORDINATES: x, + ATOM_TYPES: atom_types, + CARTESIAN_FORCES: torch.zeros_like(x), + "potential_energy": torch.tensor([0.0], dtype=torch.float), + } + dataset.append(row) + + def train_dataloader(self) -> DataLoader: + """Create the training dataloader using the training data parser.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + """Create the validation dataloader using the validation data parser.""" + return DataLoader( + self.valid_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + """Creates the testing dataloader using the testing data parser.""" + raise NotImplementedError("Test set is not defined at the moment.") + + def clean_up(self): + """Nothing to clean.""" + pass diff --git a/tests/data/diffusion/test_gaussian_data_module.py b/tests/data/diffusion/test_gaussian_data_module.py new file mode 100644 index 00000000..99520d45 --- /dev/null +++ b/tests/data/diffusion/test_gaussian_data_module.py @@ -0,0 +1,83 @@ +import pytest +import torch + +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.gaussian_data_module import ( + GaussianDataModule, GaussianDataModuleParameters) +from diffusion_for_multi_scale_molecular_dynamics.namespace import \ + RELATIVE_COORDINATES + + +class TestGaussianDataModule: + + @pytest.fixture() + def batch_size(self): + return 4 + + @pytest.fixture() + def train_dataset_size(self): + return 16 + + @pytest.fixture() + def valid_dataset_size(self): + return 8 + + @pytest.fixture() + def number_of_atoms(self): + return 4 + + @pytest.fixture() + def spatial_dimension(self): + return 2 + + @pytest.fixture() + def sigma_d(self): + return 0.01 + + @pytest.fixture() + def equilibrium_relative_coordinates(self, number_of_atoms, spatial_dimension): + list_x = torch.rand(number_of_atoms, spatial_dimension) + equilibrium_relative_coordinates = [list(x) for x in list_x.numpy()] + return equilibrium_relative_coordinates + + @pytest.fixture + def data_module_hyperparameters(self, batch_size, train_dataset_size, valid_dataset_size, + number_of_atoms, spatial_dimension, sigma_d, equilibrium_relative_coordinates): + return GaussianDataModuleParameters( + batch_size=batch_size, + random_seed=42, + num_workers=0, + sigma_d=sigma_d, + number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + equilibrium_relative_coordinates=equilibrium_relative_coordinates, + train_dataset_size=train_dataset_size, + valid_dataset_size=valid_dataset_size, + elements=['DUMMY'] + ) + + @pytest.fixture() + def data_module(self, data_module_hyperparameters): + + data_module = GaussianDataModule(hyper_params=data_module_hyperparameters) + data_module.setup() + + return data_module + + def test_train_data_loader(self, data_module, train_dataset_size, number_of_atoms, spatial_dimension): + self._check_data_loader(data_module.train_dataloader(), number_of_atoms, spatial_dimension, train_dataset_size) + + def test_validation_data_loader(self, data_module, valid_dataset_size, number_of_atoms, spatial_dimension): + self._check_data_loader(data_module.val_dataloader(), number_of_atoms, spatial_dimension, valid_dataset_size) + + def _check_data_loader(self, dataloader, number_of_atoms, spatial_dimension, dataset_size): + count = 0 + for batch in dataloader: + x = batch[RELATIVE_COORDINATES] + assert torch.all(x >= 0.) + assert torch.all(x < 1.) + batch_size, natoms, space = x.shape + count += batch_size + assert natoms == number_of_atoms + assert space == spatial_dimension + + assert count == dataset_size From 8d54ef4c5a7d08b7f60c3b736de469496d8217c3 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 25 Dec 2024 10:10:30 -0500 Subject: [PATCH 11/24] The ability to instantiate a Gaussian data module. --- .../data/diffusion/instantiate_data_module.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py index 3f28c89d..925d5a93 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py @@ -5,6 +5,8 @@ import pytorch_lightning as pl +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.gaussian_data_module import ( + GaussianDataModule, GaussianDataModuleParameters) from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.lammps_for_diffusion_data_module import ( LammpsDataModuleParameters, LammpsForDiffusionDataModule) @@ -41,6 +43,10 @@ def load_data_module(hyper_params: Dict[AnyStr, Any], args: argparse.Namespace) lammps_run_dir=args.data, processed_dataset_dir=args.processed_datadir, working_cache_dir=args.dataset_working_dir) + + case "gaussian": + data_params = GaussianDataModuleParameters(**data_config, elements=hyper_params["elements"]) + data_module = GaussianDataModule(data_params) case _: raise NotImplementedError( f"Data source '{data_source}' is not implemented" From bfb8506f2a666843a69a688e600ba030dd54ed07 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 25 Dec 2024 10:28:14 -0500 Subject: [PATCH 12/24] Make it easier to instantiate the analytical score network from a config file. --- .../analytical_score_network.py | 27 ++++++++++--------- .../test_analytical_score_network.py | 10 ++++--- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py index eb33979d..1a54650d 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py @@ -12,7 +12,7 @@ """ from dataclasses import dataclass -from typing import Any, AnyStr, Dict, Tuple +from typing import Any, AnyStr, Dict, List, Tuple import einops import torch @@ -41,8 +41,7 @@ class AnalyticalScoreNetworkParameters(ScoreNetworkParameters): # the maximum lattice translation along any dimension. Translations will be [-kmax,..,kmax]. kmax: int - # Should have shape [number_of_atoms, spatial_dimensions] - equilibrium_relative_coordinates: torch.Tensor + equilibrium_relative_coordinates: List[List[float]] # the data distribution variance. sigma_d: float @@ -60,10 +59,12 @@ def __post_init__(self): assert self.sigma_d > 0.0, "the sigma_d parameter should be positive." - assert self.equilibrium_relative_coordinates.shape == ( - self.number_of_atoms, - self.spatial_dimension, - ), "equilibrium relative coordinates have the wrong shape" + assert len(self.equilibrium_relative_coordinates) == self.number_of_atoms, \ + "There should be exactly one list of equilibrium coordinates per atom." + + for x in self.equilibrium_relative_coordinates: + assert len(x) == self.spatial_dimension, \ + "The equilibrium coordinates should be consistent with the spatial dimension." class AnalyticalScoreNetwork(ScoreNetwork): @@ -72,11 +73,12 @@ class AnalyticalScoreNetwork(ScoreNetwork): This 'score network' is for exploring and debugging. """ - def __init__(self, hyper_params: AnalyticalScoreNetworkParameters): + def __init__(self, hyper_params: AnalyticalScoreNetworkParameters, device=torch.device("cpu")): """__init__. Args: hyper_params : hyper parameters from the config file. + device: device to use. """ super(AnalyticalScoreNetwork, self).__init__(hyper_params) @@ -92,15 +94,14 @@ def __init__(self, hyper_params: AnalyticalScoreNetworkParameters): self.use_permutation_invariance = hyper_params.use_permutation_invariance - self.device = hyper_params.equilibrium_relative_coordinates.device + self.device = device # shape: [number_of_translations] self.translations_k = self._get_all_translations(self.kmax).to(self.device) self.number_of_translations = len(self.translations_k) - self.equilibrium_relative_coordinates = ( - hyper_params.equilibrium_relative_coordinates - ) + self.equilibrium_relative_coordinates = torch.tensor(hyper_params.equilibrium_relative_coordinates, + dtype=torch.float) if self.use_permutation_invariance: # Shape : [natom!, natoms, spatial dimension] @@ -110,7 +111,7 @@ def __init__(self, hyper_params: AnalyticalScoreNetworkParameters): else: # Shape : [1, natoms, spatial dimension] self.all_x0 = einops.rearrange( - hyper_params.equilibrium_relative_coordinates, "natom d -> 1 natom d" + self.equilibrium_relative_coordinates, "natom d -> 1 natom d" ) @staticmethod diff --git a/tests/models/score_network/test_analytical_score_network.py b/tests/models/score_network/test_analytical_score_network.py index 8e4256d9..4446a647 100644 --- a/tests/models/score_network/test_analytical_score_network.py +++ b/tests/models/score_network/test_analytical_score_network.py @@ -43,7 +43,8 @@ def number_of_atoms(self, request): @pytest.fixture def equilibrium_relative_coordinates(self, number_of_atoms, spatial_dimension): - return torch.rand(number_of_atoms, spatial_dimension) + list_x = torch.rand(number_of_atoms, spatial_dimension) + return [list(x.numpy()) for x in list_x] @pytest.fixture(params=[True, False]) def use_permutation_invariance(self, request): @@ -104,19 +105,20 @@ def test_all_translations(self, kmax): def test_get_all_equilibrium_permutations( self, number_of_atoms, spatial_dimension, equilibrium_relative_coordinates ): + + eq_rel_coords = torch.tensor(equilibrium_relative_coordinates) expected_permutations = [] for permutation_indices in itertools.permutations(range(number_of_atoms)): expected_permutations.append( - equilibrium_relative_coordinates[list(permutation_indices)] + eq_rel_coords[list(permutation_indices)] ) expected_permutations = torch.stack(expected_permutations) computed_permutations = ( AnalyticalScoreNetwork._get_all_equilibrium_permutations( - equilibrium_relative_coordinates - ) + eq_rel_coords) ) assert computed_permutations.shape == ( From 99bbaf479ba06b6cd3ea53da3b831885114e68be Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 25 Dec 2024 10:40:00 -0500 Subject: [PATCH 13/24] Add analytical score network to possible models to instantiate. --- .../models/score_networks/score_network_factory.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network_factory.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network_factory.py index 02adb3bb..f69ff720 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network_factory.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network_factory.py @@ -3,6 +3,8 @@ from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import ( ScoreNetwork, ScoreNetworkParameters) +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( + AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.diffusion_mace_score_network import ( DiffusionMACEScoreNetwork, DiffusionMACEScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.egnn_score_network import ( @@ -18,12 +20,14 @@ create_parameters_from_configuration_dictionary SCORE_NETWORKS_BY_ARCH = dict( + analytical=AnalyticalScoreNetwork, mlp=MLPScoreNetwork, mace=MACEScoreNetwork, diffusion_mace=DiffusionMACEScoreNetwork, egnn=EGNNScoreNetwork, ) SCORE_NETWORK_PARAMETERS_BY_ARCH = dict( + analytical=AnalyticalScoreNetworkParameters, mlp=MLPScoreNetworkParameters, mace=MACEScoreNetworkParameters, diffusion_mace=DiffusionMACEScoreNetworkParameters, From 778b4d993a18fa6d06a4ee1e824d8c931139b5b4 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 25 Dec 2024 10:46:26 -0500 Subject: [PATCH 14/24] Fix exp details. --- .../utils/logging_utils.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/logging_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/logging_utils.py index b1a5ecfe..5d00e661 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/logging_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/logging_utils.py @@ -122,14 +122,12 @@ def log_exp_details(script_location, args): git_hash = get_git_hash(script_location) hostname = socket.gethostname() dependencies = freeze.freeze() - details = ( - "\nhostname: {}\ngit code hash: {}\ndata folder: {}\ndata folder (abs): {}\n\n" - "dependencies:\n{}".format( - hostname, - git_hash, - args.data, - os.path.abspath(args.data), - "\n".join(dependencies), - ) - ) + details = f"\nhostname: {hostname}\ngit code hash: {git_hash}" + if args.data is not None: + details += f"\ndata folder: {args.data}\ndata folder (abs): {os.path.abspath(args.data)}\n\n" + else: + details += "\nNO DATA FOLDER PROVIDED\n\n" + + details += f"dependencies:\n{dependencies}" + logger.info("Experiment info:" + details + "\n") From 27c566a791e74ef0e6d3f083eb701940c2c099d1 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 25 Dec 2024 10:50:43 -0500 Subject: [PATCH 15/24] Cleaner data module instantiation. --- .../data/diffusion/instantiate_data_module.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py index 925d5a93..dafe8df5 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/instantiate_data_module.py @@ -29,12 +29,7 @@ def load_data_module(hyper_params: Dict[AnyStr, Any], args: argparse.Namespace) "The configuration should contain a 'data' block describing the data source." data_config = hyper_params["data"] - - data_source = "LAMMPS" - if "data_source" in data_config: - data_source = data_config["data_source"] - else: - data_config["data_source"] = data_source + data_source = data_config.pop("data_source", "LAMMPS") match data_source: case "LAMMPS": From 714b58386efbcfe0842767c14db4a07a5b29a277 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 25 Dec 2024 12:21:06 -0500 Subject: [PATCH 16/24] More comments for debug server. --- .../train_diffusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py index 7f3d44b3..c3274e7a 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py @@ -220,6 +220,8 @@ def train( # Uncomment the following in order to use Pycharm's Remote Debugging server, which allows to # launch python commands through a bash script (and through Orion!). VERY useful for debugging. # This requires a professional edition of Pycharm and installing the pydevd_pycharm package with pip. + # The debug server stopped workin in 2024.3. There is a workaround. See: + # https://www.reddit.com/r/pycharm/comments/1gs1lgk/python_debug_server_issues/ # import pydevd_pycharm # pydevd_pycharm.settrace('localhost', port=56636, stdoutToServer=True, stderrToServer=True) main() From bb7e1f3d369148daeb9aaad4a52899d9a695db43 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 25 Dec 2024 12:22:10 -0500 Subject: [PATCH 17/24] Make it possible to turn off the optimizer. This way, we can pass the analytical score network through the main code. --- .../models/axl_diffusion_lightning_model.py | 7 ++++++- .../models/optimizer.py | 20 ++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/axl_diffusion_lightning_model.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/axl_diffusion_lightning_model.py index 9a5cece5..ed39afe0 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/axl_diffusion_lightning_model.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/axl_diffusion_lightning_model.py @@ -14,7 +14,7 @@ from diffusion_for_multi_scale_molecular_dynamics.metrics.kolmogorov_smirnov_metrics import \ KolmogorovSmirnovMetrics from diffusion_for_multi_scale_molecular_dynamics.models.optimizer import ( - OptimizerParameters, load_optimizer) + OptimizerParameters, check_if_optimizer_is_none, load_optimizer) from diffusion_for_multi_scale_molecular_dynamics.models.scheduler import ( SchedulerParameters, load_scheduler_dictionary) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ @@ -93,6 +93,11 @@ def __init__(self, hyper_params: AXLDiffusionParameters): logger=False ) # It is not the responsibility of this class to log its parameters. + if check_if_optimizer_is_none(self.hyper_params.optimizer_parameters): + # If the config indicates None as the optimizer, then no optimization should + # take place. + self.automatic_optimization = False + # the score network is expected to produce an output as an AXL namedtuple: # atom: unnormalized estimate of p(a_0 | a_t) # relative coordinates: estimate of \sigma \nabla_{x_t} p_{t|0}(x_t | x_0) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/optimizer.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/optimizer.py index 1607f7f0..868425c5 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/optimizer.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/optimizer.py @@ -1,6 +1,6 @@ import logging from dataclasses import asdict, dataclass -from typing import Any, Dict +from typing import Any, Dict, Union import torch from torch import optim @@ -20,8 +20,22 @@ class OptimizerParameters: weight_decay: float = 0.0 -OPTIMIZERS_BY_NAME = dict(adam=optim.Adam, adamw=optim.AdamW) -OPTIMIZER_PARAMETERS_BY_NAME = dict(adam=OptimizerParameters, adamw=OptimizerParameters) +OPTIMIZERS_BY_NAME = {'adam': optim.Adam, 'adamw': optim.AdamW, 'None': optim.Adam} +OPTIMIZER_PARAMETERS_BY_NAME = {'adam': OptimizerParameters, 'adamw': OptimizerParameters, 'None': OptimizerParameters} + + +def check_if_optimizer_is_none(optimizer_parameters: Union[OptimizerParameters, None]) -> bool: + """Check if the optimizer is None. + + This is a useful check to see if pytorch-lightning optimization should be turned off. + + Args: + optimizer_parameters: optimizer parameters + + Returns: + Bool: whether or not the optimizer_parameters is None. + """ + return optimizer_parameters.name == 'None' def create_optimizer_parameters( From 89fb2ad820b1e8cc952e673ba812e793412dd220 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Wed, 25 Dec 2024 14:30:56 -0500 Subject: [PATCH 18/24] Class to plot the scores along a direction. --- .../analysis/score_viewer.py | 172 ++++++++++++++++++ 1 file changed, 172 insertions(+) create mode 100644 src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py new file mode 100644 index 00000000..8819b8f2 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py @@ -0,0 +1,172 @@ +from dataclasses import dataclass +from typing import List + +import einops +import torch +from matplotlib import pyplot as plt + +from diffusion_for_multi_scale_molecular_dynamics.analysis import ( + PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import \ + ScoreNetwork +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( + AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) +from diffusion_for_multi_scale_molecular_dynamics.namespace import ( + AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \ + VarianceScheduler +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell + +plt.style.use(PLOT_STYLE_PATH) + + +@dataclass(kw_only=True) +class ScoreViewerParameters: + """Parameters for the Score Viewer class.""" + sigma_min: float + sigma_max: float + + number_of_space_steps: int = 1000 + + # Starting and ending relative coordinates should be of shape [number of atoms, spatial dimension] + starting_relative_coordinates: List[List[float]] + ending_relative_coordinates: List[List[float]] + + +class ScoreViewer: + """Score Viewer. + + This class drives the generation of figures that show the score along specified + directions. + """ + + def __init__(self, score_viewer_parameters: ScoreViewerParameters, + analytical_score_network_parameters: AnalyticalScoreNetworkParameters): + """Init method.""" + total_time_steps = 8 + + noise_parameters = NoiseParameters(total_time_steps=total_time_steps, + sigma_min=score_viewer_parameters.sigma_min, + sigma_max=score_viewer_parameters.sigma_max) + + self.times = torch.tensor([0., 0.1, 0.2, 0.3, 0.4, 0.8, 0.9, 1.0]) + self.sigmas = VarianceScheduler(noise_parameters).get_sigma(self.times).numpy() + + self.analytical_score_network = AnalyticalScoreNetwork(analytical_score_network_parameters) + + self.natoms = analytical_score_network_parameters.number_of_atoms + + self.start = torch.tensor(score_viewer_parameters.starting_relative_coordinates) + self.end = torch.tensor(score_viewer_parameters.ending_relative_coordinates) + + self.number_of_space_steps = score_viewer_parameters.number_of_space_steps + + self.relative_coordinates, self.displacements = self.get_relative_coordinates_and_displacement() + self.direction_vector = self.get_direction_vector() + + def get_relative_coordinates_and_displacement(self): + """Get the relative coordinates and the displacement.""" + direction = (self.end - self.start) / self.number_of_space_steps + steps = torch.arange(self.number_of_space_steps + 1) + relative_coordinates = self.start.unsqueeze(0) + steps.view(-1, 1, 1) * direction.unsqueeze(0) + relative_coordinates = map_relative_coordinates_to_unit_cell(relative_coordinates) + + displacements = steps * direction.norm() + return relative_coordinates, displacements + + def get_direction_vector(self): + """Get direction vector.""" + direction_vector = einops.rearrange(self.end - self.start, "natoms space -> (natoms space)") + return direction_vector / direction_vector.norm() + + def get_batch(self, time: float, sigma: float): + """Get batch.""" + batch_size = self.relative_coordinates.shape[0] + + sigmas_t = sigma * torch.ones(batch_size, 1) + times = time * torch.ones(batch_size, 1) + unit_cell = torch.ones(batch_size, 1, 1) + forces = torch.zeros_like(self.relative_coordinates) + atom_types = torch.zeros(batch_size, self.natoms, dtype=torch.int64) + + composition = AXL(A=atom_types, + X=self.relative_coordinates, + L=torch.zeros_like(self.relative_coordinates)) + + batch = {NOISY_AXL_COMPOSITION: composition, + NOISE: sigmas_t, + TIME: times, + UNIT_CELL: unit_cell, + CARTESIAN_FORCES: forces} + return batch + + def create_figure(self, score_network: ScoreNetwork): + """Create a matplotlib figure.""" + figsize = (2 * PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[0]) + fig = plt.figure(figsize=figsize) + + ax1 = fig.add_subplot(241) + ax2 = fig.add_subplot(242) + ax3 = fig.add_subplot(243) + ax4 = fig.add_subplot(244) + ax5 = fig.add_subplot(245) + ax6 = fig.add_subplot(246) + ax7 = fig.add_subplot(247) + ax8 = fig.add_subplot(248) + + list_ax = [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8] + + list_params = [dict(color='green', lw=4, label='Analytical Normalized Score'), + dict(color='red', lw=2, label='Model Normalized Score')] + + list_score_networks = [self.analytical_score_network, score_network] + + for time, sigma, ax in zip(self.times, self.sigmas, list_ax): + batch = self.get_batch(time, sigma) + + for model, params in zip(list_score_networks, list_params): + sigma_normalized_scores = model(batch).X.detach() + vectors = einops.rearrange(sigma_normalized_scores, + "batch natoms space -> batch (natoms space)") + projected_sigma_normalized_scores = torch.matmul(vectors, self.direction_vector) + ax.plot(self.displacements, projected_sigma_normalized_scores, **params) + + ax.set_title(f"t = {time: 3.2f}," + r" $\sigma(t)$ = " + f"{sigma:5.3f}") + ax.spines["top"].set_visible(True) + ax.spines["right"].set_visible(True) + + ymin, ymax = ax.set_ylim() + ax.set_ylim(-ymax, ymax) + ax.set_xlim(self.displacements[0] - 0.01, self.displacements[-1] + 0.01) + + ax8.legend(loc=0) + fig.tight_layout() + + return fig + + +if __name__ == '__main__': + analytical_score_network_parameters = ( + AnalyticalScoreNetworkParameters(number_of_atoms=2, + spatial_dimension=1, + num_atom_types=1, + kmax=5, + sigma_d=0.01, + equilibrium_relative_coordinates=[[0.25], [0.75]], + use_permutation_invariance=True)) + + score_viewer_parameters = ScoreViewerParameters(sigma_min=0.001, + sigma_max=0.2, + starting_relative_coordinates=[[0.], [1.]], + ending_relative_coordinates=[[1.], [0.]]) + + score_viewer = ScoreViewer(score_viewer_parameters, analytical_score_network_parameters) + + score_network = AnalyticalScoreNetwork(analytical_score_network_parameters) + + fig = score_viewer.create_figure(score_network) + + plt.show() From 9da231e8dcccb8ada372f51da063cb32a0ff296e Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 26 Dec 2024 07:54:20 -0500 Subject: [PATCH 19/24] Fancier score viewer. --- .../analysis/score_viewer.py | 323 ++++++++++++++---- 1 file changed, 251 insertions(+), 72 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py index 8819b8f2..7d62576e 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py @@ -1,5 +1,6 @@ +from collections import defaultdict from dataclasses import dataclass -from typing import List +from typing import Dict, List, Tuple import einops import torch @@ -26,10 +27,11 @@ @dataclass(kw_only=True) class ScoreViewerParameters: """Parameters for the Score Viewer class.""" + sigma_min: float sigma_max: float - number_of_space_steps: int = 1000 + number_of_space_steps: int = 100 # Starting and ending relative coordinates should be of shape [number of atoms, spatial dimension] starting_relative_coordinates: List[List[float]] @@ -40,22 +42,36 @@ class ScoreViewer: """Score Viewer. This class drives the generation of figures that show the score along specified - directions. + directions. The figure is composed of 8 panes (matplotlib "axes") that show + the projected normalized score and various baselines along the specified 1D direction. + The projection of the score is on the tangent to the 1D line going from starting to ending + relative coordinates. """ - def __init__(self, score_viewer_parameters: ScoreViewerParameters, - analytical_score_network_parameters: AnalyticalScoreNetworkParameters): + def __init__( + self, + score_viewer_parameters: ScoreViewerParameters, + analytical_score_network_parameters: AnalyticalScoreNetworkParameters, + ): """Init method.""" total_time_steps = 8 - noise_parameters = NoiseParameters(total_time_steps=total_time_steps, - sigma_min=score_viewer_parameters.sigma_min, - sigma_max=score_viewer_parameters.sigma_max) + noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + sigma_min=score_viewer_parameters.sigma_min, + sigma_max=score_viewer_parameters.sigma_max, + ) + + self.times = torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.8, 0.9, 1.0]) + self.number_of_axes = 8 - self.times = torch.tensor([0., 0.1, 0.2, 0.3, 0.4, 0.8, 0.9, 1.0]) self.sigmas = VarianceScheduler(noise_parameters).get_sigma(self.times).numpy() - self.analytical_score_network = AnalyticalScoreNetwork(analytical_score_network_parameters) + self.analytical_score_network = AnalyticalScoreNetwork( + analytical_score_network_parameters + ) + + number_of_equilibrium_points = len(self.analytical_score_network.all_x0) self.natoms = analytical_score_network_parameters.number_of_atoms @@ -64,25 +80,49 @@ def __init__(self, score_viewer_parameters: ScoreViewerParameters, self.number_of_space_steps = score_viewer_parameters.number_of_space_steps - self.relative_coordinates, self.displacements = self.get_relative_coordinates_and_displacement() - self.direction_vector = self.get_direction_vector() + self.relative_coordinates, self.displacements = ( + self._get_relative_coordinates_and_displacement() + ) + self.direction_vector = self._get_direction_vector() + + # Compute the various references and baselines once and for all, keeping them in memory. + self.projected_analytical_scores = self._compute_projected_scores( + self.analytical_score_network + ) + + self.projected_gaussian_scores_dict = ( + self._get_naive_projected_gaussian_scores() + ) - def get_relative_coordinates_and_displacement(self): + self.all_in_distribution_bands = self._get_in_distribution_bands() + + self.plot_style_parameters = self._get_plot_style_parameters( + number_of_equilibrium_points + ) + + def _get_relative_coordinates_and_displacement(self): """Get the relative coordinates and the displacement.""" - direction = (self.end - self.start) / self.number_of_space_steps - steps = torch.arange(self.number_of_space_steps + 1) - relative_coordinates = self.start.unsqueeze(0) + steps.view(-1, 1, 1) * direction.unsqueeze(0) - relative_coordinates = map_relative_coordinates_to_unit_cell(relative_coordinates) + direction = (self.end - self.start) / (self.number_of_space_steps + 1) + # Avoid the first point, which tends to klank because of periodicity. + steps = torch.arange(self.number_of_space_steps + 1)[1:] + relative_coordinates = self.start.unsqueeze(0) + steps.view( + -1, 1, 1 + ) * direction.unsqueeze(0) + relative_coordinates = map_relative_coordinates_to_unit_cell( + relative_coordinates + ) displacements = steps * direction.norm() return relative_coordinates, displacements - def get_direction_vector(self): + def _get_direction_vector(self): """Get direction vector.""" - direction_vector = einops.rearrange(self.end - self.start, "natoms space -> (natoms space)") + direction_vector = einops.rearrange( + self.end - self.start, "natoms space -> (natoms space)" + ) return direction_vector / direction_vector.norm() - def get_batch(self, time: float, sigma: float): + def _get_batch(self, time: float, sigma: float): """Get batch.""" batch_size = self.relative_coordinates.shape[0] @@ -92,81 +132,220 @@ def get_batch(self, time: float, sigma: float): forces = torch.zeros_like(self.relative_coordinates) atom_types = torch.zeros(batch_size, self.natoms, dtype=torch.int64) - composition = AXL(A=atom_types, - X=self.relative_coordinates, - L=torch.zeros_like(self.relative_coordinates)) - - batch = {NOISY_AXL_COMPOSITION: composition, - NOISE: sigmas_t, - TIME: times, - UNIT_CELL: unit_cell, - CARTESIAN_FORCES: forces} + composition = AXL( + A=atom_types, + X=self.relative_coordinates, + L=torch.zeros_like(self.relative_coordinates), + ) + + batch = { + NOISY_AXL_COMPOSITION: composition, + NOISE: sigmas_t, + TIME: times, + UNIT_CELL: unit_cell, + CARTESIAN_FORCES: forces, + } return batch + def _compute_projected_scores(self, score_network: ScoreNetwork): + """Compute projected scores.""" + list_projected_scores = [] + for time, sigma in zip(self.times, self.sigmas): + batch = self._get_batch(time, sigma) + + sigma_normalized_scores = score_network(batch).X.detach() + vectors = einops.rearrange( + sigma_normalized_scores, "batch natoms space -> batch (natoms space)" + ) + projected_sigma_normalized_scores = torch.matmul( + vectors, self.direction_vector + ) + list_projected_scores.append(projected_sigma_normalized_scores) + + return list_projected_scores + + def _get_naive_projected_gaussian_scores(self) -> Dict: + """Compute the scores as if coming from a simple, single Gaussian.""" + projected_gaussian_scores_dict = defaultdict(list) + for time, sigma in zip(self.times, self.sigmas): + + prefactor = -sigma / ( + sigma**2 + self.analytical_score_network.sigma_d_square + ) + + for idx, x0 in enumerate(self.analytical_score_network.all_x0): + + equilibrium_relative_coordinates = einops.repeat( + x0, + "natoms space -> batch natoms space", + batch=self.number_of_space_steps, + ) + + directions = ( + self.relative_coordinates - equilibrium_relative_coordinates + ) + vectors = einops.rearrange( + directions, "batch natoms space -> batch (natoms space)" + ) + projected_directions = torch.matmul(vectors, self.direction_vector) + gaussian_normalized_score = prefactor * projected_directions + projected_gaussian_scores_dict[idx].append(gaussian_normalized_score) + + return projected_gaussian_scores_dict + + def _get_in_distribution_bands(self) -> List[List[Tuple]]: + """Create the limits where the relative coordinates are within sigma_eff of an equilibrium point.""" + # Start of the 1D visualization line + origin = einops.rearrange(self.start, "natoms space -> (natoms space)") + + # Tangent vector to the 1D visualization line + d_hat = einops.rearrange( + self.end - self.start, "natoms space -> (natoms space)" + ) + d_hat = d_hat / d_hat.norm() + + list_bands = [] + for sigma in self.sigmas: + effective_sigma_square = torch.tensor( + sigma**2 + self.analytical_score_network.sigma_d_square + ) + + bands = [] + for idx, x0 in enumerate(self.analytical_score_network.all_x0): + center = einops.rearrange(x0, "natoms space -> (natoms space)") + + # We can build a simple quadratic equation to identify where the visualization line + # intersects the sphere centered at 'center' with radius effective_sigma. + v = origin - center + a = 1.0 + b = 2.0 * torch.dot(d_hat, v) + c = torch.dot(v, v) - effective_sigma_square + + discriminant = b**2 - 4 * a * c + + if discriminant < 0: + band = None + else: + dmin = (-b - torch.sqrt(discriminant)) / (2 * a) + dmax = (-b + torch.sqrt(discriminant)) / (2 * a) + band = (dmin, dmax) + bands.append(band) + list_bands.append(bands) + return list_bands + + def _get_plot_style_parameters(self, number_of_equilibrium_points: int): + """Create the linestyles for each item in the plots.""" + gaussian_linestyle = dict(ls="--", color="black", lw=2, alpha=0.5) + + list_params = [ + dict(ls="-", color="green", lw=4, label="Analytical Normalized Score"), + dict(ls="-", color="red", lw=2, label="Model Normalized Score"), + ] + + gaussian_params = dict(gaussian_linestyle) + gaussian_params["label"] = "Gaussian Approximation" + list_params.append(gaussian_params) + + for _ in range(number_of_equilibrium_points - 1): + gaussian_params = dict(gaussian_linestyle) + gaussian_params["label"] = "__nolegend__" + list_params.append(gaussian_params) + + return list_params + def create_figure(self, score_network: ScoreNetwork): - """Create a matplotlib figure.""" + """Create Figure. + + Create a matplotlib figure showing the projected normalized scores for the model + along with various baselines. + """ + model_projected_scores = self._compute_projected_scores(score_network) + figsize = (2 * PLEASANT_FIG_SIZE[0], PLEASANT_FIG_SIZE[0]) fig = plt.figure(figsize=figsize) - ax1 = fig.add_subplot(241) - ax2 = fig.add_subplot(242) - ax3 = fig.add_subplot(243) - ax4 = fig.add_subplot(244) - ax5 = fig.add_subplot(245) - ax6 = fig.add_subplot(246) - ax7 = fig.add_subplot(247) - ax8 = fig.add_subplot(248) - - list_ax = [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8] + ax_id = 240 + for idx, (time, sigma) in enumerate(zip(self.times, self.sigmas)): + ax_id += 1 + ax = fig.add_subplot(ax_id) - list_params = [dict(color='green', lw=4, label='Analytical Normalized Score'), - dict(color='red', lw=2, label='Model Normalized Score')] + ax.set_title(f"t = {time: 3.2f}," + r" $\sigma(t)$ = " + f"{sigma:5.3f}") + ax.spines["top"].set_visible(True) + ax.spines["right"].set_visible(True) - list_score_networks = [self.analytical_score_network, score_network] + list_projected_scores = [ + self.projected_analytical_scores[idx], + model_projected_scores[idx], + ] - for time, sigma, ax in zip(self.times, self.sigmas, list_ax): - batch = self.get_batch(time, sigma) + maximum_projected_score_values = ( + torch.stack(list_projected_scores).abs().max() + ) - for model, params in zip(list_score_networks, list_params): - sigma_normalized_scores = model(batch).X.detach() - vectors = einops.rearrange(sigma_normalized_scores, - "batch natoms space -> batch (natoms space)") - projected_sigma_normalized_scores = torch.matmul(vectors, self.direction_vector) - ax.plot(self.displacements, projected_sigma_normalized_scores, **params) + for ( + all_gaussian_projected_scores + ) in self.projected_gaussian_scores_dict.values(): + list_projected_scores.append(all_gaussian_projected_scores[idx]) - ax.set_title(f"t = {time: 3.2f}," + r" $\sigma(t)$ = " + f"{sigma:5.3f}") - ax.spines["top"].set_visible(True) - ax.spines["right"].set_visible(True) + for projected_scores, params in zip( + list_projected_scores, self.plot_style_parameters + ): + ax.plot(self.displacements, projected_scores, **params) - ymin, ymax = ax.set_ylim() + ymax = 1.2 * maximum_projected_score_values ax.set_ylim(-ymax, ymax) ax.set_xlim(self.displacements[0] - 0.01, self.displacements[-1] + 0.01) - ax8.legend(loc=0) + bands = self.all_in_distribution_bands[idx] + label = r"$x_0 \pm \sigma_{eff}$" + for band in bands: + if band is None: + continue + ax.fill_betweenx( + y=[-ymax, ymax], + x1=band[0], + x2=band[1], + color="green", + alpha=0.10, + label=label, + ) + label = "__nolegend__" + + # The last ax gets the legend. + ax.legend(loc=0) fig.tight_layout() return fig -if __name__ == '__main__': - analytical_score_network_parameters = ( - AnalyticalScoreNetworkParameters(number_of_atoms=2, - spatial_dimension=1, - num_atom_types=1, - kmax=5, - sigma_d=0.01, - equilibrium_relative_coordinates=[[0.25], [0.75]], - use_permutation_invariance=True)) - - score_viewer_parameters = ScoreViewerParameters(sigma_min=0.001, - sigma_max=0.2, - starting_relative_coordinates=[[0.], [1.]], - ending_relative_coordinates=[[1.], [0.]]) - - score_viewer = ScoreViewer(score_viewer_parameters, analytical_score_network_parameters) +if __name__ == "__main__": + # A simple demonstration of how the Score Viewer works. We naively use an analytical score network + # as the external score network, such that the 'model' results will overlap with the analytical score baseline. + analytical_score_network_parameters = AnalyticalScoreNetworkParameters( + number_of_atoms=2, + spatial_dimension=1, + num_atom_types=1, + kmax=5, + sigma_d=0.01, + equilibrium_relative_coordinates=[[0.25], [0.75]], + use_permutation_invariance=True, + ) + + score_viewer_parameters = ScoreViewerParameters( + sigma_min=0.001, + sigma_max=0.2, + starting_relative_coordinates=[[0.0], [1.0]], + ending_relative_coordinates=[[1.0], [0.0]], + ) + + score_viewer = ScoreViewer( + score_viewer_parameters, analytical_score_network_parameters + ) score_network = AnalyticalScoreNetwork(analytical_score_network_parameters) fig = score_viewer.create_figure(score_network) + fig.suptitle("Demonstration") + fig.tight_layout() plt.show() From d0d0b2350d904b1347fa41827caad56ee5dec0dc Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 26 Dec 2024 09:06:18 -0500 Subject: [PATCH 20/24] Bring scores back to cpu. --- .../analysis/score_viewer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py index 7d62576e..db42e1ef 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py @@ -153,7 +153,7 @@ def _compute_projected_scores(self, score_network: ScoreNetwork): for time, sigma in zip(self.times, self.sigmas): batch = self._get_batch(time, sigma) - sigma_normalized_scores = score_network(batch).X.detach() + sigma_normalized_scores = score_network(batch).X.detach().cpu() vectors = einops.rearrange( sigma_normalized_scores, "batch natoms space -> batch (natoms space)" ) From f335322f500cbbfda8156089a1b7db503d180637 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 26 Dec 2024 10:11:44 -0500 Subject: [PATCH 21/24] New callback to show scores along a path. --- .../callbacks/callback_loader.py | 3 + .../callbacks/score_viewer_callback.py | 79 +++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 src/diffusion_for_multi_scale_molecular_dynamics/callbacks/score_viewer_callback.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/callback_loader.py b/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/callback_loader.py index 934b3fd7..1c755d26 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/callback_loader.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/callback_loader.py @@ -7,6 +7,8 @@ instantiate_loss_monitoring_callback from diffusion_for_multi_scale_molecular_dynamics.callbacks.sampling_visualization_callback import \ instantiate_sampling_visualization_callback +from diffusion_for_multi_scale_molecular_dynamics.callbacks.score_viewer_callback import \ + instantiate_score_viewer_callback from diffusion_for_multi_scale_molecular_dynamics.callbacks.standard_callbacks import ( CustomProgressBar, instantiate_early_stopping_callback, instantiate_model_checkpoint_callbacks) @@ -16,6 +18,7 @@ model_checkpoint=instantiate_model_checkpoint_callbacks, sampling_visualization=instantiate_sampling_visualization_callback, loss_monitoring=instantiate_loss_monitoring_callback, + score_viewer=instantiate_score_viewer_callback ) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/score_viewer_callback.py b/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/score_viewer_callback.py new file mode 100644 index 00000000..4ae56d3e --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/score_viewer_callback.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass +from typing import Any, AnyStr, Dict + +from matplotlib import pyplot as plt +from pytorch_lightning import Callback, LightningModule, Trainer + +from diffusion_for_multi_scale_molecular_dynamics.analysis.score_viewer import ( + ScoreViewer, ScoreViewerParameters) +from diffusion_for_multi_scale_molecular_dynamics.loggers.logger_loader import \ + log_figure +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import \ + AnalyticalScoreNetworkParameters + + +@dataclass(kw_only=True) +class ScoreViewerCallbackParameters: + """Parameters to decide what to plot and write to disk.""" + + record_every_n_epochs: int = 1 + + score_viewer_parameters: ScoreViewerParameters + analytical_score_network_parameters: AnalyticalScoreNetworkParameters + + +def instantiate_score_viewer_callback( + callback_params: Dict[AnyStr, Any], output_directory: str, verbose: bool +) -> Dict[str, Callback]: + """Instantiate the Diffusion Sampling callback.""" + analytical_score_network_parameters = ( + AnalyticalScoreNetworkParameters(**callback_params['analytical_score_network'])) + + score_viewer_parameters = ScoreViewerParameters(**callback_params['score_viewer_parameters']) + + score_viewer_callback_parameters = ScoreViewerCallbackParameters( + record_every_n_epochs=callback_params['record_every_n_epochs'], + score_viewer_parameters=score_viewer_parameters, + analytical_score_network_parameters=analytical_score_network_parameters) + + callback = ScoreViewerCallback( + score_viewer_callback_parameters, output_directory + ) + + return dict(score_viewer=callback) + + +class ScoreViewerCallback(Callback): + """Score Viewer Callback.""" + + def __init__(self, score_viewer_callback_parameters: ScoreViewerCallbackParameters, output_directory: str): + """Init method.""" + self.record_every_n_epochs = score_viewer_callback_parameters.record_every_n_epochs + self.score_viewer = ScoreViewer( + score_viewer_parameters=score_viewer_callback_parameters.score_viewer_parameters, + analytical_score_network_parameters=score_viewer_callback_parameters.analytical_score_network_parameters) + + def _compute_results_at_this_epoch(self, current_epoch: int) -> bool: + """Check if results should be computed at this epoch.""" + return current_epoch % self.record_every_n_epochs == 0 + + def on_validation_end(self, trainer: Trainer, pl_model: LightningModule) -> None: + """On validation epoch end.""" + if not self._compute_results_at_this_epoch(trainer.current_epoch): + return + + figure = self.score_viewer.create_figure(score_network=pl_model.axl_network) + figure.suptitle(f"Epoch {trainer.current_epoch}, Step {trainer.global_step}") + # Set the DPI so we can actually see something in the logger window. + figure.set_dpi(100) + figure.tight_layout() + + for pl_logger in trainer.loggers: + log_figure( + figure=figure, + global_step=trainer.current_epoch, + dataset="validation", + pl_logger=pl_logger, + name="projected_normalized_scores", + ) + plt.close(figure) From 4071d3a5972ce7f3d52de97bb949d7a06855ca21 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 26 Dec 2024 15:41:19 -0500 Subject: [PATCH 22/24] Fix assert error comment. --- .../models/score_networks/score_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py index 20b067c0..42d2146d 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py @@ -92,7 +92,8 @@ def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): assert ( len(relative_coordinates_shape) == 3 and relative_coordinates_shape[2] == self.spatial_dimension - ), "The relative coordinates are expected to be in a tensor of shape [batch_size, number_of_atoms, 3]" + ), ("The relative coordinates are expected to be in a tensor of " + "shape [batch_size, number_of_atoms, spatial_dimension]") assert torch.logical_and( relative_coordinates >= 0.0, relative_coordinates < 1.0 From 98afd8751092453d2b717ff10226959536354669 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 26 Dec 2024 17:15:58 -0500 Subject: [PATCH 23/24] Ensure that the EGNN score network works in 1D, 2D and 3D. --- .../models/egnn_utils.py | 3 +- .../models/graph_utils.py | 4 +- .../score_networks/egnn_score_network.py | 1 + .../utils/neighbors.py | 77 +++++++++++++---- .../test_score_network_general_tests.py | 7 +- tests/utils/test_neighbors.py | 82 ++++++++++++++++--- 6 files changed, 143 insertions(+), 31 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/egnn_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/egnn_utils.py index fa2417fd..2b6d0374 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/egnn_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/egnn_utils.py @@ -109,6 +109,7 @@ def get_edges_with_radial_cutoff( unit_cell: torch.Tensor, radial_cutoff: float = 4.0, drop_duplicate_edges: bool = True, + spatial_dimension: int = 3 ) -> torch.Tensor: """Get edges for a batch with a cutoff based on distance. @@ -127,7 +128,7 @@ def get_edges_with_radial_cutoff( relative_coordinates, unit_cell ) adj_matrix, _, _, _ = get_adj_matrix( - cartesian_coordinates, unit_cell, radial_cutoff + cartesian_coordinates, unit_cell, radial_cutoff, spatial_dimension ) # adj_matrix is a n_edges x 2 tensor with duplicates with different shifts. # the uplifting in 2 x spatial_dimension manages the shifts in a natural way. This means we can ignore the shifts diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/graph_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/graph_utils.py index 3f3defc1..48799828 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/graph_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/graph_utils.py @@ -8,7 +8,7 @@ def get_adj_matrix( - positions: torch.Tensor, basis_vectors: torch.Tensor, radial_cutoff: float = 4.0 + positions: torch.Tensor, basis_vectors: torch.Tensor, radial_cutoff: float = 4.0, spatial_dimension: int = 3 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Create the adjacency and shift matrices. @@ -32,7 +32,7 @@ def get_adj_matrix( batch_size, number_of_atoms, spatial_dimensions = positions.shape adjacency_info = get_periodic_adjacency_information( - positions, basis_vectors, radial_cutoff + positions, basis_vectors, radial_cutoff, spatial_dimension ) # The indices in the adjacency matrix must be shifted to account for the batching diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/egnn_score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/egnn_score_network.py index 1068ba5c..325b33b7 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/egnn_score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/egnn_score_network.py @@ -214,6 +214,7 @@ def _forward_unchecked( batch[UNIT_CELL], self.radial_cutoff, drop_duplicate_edges=self.drop_duplicate_edges, + spatial_dimension=self.spatial_dimension ) edges = edges.to(relative_coordinates.device) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/neighbors.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/neighbors.py index 5b652d3d..970bfa23 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/neighbors.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/neighbors.py @@ -8,6 +8,7 @@ import itertools from collections import namedtuple +import einops import numpy as np import torch from pykeops.torch import LazyTensor @@ -32,7 +33,7 @@ def get_periodic_adjacency_information( - cartesian_positions: torch.Tensor, basis_vectors: torch.Tensor, radial_cutoff: float + cartesian_positions: torch.Tensor, basis_vectors: torch.Tensor, radial_cutoff: float, spatial_dimension: int = 3 ) -> AdjacencyInfo: """Get periodic adjacency information. @@ -61,7 +62,7 @@ def get_periodic_adjacency_information( Args: cartesian_positions : atomic positions, assumed to be within the unit cell, in Euclidean space, in Angstrom. - Dimension [batch_size, max_number_of_atoms, 3] + Dimension [batch_size, max_number_of_atoms, spatial_dimension] basis_vectors : vectors that define the unit cell, (a1, a2, a3). The basis vectors are assumed to be vertically stacked, namely [-- a1 --] @@ -69,6 +70,7 @@ def get_periodic_adjacency_information( [-- a3 --] Dimension [batch_size, 3, 3]. radial_cutoff : largest distance between neighbors, in Angstrom. + spatial_dimension: the dimension of space. Returns: adjacency_info: an AdjacencyInfo object that contains @@ -83,7 +85,6 @@ def get_periodic_adjacency_information( ), "Wrong number of dimensions for relative_coordinates" assert len(basis_vectors.shape) == 3, "Wrong number of dimensions for basis_vectors" - spatial_dimension = 3 # We define this to avoid "magic numbers" in the code below. batch_size, max_natom, spatial_dimension_ = cartesian_positions.shape assert ( spatial_dimension_ == spatial_dimension @@ -103,7 +104,7 @@ def get_periodic_adjacency_information( # Check that the radial cutoff does not lead to possible neighbors beyond the first shell. shortest_cell_crossing_distances = _get_shortest_distance_that_crosses_unit_cell( - basis_vectors + basis_vectors, spatial_dimension=spatial_dimension ) assert torch.all(shortest_cell_crossing_distances > radial_cutoff), ( "The radial cutoff is so large that neighbors could be located " @@ -112,7 +113,7 @@ def get_periodic_adjacency_information( # The relative coordinates lattice vectors have dimensions [number of lattice vectors, spatial_dimension] relative_lattice_vectors = _get_relative_coordinates_lattice_vectors( - number_of_shells=1 + number_of_shells=1, spatial_dimension=spatial_dimension ).to(device) number_of_relative_lattice_vectors = len(relative_lattice_vectors) @@ -266,7 +267,7 @@ def _get_shifted_positions( def _get_shortest_distance_that_crosses_unit_cell( - basis_vectors: torch.Tensor, + basis_vectors: torch.Tensor, spatial_dimension: int = 3 ) -> torch.Tensor: """Get the shortest distance that crosses unit cell. @@ -280,21 +281,15 @@ def _get_shortest_distance_that_crosses_unit_cell( / v / --------------------------- - Args: basis_vectors : basis vectors that define the unit cell. - Dimension [batch_size, spatial_dimension = 3] + Dimension [batch_size, spatial_dimension] Returns: shortest_distances: shortest distance that can cross the unit cell, from one side to its other parallel side. Dimension [batch_size]. """ - # It is straightforward to show that the distance between two parallel planes, - # (say the plane spanned by (a1, a2) crossing the origin and the plane spanned by (a1, a2) crossing the point a3) - # is given by unit_normal DOT a3. The unit normal to the plane is proportional to the cross product of a1 and a2. - # - # This idea must be repeated for the three pairs of planes bounding the unit cell. - spatial_dimension = 3 + assert spatial_dimension in {1, 2, 3}, "The spatial dimension must be 1, 2 or 3." assert len(basis_vectors.shape) == 3, "basis_vectors has wrong shape." assert ( basis_vectors.shape[1] == spatial_dimension @@ -302,6 +297,60 @@ def _get_shortest_distance_that_crosses_unit_cell( assert ( basis_vectors.shape[2] == spatial_dimension ), "Basis vectors in wrong spatial dimension." + + match spatial_dimension: + case 1: + return _get_shortest_distance_that_crosses_unit_cell_1d(basis_vectors) + case 2: + return _get_shortest_distance_that_crosses_unit_cell_2d(basis_vectors) + case 3: + return _get_shortest_distance_that_crosses_unit_cell_3d(basis_vectors) + case _: + raise RuntimeError("Spatial dimension must be 1, 2 or 3.") + + +def _get_shortest_distance_that_crosses_unit_cell_1d( + basis_vectors: torch.Tensor, +) -> torch.Tensor: + """Get the shortest distance that crosses unit cell in 1D.""" + distances = basis_vectors.norm(dim=[-1, -2]) + return distances + + +def _get_shortest_distance_that_crosses_unit_cell_2d( + basis_vectors: torch.Tensor, +) -> torch.Tensor: + """Get the shortest distance that crosses unit cell in 2D.""" + a1 = basis_vectors[:, 0, :] + a2 = basis_vectors[:, 1, :] + + dot_product = einops.einsum(a1, a2, "b i, b i -> b") + + norm_a1 = torch.norm(a1, dim=-1) + norm_a2 = torch.norm(a2, dim=-1) + + orthogonal_a2 = a2 - (dot_product / norm_a1**2).unsqueeze(1) * a1 + distances_1 = orthogonal_a2.norm(dim=-1) + + orthogonal_a1 = a1 - (dot_product / norm_a2**2).unsqueeze(1) * a2 + distances_2 = orthogonal_a1.norm(dim=-1) + + distances = ( + torch.stack([distances_1, distances_2], dim=1).min(dim=1).values + ) + + return distances + + +def _get_shortest_distance_that_crosses_unit_cell_3d( + basis_vectors: torch.Tensor, +) -> torch.Tensor: + """Get the shortest distance that crosses unit cell in 3D.""" + # It is straightforward to show that the distance between two parallel planes, + # (say the plane spanned by (a1, a2) crossing the origin and the plane spanned by (a1, a2) crossing the point a3) + # is given by unit_normal DOT a3. The unit normal to the plane is proportional to the cross product of a1 and a2. + # + # This idea must be repeated for the three pairs of planes bounding the unit cell. a1 = basis_vectors[:, 0, :] a2 = basis_vectors[:, 1, :] a3 = basis_vectors[:, 2, :] diff --git a/tests/models/score_network/test_score_network_general_tests.py b/tests/models/score_network/test_score_network_general_tests.py index 30ddc2b3..5acd766e 100644 --- a/tests/models/score_network/test_score_network_general_tests.py +++ b/tests/models/score_network/test_score_network_general_tests.py @@ -319,10 +319,15 @@ def score_network(self, score_network_parameters): class TestEGNNScoreNetwork(BaseScoreNetworkGeneralTests): + @pytest.fixture(params=[1, 2, 3]) + def spatial_dimension(self, request): + return request.param + @pytest.fixture(params=[("fully_connected", None), ("radial_cutoff", 3.0)]) - def score_network_parameters(self, request, num_atom_types): + def score_network_parameters(self, request, spatial_dimension, num_atom_types): edges, radial_cutoff = request.param return EGNNScoreNetworkParameters( + spatial_dimension=spatial_dimension, edges=edges, radial_cutoff=radial_cutoff, num_atom_types=num_atom_types ) diff --git a/tests/utils/test_neighbors.py b/tests/utils/test_neighbors.py index b58c1f8e..4d294ff5 100644 --- a/tests/utils/test_neighbors.py +++ b/tests/utils/test_neighbors.py @@ -35,9 +35,25 @@ def number_of_atoms(): return 32 +@pytest.fixture() +def spatial_dimension(request): + return 3 + + +@pytest.fixture +def basis_vectors(batch_size, spatial_dimension): + # orthogonal boxes with dimensions between 5 and 10. + orthogonal_boxes = torch.stack( + [torch.diag(5.0 + 5.0 * torch.rand(spatial_dimension)) for _ in range(batch_size)] + ) + # add a bit of noise to make the vectors not quite orthogonal + basis_vectors = orthogonal_boxes + 0.1 * torch.randn(batch_size, spatial_dimension, spatial_dimension) + return basis_vectors + + @pytest.fixture -def relative_coordinates(batch_size, number_of_atoms): - return torch.rand(batch_size, number_of_atoms, 3) +def relative_coordinates(batch_size, number_of_atoms, spatial_dimension): + return torch.rand(batch_size, number_of_atoms, spatial_dimension) @pytest.fixture @@ -47,9 +63,9 @@ def positions(relative_coordinates, basis_vectors): @pytest.fixture -def lattice_vectors(batch_size, basis_vectors, number_of_shells): +def lattice_vectors(batch_size, basis_vectors, number_of_shells, spatial_dimension): relative_lattice_vectors = _get_relative_coordinates_lattice_vectors( - number_of_shells + number_of_shells, spatial_dimension ) batched_relative_lattice_vectors = relative_lattice_vectors.repeat(batch_size, 1, 1) lattice_vectors = get_positions_from_coordinates( @@ -218,31 +234,71 @@ def test_get_periodic_adjacency_information( ) +@pytest.mark.parametrize("spatial_dimension", [1, 2, 3]) def test_get_periodic_neighbour_indices_and_displacements_large_cutoff( - basis_vectors, relative_coordinates + basis_vectors, relative_coordinates, spatial_dimension ): # Check that the code crashes if the radial cutoff is too big! shortest_cell_crossing_distances = _get_shortest_distance_that_crosses_unit_cell( - basis_vectors + basis_vectors, spatial_dimension=spatial_dimension ).min() - large_radial_cutoff = shortest_cell_crossing_distances + 0.1 - small_radial_cutoff = shortest_cell_crossing_distances - 0.1 + large_radial_cutoff = (shortest_cell_crossing_distances + 0.1).item() + small_radial_cutoff = (shortest_cell_crossing_distances - 0.1).item() # Should run get_periodic_adjacency_information( - relative_coordinates, basis_vectors, small_radial_cutoff + relative_coordinates, basis_vectors, small_radial_cutoff, spatial_dimension=spatial_dimension ) with pytest.raises(AssertionError): # Should crash get_periodic_adjacency_information( - relative_coordinates, basis_vectors, large_radial_cutoff + relative_coordinates, basis_vectors, large_radial_cutoff, spatial_dimension=spatial_dimension ) @pytest.mark.parametrize("number_of_shells", [1, 2, 3]) -def test_get_relative_coordinates_lattice_vectors(number_of_shells): +def test_get_relative_coordinates_lattice_vectors_1d(number_of_shells): + + expected_lattice_vectors = [] + + for nx in torch.arange(-number_of_shells, number_of_shells + 1): + lattice_vector = torch.tensor([nx]) + expected_lattice_vectors.append(lattice_vector) + + expected_lattice_vectors = torch.stack(expected_lattice_vectors).to( + dtype=torch.float32 + ) + computed_lattice_vectors = _get_relative_coordinates_lattice_vectors( + number_of_shells, spatial_dimension=1 + ) + + torch.testing.assert_close(expected_lattice_vectors, computed_lattice_vectors) + + +@pytest.mark.parametrize("number_of_shells", [1, 2, 3]) +def test_get_relative_coordinates_lattice_vectors_2d(number_of_shells): + + expected_lattice_vectors = [] + + for nx in torch.arange(-number_of_shells, number_of_shells + 1): + for ny in torch.arange(-number_of_shells, number_of_shells + 1): + lattice_vector = torch.tensor([nx, ny]) + expected_lattice_vectors.append(lattice_vector) + + expected_lattice_vectors = torch.stack(expected_lattice_vectors).to( + dtype=torch.float32 + ) + computed_lattice_vectors = _get_relative_coordinates_lattice_vectors( + number_of_shells, spatial_dimension=2 + ) + + torch.testing.assert_close(expected_lattice_vectors, computed_lattice_vectors) + + +@pytest.mark.parametrize("number_of_shells", [1, 2, 3]) +def test_get_relative_coordinates_lattice_vectors_3d(number_of_shells): expected_lattice_vectors = [] @@ -256,7 +312,7 @@ def test_get_relative_coordinates_lattice_vectors(number_of_shells): dtype=torch.float32 ) computed_lattice_vectors = _get_relative_coordinates_lattice_vectors( - number_of_shells + number_of_shells, spatial_dimension=3 ) torch.testing.assert_close(expected_lattice_vectors, computed_lattice_vectors) @@ -289,7 +345,7 @@ def test_get_shifted_positions(positions, lattice_vectors): ) -def test_get_shortest_distance_that_crosses_unit_cell(basis_vectors): +def test_get_shortest_distance_that_crosses_unit_cell_3d(basis_vectors): expected_shortest_distances = [] for matrix in basis_vectors.numpy(): a1, a2, a3 = matrix From fd25a21641265347a7f10566c87be7c88b32143c Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 26 Dec 2024 19:33:38 -0500 Subject: [PATCH 24/24] an example using the on-the-fly data module --- .../config_diffusion_egnn_2_atoms_in_1D.yaml | 117 ++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 examples/config_files/diffusion/config_diffusion_egnn_2_atoms_in_1D.yaml diff --git a/examples/config_files/diffusion/config_diffusion_egnn_2_atoms_in_1D.yaml b/examples/config_files/diffusion/config_diffusion_egnn_2_atoms_in_1D.yaml new file mode 100644 index 00000000..a234112b --- /dev/null +++ b/examples/config_files/diffusion/config_diffusion_egnn_2_atoms_in_1D.yaml @@ -0,0 +1,117 @@ +#================================================================================ +# Configuration file for a diffusion experiment for 2 pseudo-atoms in 1D. +# +# An 'on-the-fly' Gaussian dataset is created and used for training. +#================================================================================ +exp_name: egnn_2_atoms_in_1D +run_name: run1 +max_epoch: 1000 +log_every_n_steps: 1 +gradient_clipping: 0.0 +accumulate_grad_batches: 1 # make this number of forward passes before doing a backprop step + +elements: [A] + +# set to null to avoid setting a seed (can speed up GPU computation, but +# results will not be reproducible) +seed: 1234 + +# On-the-fly Data Module that creates a Gaussian dataset. +data: + data_source: gaussian + random_seed: 42 + number_of_atoms: 2 + sigma_d: 0.01 + equilibrium_relative_coordinates: + - [0.25] + - [0.75] + + train_dataset_size: 8_192 + valid_dataset_size: 1_024 + + batch_size: 64 + num_workers: 0 + max_atom: 2 + spatial_dimension: 1 + + +spatial_dimension: 1 + +model: + loss: + coordinates_algorithm: mse + atom_types_ce_weight: 0.0 + atom_types_lambda_weight: 0.0 + relative_coordinates_lambda_weight: 1.0 + lattice_lambda_weight: 0.0 + score_network: + architecture: egnn + spatial_dimension: 1 + num_atom_types: 1 + n_layers: 4 + coordinate_hidden_dimensions_size: 128 + coordinate_n_hidden_dimensions: 4 + coords_agg: "mean" + message_hidden_dimensions_size: 128 + message_n_hidden_dimensions: 4 + node_hidden_dimensions_size: 128 + node_n_hidden_dimensions: 4 + attention: False + normalize: True + residual: True + tanh: False + edges: fully_connected + noise: + total_time_steps: 100 + sigma_min: 0.001 + sigma_max: 0.2 + +# optimizer and scheduler +optimizer: + name: adamw + learning_rate: 0.001 + weight_decay: 5.0e-8 + + +scheduler: + name: CosineAnnealingLR + T_max: 1000 + eta_min: 0.0 + +# early stopping +early_stopping: + metric: validation_epoch_loss + mode: min + patience: 1000 + +model_checkpoint: + monitor: validation_epoch_loss + mode: min + +score_viewer: + record_every_n_epochs: 1 + + score_viewer_parameters: + sigma_min: 0.001 + sigma_max: 0.2 + number_of_space_steps: 100 + starting_relative_coordinates: + - [0.0] + - [1.0] + ending_relative_coordinates: + - [1.0] + - [0.0] + analytical_score_network: + architecture: "analytical" + spatial_dimension: 1 + number_of_atoms: 2 + num_atom_types: 1 + kmax: 5 + equilibrium_relative_coordinates: + - [0.25] + - [0.75] + sigma_d: 0.01 + use_permutation_invariance: True + +logging: + - tensorboard