diff --git a/crystal_diffusion/analysis/exploding_variance_analysis.py b/crystal_diffusion/analysis/exploding_variance_analysis.py index a19b8133..ed19ae26 100644 --- a/crystal_diffusion/analysis/exploding_variance_analysis.py +++ b/crystal_diffusion/analysis/exploding_variance_analysis.py @@ -8,9 +8,8 @@ from crystal_diffusion import ANALYSIS_RESULTS_DIR from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH -from crystal_diffusion.samplers.time_sampler import TimeParameters, TimeSampler from crystal_diffusion.samplers.variance_sampler import ( - ExplodingVarianceSampler, VarianceParameters) + ExplodingVarianceSampler, NoiseParameters) from crystal_diffusion.score.wrapped_gaussian_score import \ get_sigma_normalized_score @@ -18,19 +17,11 @@ if __name__ == '__main__': - variance_parameters = VarianceParameters() - time_parameters = TimeParameters(total_time_steps=1000) + noise_parameters = NoiseParameters(total_time_steps=1000) + variance_sampler = ExplodingVarianceSampler(noise_parameters=noise_parameters) - time_sampler = TimeSampler(time_parameters=time_parameters) - variance_sampler = ExplodingVarianceSampler(variance_parameters=variance_parameters, - time_sampler=time_sampler) + noise = variance_sampler.get_all_noise() - indices = torch.arange(time_parameters.total_time_steps) - times = time_sampler.get_time_steps(indices) - sigmas = torch.sqrt(variance_sampler.get_variances(indices)) - gs = torch.sqrt(variance_sampler.get_g_squared_factors(indices[1:])) - - # A first figure to compare the "smart" and the "brute force" results fig1 = plt.figure(figsize=PLEASANT_FIG_SIZE) fig1.suptitle("Noise Schedule") @@ -38,8 +29,8 @@ ax2 = fig1.add_subplot(223) ax3 = fig1.add_subplot(122) - ax1.plot(times, sigmas, '-', c='k', lw=2) - ax2.plot(times[1:], gs, '-', c='k', lw=2) + ax1.plot(noise.time, noise.sigma, '-', c='k', lw=2) + ax2.plot(noise.time[1:], noise.g[1:], '-', c='k', lw=2) ax1.set_ylabel('$\\sigma(t)$') ax2.set_ylabel('$g(t)$') @@ -55,17 +46,18 @@ kmax = 4 indices = torch.tensor([1, 250, 750, 999]) - times = time_sampler.get_time_steps(indices) - sigmas = torch.sqrt(variance_sampler.get_variances(indices)) - gs_squared = variance_sampler.get_g_squared_factors(indices) + + times = noise.time.take(indices) + sigmas = noise.sigma.take(indices) + gs_squared = noise.g_squared.take(indices) for t, sigma in zip(times, sigmas): - target_scores = get_sigma_normalized_score(relative_positions, - torch.ones_like(relative_positions) * sigma, - kmax=kmax) - ax3.plot(relative_positions, sigma * target_scores, label=f"t = {t:3.2f}") + target_sigma_normalized_scores = get_sigma_normalized_score(relative_positions, + torch.ones_like(relative_positions) * sigma, + kmax=kmax) + ax3.plot(relative_positions, target_sigma_normalized_scores, label=f"t = {t:3.2f}") - ax3.set_title("Target Noise") + ax3.set_title("Target Normalized Score") ax3.set_xlabel("relative position, u") ax3.set_ylabel("$\\sigma(t) \\times S(u, t)$") ax3.legend(loc=0) diff --git a/crystal_diffusion/analysis/target_score_analysis.py b/crystal_diffusion/analysis/target_score_analysis.py index 35aa2de3..b4cfc30f 100644 --- a/crystal_diffusion/analysis/target_score_analysis.py +++ b/crystal_diffusion/analysis/target_score_analysis.py @@ -10,8 +10,8 @@ from crystal_diffusion import ANALYSIS_RESULTS_DIR from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH from crystal_diffusion.score.wrapped_gaussian_score import ( - SIGMA_THRESHOLD, get_expected_sigma_normalized_score_brute_force, - get_sigma_normalized_score) + SIGMA_THRESHOLD, get_sigma_normalized_score, + get_sigma_normalized_score_brute_force) plt.style.use(PLOT_STYLE_PATH) @@ -32,7 +32,7 @@ sigma = sigma_factor * SIGMA_THRESHOLD sigmas = torch.ones_like(relative_positions) * sigma - list_scores_brute = np.array([get_expected_sigma_normalized_score_brute_force(u, sigma) for u in list_u]) + list_scores_brute = np.array([get_sigma_normalized_score_brute_force(u, sigma) for u in list_u]) list_scores = get_sigma_normalized_score(relative_positions, sigmas, kmax=kmax).numpy() error = list_scores - list_scores_brute @@ -73,7 +73,7 @@ ms=ms, c=color, lw=2, alpha=0.25, label=f'kmax = {kmax}') list_scores_brute = np.array([ - get_expected_sigma_normalized_score_brute_force(u, sigma, kmax=4 * kmax) for sigma in sigmas]) + get_sigma_normalized_score_brute_force(u, sigma, kmax=4 * kmax) for sigma in sigmas]) ax4.semilogy(sigma_factors, list_scores_brute, 'o-', ms=ms, c=color, lw=2, alpha=0.25, label=f'kmax = {4 * kmax}') diff --git a/crystal_diffusion/models/my_model.py b/crystal_diffusion/models/my_model.py deleted file mode 100644 index 9c22bda3..00000000 --- a/crystal_diffusion/models/my_model.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -import typing - -import pytorch_lightning as pl - -from crystal_diffusion.models.optim import load_optimizer - -logger = logging.getLogger(__name__) - - -class BaseModel(pl.LightningModule): - """Base class for Pytorch Lightning model - useful to reuse the same *_step methods.""" - - def configure_optimizers(self): - """Returns the combination of optimizer(s) and learning rate scheduler(s) to train with. - - Here, we read all the optimization-related hyperparameters from the config dictionary and - create the required optimizer/scheduler combo. - - This function will be called automatically by the pytorch lightning trainer implementation. - See https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html for more info - on the expected returned elements. - """ - # we use the generic loading function from the `model_loader` module, but it could be made - # a direct part of the model (useful if we want layer-dynamic optimization) - return load_optimizer(self.hparams, self) - - def _generic_step( - self, - batch: typing.Any, - batch_idx: int, - ) -> typing.Any: - """Runs the prediction + evaluation step for training/validation/testing.""" - input_data, targets = batch - preds = self(input_data) # calls the forward pass of the model - loss = self.loss_fn(preds, targets) - return loss - - def training_step(self, batch, batch_idx): - """Runs a prediction step for training, returning the loss.""" - loss = self._generic_step(batch, batch_idx) - self.log("train_loss", loss) - self.log("epoch", self.current_epoch) - self.log("step", self.global_step) - return loss # this function is required, as the loss returned here is used for backprop - - def validation_step(self, batch, batch_idx): - """Runs a prediction step for validation, logging the loss.""" - loss = self._generic_step(batch, batch_idx) - self.log("val_loss", loss) - - def test_step(self, batch, batch_idx): - """Runs a prediction step for testing, logging the loss.""" - loss = self._generic_step(batch, batch_idx) - self.log("test_loss", loss) diff --git a/crystal_diffusion/models/optim.py b/crystal_diffusion/models/optim.py deleted file mode 100644 index a1246847..00000000 --- a/crystal_diffusion/models/optim.py +++ /dev/null @@ -1,51 +0,0 @@ -import logging - -import torch -from torch import optim - -logger = logging.getLogger(__name__) - - -def load_optimizer(hyper_params, model): # pragma: no cover - """Instantiate the optimizer. - - Args: - hyper_params (dict): hyper parameters from the config file - model (obj): A neural network model object. - - Returns: - optimizer (obj): The optimizer for the given model - """ - optimizer_name = hyper_params["optimizer"] - # __TODO__ fix optimizer list - if optimizer_name == "adam": - optimizer = optim.Adam(model.parameters()) - elif optimizer_name == "sgd": - optimizer = optim.SGD(model.parameters()) - else: - raise ValueError("optimizer {} not supported".format(optimizer_name)) - return optimizer - - -def load_loss(hyper_params): # pragma: no cover - r"""Instantiate the loss. - - You can add some math directly in your docstrings, however don't forget the `r` - to indicate it is being treated as restructured text. For example, an L1-loss can be - defined as: - - .. math:: - \text{loss}(x, y) = \frac{1}{n} \sum_{i} z_{i} - - Args: - hyper_params (dict): hyper parameters from the config file - - Returns: - loss (obj): The loss for the given model - """ - loss_name = hyper_params["loss"] - if loss_name == "cross_entropy": - loss = torch.nn.CrossEntropyLoss() - else: - raise ValueError("loss {} not supported".format(loss_name)) - return loss diff --git a/crystal_diffusion/models/optimizer.py b/crystal_diffusion/models/optimizer.py new file mode 100644 index 00000000..8aa9ae8c --- /dev/null +++ b/crystal_diffusion/models/optimizer.py @@ -0,0 +1,41 @@ +import logging +from dataclasses import dataclass +from enum import Enum + +import torch +from torch import optim + +logger = logging.getLogger(__name__) + + +class ValidOptimizerNames(Enum): + """Valid optimizer names.""" + adam = "adam" + sgd = "sgd" + + +@dataclass(kw_only=True) +class OptimizerParameters: + """Parameters for the optimizer.""" + name: ValidOptimizerNames + learning_rate: float + + +def load_optimizer(hyper_params: OptimizerParameters, model: torch.nn.Module) -> optim.Optimizer: + """Instantiate the optimizer. + + Args: + hyper_params : hyperparameters defining the optimizer + model : A neural network model. + + Returns: + optimizer : The optimizer for the given model + """ + match hyper_params.name: + case ValidOptimizerNames.adam: + optimizer = optim.Adam(model.parameters(), lr=hyper_params.learning_rate) + case ValidOptimizerNames.sgd: + optimizer = optim.SGD(model.parameters(), lr=hyper_params.learning_rate) + case _: + raise ValueError(f"optimizer {hyper_params.name} not supported") + return optimizer diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py new file mode 100644 index 00000000..6b4db88e --- /dev/null +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -0,0 +1,211 @@ +import logging +import typing +from dataclasses import dataclass + +import pytorch_lightning as pl +import torch + +from crystal_diffusion.models.optimizer import (OptimizerParameters, + load_optimizer) +from crystal_diffusion.models.score_network import (MLPScoreNetwork, + MLPScoreNetworkParameters) +from crystal_diffusion.samplers.noisy_position_sampler import ( + NoisyPositionSampler, map_positions_to_unit_cell) +from crystal_diffusion.samplers.variance_sampler import ( + ExplodingVarianceSampler, NoiseParameters) +from crystal_diffusion.score.wrapped_gaussian_score import \ + get_sigma_normalized_score +from crystal_diffusion.utils.tensor_utils import \ + broadcast_batch_tensor_to_all_dimensions + +logger = logging.getLogger(__name__) + + +@dataclass(kw_only=True) +class PositionDiffusionParameters: + """Position Diffusion parameters.""" + + score_network_parameters: MLPScoreNetworkParameters + optimizer_parameters: OptimizerParameters + noise_parameters: NoiseParameters + kmax_target_score: int = ( + 4 # convergence parameter for the Ewald-like sum of the perturbation kernel. + ) + + +class PositionDiffusionLightningModel(pl.LightningModule): + """Position Diffusion Lightning Model. + + This lightning model can train a score network predict the noise for relative positions. + """ + + def __init__(self, hyper_params: PositionDiffusionParameters): + """Init method. + + This initializes the class. + """ + super().__init__() + + self.hyper_params = hyper_params + + # we will model sigma x score + self.sigma_normalized_score_network = MLPScoreNetwork( + hyper_params.score_network_parameters + ) + + self.noisy_position_sampler = NoisyPositionSampler() + self.variance_sampler = ExplodingVarianceSampler(hyper_params.noise_parameters) + + def configure_optimizers(self): + """Returns the combination of optimizer(s) and learning rate scheduler(s) to train with. + + Here, we read all the optimization-related hyperparameters from the config dictionary and + create the required optimizer/scheduler combo. + + This function will be called automatically by the pytorch lightning trainer implementation. + See https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html for more info + on the expected returned elements. + """ + return load_optimizer(self.hyper_params.optimizer_parameters, self) + + def _generic_step( + self, + batch: typing.Any, + batch_idx: int, + ) -> typing.Any: + """Generic step. + + This "generic step" computes the loss for any of the possible lightning "steps". + + The loss is defined as: + L = 1 / T int_0^T dt lambda(t) E_{x0 ~ p_data} E_{xt~ p_{t| 0}} + [|S_theta(xt, t) - nabla_{xt} log p_{t | 0} (xt | x0)|^2] + + Where + T : time range of the noising process + S_theta : score network + p_{t| 0} : perturbation kernel + nabla log p : the target score + lambda(t) : is arbitrary, but chosen for convenience. + + In this implementation, we choose lambda(t) = sigma(t)^2 ( a standard choice from the literature), such + that the score network and the target scores that are used are actually "sigma normalized" versions, ie, + pre-multiplied by sigma. + + The loss that is computed is a Monte Carlo estimate of L, where we sample a mini-batch of relative position + configurations {x0}; each of these configurations is noised with a random t value, with corresponding + {sigma(t)} and {xt}. + + Args: + batch : a dictionary that should contain a data sample. + batch_idx : index of the batch + + Returns: + loss : the computed loss. + """ + # The relative positions have dimensions [batch_size, number_of_atoms, spatial_dimension]. + assert "relative_positions" in batch, "The field 'relative_positions' is missing from the input." + x0 = batch["relative_positions"] + shape = x0.shape + assert len(shape) == 3, ( + f"the shape of the relative_positions array should be [batch_size, number_of_atoms, spatial_dimensions]. " + f"Got shape = {shape}." + ) + batch_size = shape[0] + + noise_sample = self.variance_sampler.get_random_noise_sample(batch_size) + + # noise_sample.sigma has dimension [batch_size]. Broadcast these sigma values to be + # of shape [batch_size, number_of_atoms, spatial_dimension], which can be interpreted + # as [batch_size, (configuration)]. All the sigma values must be the same for a given configuration. + sigmas = broadcast_batch_tensor_to_all_dimensions( + batch_values=noise_sample.sigma, final_shape=shape + ) + + xt = self.noisy_position_sampler.get_noisy_position_sample(x0, sigmas) + + target_normalized_scores = self._get_target_normalized_score(xt, x0, sigmas) + + predicted_normalized_scores = self._get_predicted_normalized_score( + xt, noise_sample.time + ) + + loss = torch.nn.functional.mse_loss( + predicted_normalized_scores, target_normalized_scores, reduction="mean" + ) + return loss + + def _get_target_normalized_score( + self, + noisy_relative_positions: torch.Tensor, + real_relative_positions: torch.Tensor, + sigmas: torch.Tensor, + ) -> torch.Tensor: + """Get target normalized score. + + It is assumed that the inputs are consistent, ie, the noisy relative positions correspond + to the real relative positions noised with sigmas. It is also assumed that sigmas has + been broadcast so that the same value sigma(t) is applied to all atoms + dimensions within a configuration. + + Args: + noisy_relative_positions : noised relative positions. + Tensor of dimensions [batch_size, number_of_atoms, spatial_dimension] + real_relative_positions : original relative positions, before the addition of noise. + Tensor of dimensions [batch_size, number_of_atoms, spatial_dimension] + sigmas : + Tensor of dimensions [batch_size, number_of_atoms, spatial_dimension] + + Returns: + target normalized score: sigma times target score, ie, sigma times nabla_xt log P_{t|0}(xt| x0). + Tensor of dimensions [batch_size, number_of_atoms, spatial_dimension] + """ + delta_relative_positions = map_positions_to_unit_cell(noisy_relative_positions - real_relative_positions) + target_normalized_scores = get_sigma_normalized_score( + delta_relative_positions, sigmas, kmax=self.hyper_params.kmax_target_score + ) + return target_normalized_scores + + def _get_predicted_normalized_score( + self, noisy_relative_positions: torch.Tensor, time: torch.Tensor + ) -> torch.Tensor: + """Get predicted normalized score. + + Args: + noisy_relative_positions : noised relative positions. + Tensor of dimensions [batch_size, number_of_atoms, spatial_dimension] + time : Noise times for the noisy relative positions. It is assumed that the inputs are consistent, ie, + the time values in this array correspond to the noise times used to create the noisy relative positions. + Tensor of dimensions [batch_size]. + + Returns: + predicted normalized score: sigma times predicted score, ie, sigma times S_theta(xt, t). + Tensor of dimensions [batch_size, number_of_atoms, spatial_dimension] + """ + pos_key = self.sigma_normalized_score_network.position_key + time_key = self.sigma_normalized_score_network.timestep_key + augmented_batch = { + pos_key: noisy_relative_positions, + time_key: time.reshape(-1, 1), + } + predicted_normalized_scores = self.sigma_normalized_score_network( + augmented_batch + ) + return predicted_normalized_scores + + def training_step(self, batch, batch_idx): + """Runs a prediction step for training, returning the loss.""" + loss = self._generic_step(batch, batch_idx) + self.log("train_loss", loss) + self.log("epoch", self.current_epoch) + self.log("step", self.global_step) + return loss # this function is required, as the loss returned here is used for backprop + + def validation_step(self, batch, batch_idx): + """Runs a prediction step for validation, logging the loss.""" + loss = self._generic_step(batch, batch_idx) + self.log("val_loss", loss) + + def test_step(self, batch, batch_idx): + """Runs a prediction step for testing, logging the loss.""" + loss = self._generic_step(batch, batch_idx) + self.log("test_loss", loss) diff --git a/crystal_diffusion/models/score_network.py b/crystal_diffusion/models/score_network.py index 52d5cd7b..2fcf529e 100644 --- a/crystal_diffusion/models/score_network.py +++ b/crystal_diffusion/models/score_network.py @@ -4,25 +4,29 @@ Relative coordinates are with respect to lattice vectors which define the periodic unit cell. """ -from typing import Any, AnyStr, Dict +from dataclasses import dataclass +from typing import AnyStr, Dict import torch from torch import nn +@dataclass(kw_only=True) +class BaseScoreNetworkParameters: + """Base Hyper-parameters for score networks.""" + spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. + + class BaseScoreNetwork(torch.nn.Module): """Base score network. This base class defines the interface that all score networks should have in order to be easily interchangeable (ie, polymorphic). """ - position_key = "relative_positions" # unitless positions in the lattice coordinate basis + position_key = "noisy_relative_positions" # unitless positions in the lattice coordinate basis timestep_key = "time" - spatial_dimension = ( - 3 # the spatial dimension of the space where the atoms live, ie 3D space. - ) - def __init__(self, hyper_params: Dict[AnyStr, Any]): + def __init__(self, hyper_params: BaseScoreNetworkParameters): """__init__. Args: @@ -30,6 +34,7 @@ def __init__(self, hyper_params: Dict[AnyStr, Any]): """ super(BaseScoreNetwork, self).__init__() self._hyper_params = hyper_params + self.spatial_dimension = hyper_params.spatial_dimension def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): """Check batch. @@ -112,23 +117,28 @@ def _forward_unchecked(self, batch: Dict[AnyStr, torch.Tensor]) -> torch.Tensor: raise NotImplementedError +@dataclass(kw_only=True) +class MLPScoreNetworkParameters(BaseScoreNetworkParameters): + """Specific Hyper-parameters for MLP score networks.""" + number_of_atoms: int # the number of atoms in a configuration. + hidden_dim: int + + class MLPScoreNetwork(BaseScoreNetwork): """Simple Model Class. Inherits from the given framework's model class. This is a simple MLP model. """ - def __init__(self, hyper_params: Dict[AnyStr, Any]): + def __init__(self, hyper_params: MLPScoreNetworkParameters): """__init__. Args: hyper_params (dict): hyper parameters from the config file. """ super(MLPScoreNetwork, self).__init__(hyper_params) - hidden_dim = hyper_params["hidden_dim"] - self._natoms = hyper_params[ - "number_of_atoms" - ] # the number of atoms in a configuration. + hidden_dim = hyper_params.hidden_dim + self._natoms = hyper_params.number_of_atoms output_dimension = self.spatial_dimension * self._natoms input_dimension = output_dimension + 1 @@ -137,6 +147,8 @@ def __init__(self, hyper_params: Dict[AnyStr, Any]): self.mlp_layers = nn.Sequential( nn.Linear(input_dimension, hidden_dim), nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), nn.Linear(hidden_dim, output_dimension), ) @@ -159,7 +171,7 @@ def _forward_unchecked(self, batch: Dict[AnyStr, torch.Tensor]) -> torch.Tensor: Returns: computed_scores : the scores computed by the model. """ - positions = batch[self.position_key] # shape [batch_size, number_of_atoms, 3] + positions = batch[self.position_key] # shape [batch_size, number_of_atoms, spatial_dimension] times = batch[self.timestep_key] # shape [batch_size, 1] input = torch.cat([self.flatten(positions), times], dim=1) diff --git a/crystal_diffusion/samplers/noisy_position_sampler.py b/crystal_diffusion/samplers/noisy_position_sampler.py new file mode 100644 index 00000000..c101e7d5 --- /dev/null +++ b/crystal_diffusion/samplers/noisy_position_sampler.py @@ -0,0 +1,84 @@ +"""Noisy Position Sampler. + +This module is responsible for sampling relative positions from the perturbation kernel. +""" +from typing import Tuple + +import torch + + +def map_positions_to_unit_cell(positions: torch.Tensor) -> torch.Tensor: + """Map positions back to unit cell. + + The function torch.remainder does not always bring back the positions in the range [0, 1). + If the input is very small and negative, torch.remainder returns 1. This is problematic + when using floats instead of doubles. + + This method makes sure that the positions are mapped back in the [0, 1) range and does reasonable + things when the position is close to the edge. + + See issues: + https://github.com/pytorch/pytorch/issues/37743 + https://github.com/pytorch/pytorch/issues/24861 + + Args: + positions : atomic positions, tensor of arbitrary shape. + + Returns: + relative_positions: atomic positions in the unit cell, ie, in the range [0, 1). + """ + relative_positions = torch.remainder(positions, 1.0) + relative_positions[relative_positions == 1.] = 0. + return relative_positions + + +class NoisyPositionSampler: + """Noisy Position Sampler. + + This class provides methods to generate noisy positions, given real positions and + a sigma parameter. + + The random samples are produced by a separate method to make this code easy to test. + """ + @staticmethod + def _get_gaussian_noise(shape: Tuple[int]) -> torch.Tensor: + """Get Gaussian noise. + + Get a sample from N(0, 1) of dimensions shape. + + Args: + shape : the shape of the sample. + + Returns: + gaussian_noise: a sample from N(0, 1) of dimensions shape. + """ + return torch.randn(shape) + + @staticmethod + def get_noisy_position_sample(real_relative_positions: torch.Tensor, sigmas: torch.Tensor) -> torch.Tensor: + """Get noisy positions sample. + + This method draws a sample from the perturbation kernel centered on the real_relative_positions + and with a variance parameter sigma. The sample is brought back into the periodic unit cell. + + Note that sigmas is assumed to be of the same shape as real_relative_positions. There is no + check that the sigmas are "all the same" for a given batch index: it is the user's responsibility to + provide a consistent sigma, if the desired behavior is to noise a batch of configurations consistently. + + + Args: + real_relative_positions : relative coordinates of real data. Should be between 0 and 1. + relative_positions is assumed to have an arbitrary shape. + sigmas : variance of the perturbation kernel. Tensor is assumed to be of the same shape as + real_relative_positions. + + Returns: + noisy_relative_positions: a sample of noised relative positions, of the same shape as relative_positions. + """ + assert real_relative_positions.shape == sigmas.shape, \ + "sigmas array is expected to be of the same shape as the real_relative_positions array" + + z_scores = NoisyPositionSampler._get_gaussian_noise(real_relative_positions.shape) + noise = sigmas * z_scores + noisy_relative_positions = map_positions_to_unit_cell(real_relative_positions + noise) + return noisy_relative_positions diff --git a/crystal_diffusion/samplers/time_sampler.py b/crystal_diffusion/samplers/time_sampler.py deleted file mode 100644 index d3b79fed..00000000 --- a/crystal_diffusion/samplers/time_sampler.py +++ /dev/null @@ -1,90 +0,0 @@ -from dataclasses import dataclass -from typing import Tuple - -import torch - - -@dataclass -class TimeParameters: - """Time sampling parameters.""" - total_time_steps: int - random_seed: int = 1234 - - -class TimeSampler: - """Time Sampler. - - This class will produce time step samples as needed. - The times will be normalized to be between 0 and 1. - """ - def __init__(self, time_parameters: TimeParameters): - """Init method. - - Args: - time_parameters: parameters needed to instantiate the time sampler. - """ - self.time_parameters = time_parameters - self.time_step_array = torch.linspace(0, 1, time_parameters.total_time_steps) - - self._maximum_index = time_parameters.total_time_steps - 1 - self._minimum_index = 1 # we don't want to sample "0". - self._rng = torch.manual_seed(time_parameters.random_seed) - - def get_random_time_step_indices(self, shape: Tuple[int]) -> torch.Tensor: - """Random time step indices. - - Generate random indices that correspond to valid time steps. - This sampling avoids index "0", which corresponds to time "0". - - Args: - shape: shape of the random index array. - - Returns: - time_step_indices: random time step indices in a tensor of shape "shape". - """ - random_indices = torch.randint(self._minimum_index, - self._maximum_index + 1, # +1 because the maximum value is not sampled - size=shape, - generator=self._rng) - return random_indices - - def get_time_steps(self, indices: torch.Tensor) -> torch.Tensor: - """Get time steps. - - Extract the time steps from the internal data structure, given desired indices. - - Args: - indices: the time step indices. Must be between 0 and "total_number_of_time_steps" - 1. - - Returns: - time_steps: random time step indices in a tensor of shape "shape". - """ - assert torch.all(indices >= 0), "indices must be non-negative" - assert torch.all(indices <= self._maximum_index), \ - f"indices must be smaller than or equal to {self._maximum_index}" - return self.time_step_array.take(indices) - - def get_forward_iterator(self): - """Get forward iterator. - - Iterate over time steps for indices in 0,..., T-1, where T is the maximum index. - This is useful for the NOISING process, which is not formally needed in the formalism, - but might be useful for sanity checking / debugging. - - Returns: - forward_iterator: an iterator over (index, time_step) that iterates over increasing values. - """ - return iter(enumerate(self.time_step_array[:-1])) - - def get_backward_iterator(self): - """Get backward iterator. - - Iterate over time steps for indices in T,...,1 , where T is the maximum index. - This is useful for the DENOISING process. - - Returns: - backward_iterator: an iterator over (index, time_step) that iterates over decreasing values. - """ - maximum_index = len(self.time_step_array) - 1 - indices = range(maximum_index, 0, -1) - return zip(indices, self.time_step_array.flip(dims=(0,))) diff --git a/crystal_diffusion/samplers/variance_sampler.py b/crystal_diffusion/samplers/variance_sampler.py index e176fed6..cb449362 100644 --- a/crystal_diffusion/samplers/variance_sampler.py +++ b/crystal_diffusion/samplers/variance_sampler.py @@ -1,13 +1,17 @@ +from collections import namedtuple from dataclasses import dataclass +from typing import Tuple import torch -from crystal_diffusion.samplers.time_sampler import TimeSampler +Noise = namedtuple("Noise", ["time", "sigma", "sigma_squared", "g", "g_squared"]) @dataclass -class VarianceParameters: +class NoiseParameters: """Variance parameters.""" + + total_time_steps: int # Default values come from the paper: # "Torsional Diffusion for Molecular Conformer Generation", # The original values in the paper are @@ -21,75 +25,112 @@ class VarianceParameters: class ExplodingVarianceSampler: """Exploding Variance Sampler. - This class is responsible for creating the variances + This class is responsible for creating the all the quantities needed for noise generation. This implementation will use "exponential diffusion" as discussed in the paper "Torsional Diffusion for Molecular Conformer Generation". """ - def __init__( - self, variance_parameters: VarianceParameters, time_sampler: TimeSampler - ): + def __init__(self, noise_parameters: NoiseParameters): """Init method. Args: - variance_parameters: parameters that define the variance schedule. - time_sampler: object that can sample time steps. + noise_parameters: parameters that define the noise schedule. """ - self._sigma_square_array = self._create_sigma_square_array( - variance_parameters, time_sampler - ) - self._g_square_array = self._create_g_square_array(self._sigma_square_array) + self.noise_parameters = noise_parameters + self._time_array = torch.linspace(0, 1, noise_parameters.total_time_steps) - self._maximum_index = len(self._sigma_square_array) - 1 + self._sigma_array = self._create_sigma_array(noise_parameters, self._time_array) + self._sigma_squared_array = self._sigma_array**2 - def _create_sigma_square_array( - self, variance_parameters: VarianceParameters, time_sampler: TimeSampler - ) -> torch.Tensor: + self._g_squared_array = self._create_g_squared_array(self._sigma_squared_array) + self._g_array = torch.sqrt(self._g_squared_array) + + self._maximum_random_index = noise_parameters.total_time_steps - 1 + self._minimum_random_index = 1 # we don't want to randomly sample "0". - t = time_sampler.time_step_array + @staticmethod + def _create_sigma_array( + noise_parameters: NoiseParameters, time_array: torch.Tensor + ) -> torch.Tensor: + sigma_min = noise_parameters.sigma_min + sigma_max = noise_parameters.sigma_max - sigma_min = variance_parameters.sigma_min - sigma_max = variance_parameters.sigma_max + sigma = sigma_min ** (1.0 - time_array) * sigma_max**time_array + return sigma - sigma = sigma_min**(1.0 - t) * sigma_max**t - return sigma**2 + @staticmethod + def _create_g_squared_array(sigma_squared_array: torch.Tensor) -> torch.Tensor: + nan_tensor = torch.tensor([float("nan")]) + return torch.cat( + [nan_tensor, sigma_squared_array[1:] - sigma_squared_array[:-1]] + ) - def _create_g_square_array(self, sigma_square_array: torch.Tensor) -> torch.Tensor: - nan_tensor = torch.tensor([float('nan')]) - return torch.cat([nan_tensor, sigma_square_array[1:] - sigma_square_array[:-1]]) + def _get_random_time_step_indices(self, shape: Tuple[int]) -> torch.Tensor: + """Random time step indices. - def get_variances(self, indices: torch.Tensor) -> torch.Tensor: - """Get variances. + Generate random indices that correspond to valid time steps. + This sampling avoids index "0", which corresponds to time "0". Args: - indices : indices to be extracted. + shape: shape of the random index array. Returns: - variances: the variances at the specified indices. + time_step_indices: random time step indices in a tensor of shape "shape". """ - assert torch.all(indices >= 0), "indices must be non-negative" - assert torch.all(indices <= self._maximum_index), \ - f"indices must be smaller than or equal to {self._maximum_index}" - return self._sigma_square_array.take(indices) - - def get_g_squared_factors(self, indices: torch.Tensor) -> torch.Tensor: - """Get g squared factors. + random_indices = torch.randint( + self._minimum_random_index, + self._maximum_random_index + + 1, # +1 because the maximum value is not sampled + size=shape, + ) + return random_indices - The g squared factors are defined as: + def get_random_noise_sample(self, batch_size: int) -> Noise: + """Get random noise sample. - g(t)^2 = sigma(t)^2 - sigma(t-1)^2 + It is assumed that a batch is of the form [batch_size, (dimensions of a configuration)]. + In order to train a diffusion model, a configuration must be "noised" to a time t with a parameter sigma(t). + Different values can be used for different configurations: correspondingly, this method returns + one random time per element in the batch. - Note that this is ill-defined at t=0. Args: - indices: indices at which to get the g squared factors. + batch_size : number of configurations in a batch, Returns: - g_square_factors : g squared factors at specified indices + noise_sample: a collection of all the noise parameters (t, sigma, sigma^2, g, g^2) + for some random indices. All the arrays are of dimension [batch_size]. """ - assert torch.all(indices > 0), "g squared factor is ill defined at index zero." - assert torch.all(indices <= self._maximum_index), \ - f"indices must be smaller than or equal to {self._maximum_index}" - return self._g_square_array.take(indices) + indices = self._get_random_time_step_indices((batch_size,)) + times = self._time_array.take(indices) + sigmas = self._sigma_array.take(indices) + sigmas_squared = self._sigma_squared_array.take(indices) + gs = self._g_array.take(indices) + gs_squared = self._g_squared_array.take(indices) + + return Noise( + time=times, + sigma=sigmas, + sigma_squared=sigmas_squared, + g=gs, + g_squared=gs_squared, + ) + + def get_all_noise(self) -> Noise: + """Get all noise. + + All the internal noise parameter arrays, passed as a Noise object. + + Returns: + all_noise: a collection of all the noise parameters (t, sigma, sigma^2, g, g^2) + for all indices. The arrays are all of dimension [total_time_steps]. + """ + return Noise( + time=self._time_array, + sigma=self._sigma_array, + sigma_squared=self._sigma_squared_array, + g=self._g_array, + g_squared=self._g_squared_array, + ) diff --git a/crystal_diffusion/score/wrapped_gaussian_score.py b/crystal_diffusion/score/wrapped_gaussian_score.py index 965263fd..2f9a560a 100644 --- a/crystal_diffusion/score/wrapped_gaussian_score.py +++ b/crystal_diffusion/score/wrapped_gaussian_score.py @@ -14,15 +14,18 @@ sigma is derived to avoid division by a number that is very close to zero, or summing very large terms that can overlflow. -Also, what is computed is actually "sigma2 x S" (ie, the "sigma normalized score"). This is because -S ~ 1/ sigma^2; since sigma can be small, this makes the raw score arbitrarily large, and it is better to -manipulate numbers of small magnitude. +Also, what is computed is actually "sigma x S" (ie, the "sigma normalized score"). This is because, as argued +in section 4.2 of the paper + "Generative Modeling by Estimating Gradients of the Data Distribution", Song & Ermon +at convergence we expect |S| ~ 1 / sigma. Normalizing the score should lead to numbers of the same order of magnitude. Relevant papers: "Torsional Diffusion for Molecular Conformer Generation", Bowen Jing, Gabriele Corso, Jeffrey Chang, Regina Barzilay, Tommi Jaakkola "Riemannian Score-Based Generative Modelling", Bortoli et al. + + "Generative Modeling by Estimating Gradients of the Data Distribution", Song & Ermon """ from typing import Optional @@ -33,7 +36,7 @@ U_THRESHOLD = 0.5 -def get_expected_sigma_normalized_score_brute_force(u: float, sigma: float, kmax: Optional[int] = None) -> float: +def get_sigma_normalized_score_brute_force(u: float, sigma: float, kmax: Optional[int] = None) -> float: """Brute force implementation. A brute force implementation of the sigma normalized score to check that the main code is correct. @@ -61,7 +64,10 @@ def get_expected_sigma_normalized_score_brute_force(u: float, sigma: float, kmax z += exp sigma2_derivative_z += -upk * exp - return sigma2_derivative_z / z + sigma2_score = sigma2_derivative_z / z + sigma_score = sigma2_score / sigma + + return sigma_score def get_sigma_normalized_score( @@ -93,7 +99,7 @@ def get_sigma_normalized_score( total_number_of_elements = relative_positions.nelement() list_u = relative_positions.view(total_number_of_elements) - list_sigma = sigmas.view(total_number_of_elements) + list_sigma = sigmas.reshape(total_number_of_elements) # The dimension of list_k is [2 kmax + 1]. list_k = torch.arange(-kmax, kmax + 1) @@ -110,8 +116,8 @@ def get_sigma_normalized_score( _get_large_sigma_mask, ] score_calculators = [ - _get_sigma_normalized_s1a, - _get_sigma_normalized_s1b, + _get_sigma_normalized_score_1a, + _get_sigma_normalized_score_1b, _get_sigma_normalized_s2, ] @@ -210,7 +216,7 @@ def _get_s1b_exponential( return exponential -def _get_sigma_normalized_s1_from_exponential( +def _get_sigma_square_times_score_1_from_exponential( exponential: torch.Tensor, list_u: torch.Tensor, list_k: torch.Tensor ) -> torch.Tensor: """Get one of the contributions to the S1 score, assuming the exponentials has been computed and is passed as input. @@ -221,21 +227,21 @@ def _get_sigma_normalized_s1_from_exponential( list_k : the integer values that will be summed over, with shape [Nk]. Returns: - sigma_normalized_score_component : the corresponding sigma score, with shape [Nu] + sigma_square_times_score_component : the corresponding score multiplied by sigma^2, with shape [Nu]. """ # Sum is on Nk column_u = list_u.view(list_u.nelement(), 1) numerator = ((torch.ones_like(column_u) * list_k) * exponential).sum(dim=1) denominator = exponential.sum(dim=1) - list_sigma_normalized_score = -list_u - (numerator / denominator) - return list_sigma_normalized_score + list_sigma_square_times_score = -list_u - (numerator / denominator) + return list_sigma_square_times_score -def _get_sigma_normalized_s1a( +def _get_sigma_normalized_score_1a( list_u: torch.Tensor, list_sigma: torch.Tensor, list_k: torch.Tensor ) -> torch.Tensor: - """Get the sigma normalized score for small sigma and 0 <= u < 0.5. + """Get the sigma times the score for small sigma and 0 <= u < 0.5. This method assumes that the inputs are appropriate. @@ -245,16 +251,18 @@ def _get_sigma_normalized_s1a( list_k : the integer values that will be summed over, with shape [Nk]. Returns: - list_normalized_score : the sigma^2 x s1a scores, with shape [Nu]. + list_sigma_normalized_score : the sigma x s1a scores, with shape [Nu]. """ exponential = _get_s1a_exponential(list_u, list_sigma, list_k) - return _get_sigma_normalized_s1_from_exponential(exponential, list_u, list_k) + list_sigma_square_times_score = _get_sigma_square_times_score_1_from_exponential(exponential, list_u, list_k) + list_normalized_score = list_sigma_square_times_score / list_sigma + return list_normalized_score -def _get_sigma_normalized_s1b( +def _get_sigma_normalized_score_1b( list_u: torch.Tensor, list_sigma: torch.Tensor, list_k: torch.Tensor ) -> torch.Tensor: - """Get the sigma normalized score for small sigma and 0.5 <= u < 1. + """Get the sigma times the score for small sigma and 0.5 <= u < 1. This method assumes that the inputs are appropriate. @@ -264,10 +272,12 @@ def _get_sigma_normalized_s1b( list_k : the integer values that will be summed over, with shape [Nk]. Returns: - list_column_normalized_score : the sigma^2 x s1b scores, with shape [Nu]. + list_sigma_normalized_score : the sigma x s1b scores, with shape [Nu]. """ exponential = _get_s1b_exponential(list_u, list_sigma, list_k) - return _get_sigma_normalized_s1_from_exponential(exponential, list_u, list_k) + list_sigma_square_times_score = _get_sigma_square_times_score_1_from_exponential(exponential, list_u, list_k) + list_normalized_score = list_sigma_square_times_score / list_sigma + return list_normalized_score def _get_sigma_normalized_s2( @@ -283,8 +293,10 @@ def _get_sigma_normalized_s2( list_k : the integer values that will be summed over, with shape [Nk]. Returns: - list_normalized_score : the sigma^2 x s2 scores, with shape [Nu]. + list_normalized_score : the sigma x s2 scores, with shape [Nu]. """ + numerical_type = list_u.dtype + column_u = list_u.view(list_u.nelement(), 1) column_sigma = list_sigma.view(list_u.nelement(), 1) @@ -295,18 +307,20 @@ def _get_sigma_normalized_s2( g = torch.ones_like(column_u) * list_k sig = column_sigma * torch.ones_like(list_k) - exp_upk = (-np.pi * upk**2).exp() - exp_sigma_g = (-2.0 * np.pi**2 * sigma_g**2).exp() - exp_g = (-np.pi * g**2).exp() + pi = torch.tensor(np.pi, dtype=numerical_type) + + exp_upk = (-pi * upk**2).exp() + exp_sigma_g = (-2.0 * pi**2 * sigma_g**2).exp() + exp_g = (-pi * g**2).exp() - g_exponential_combination = np.sqrt(2.0 * np.pi) * sig * exp_sigma_g - exp_g + g_exponential_combination = torch.sqrt(2.0 * pi) * sig * exp_sigma_g - exp_g - cos = torch.cos(2.0 * np.pi * gu) - sin = torch.sin(2.0 * np.pi * gu) + cos = torch.cos(2.0 * pi * gu) + sin = torch.sin(2.0 * pi * gu) # The sum is over Nk, leaving arrays of dimensions [Nu] z2 = exp_upk.sum(dim=1) + (g_exponential_combination * cos).sum(dim=1) - deriv_z2 = -2.0 * np.pi * ((upk * exp_upk).sum(dim=1) + (g * g_exponential_combination * sin).sum(dim=1)) - list_sigma_normalized_scores_s2 = list_sigma**2 * deriv_z2 / z2 + deriv_z2 = -2.0 * pi * ((upk * exp_upk).sum(dim=1) + (g * g_exponential_combination * sin).sum(dim=1)) + list_sigma_normalized_scores_s2 = list_sigma * deriv_z2 / z2 return list_sigma_normalized_scores_s2 diff --git a/crystal_diffusion/utils/tensor_utils.py b/crystal_diffusion/utils/tensor_utils.py new file mode 100644 index 00000000..97dab61b --- /dev/null +++ b/crystal_diffusion/utils/tensor_utils.py @@ -0,0 +1,34 @@ +from typing import Tuple + +import torch + + +def broadcast_batch_tensor_to_all_dimensions(batch_values: torch.Tensor, final_shape: Tuple[int]) -> torch.Tensor: + """Broadcast batch tensor to all dimensions. + + A data batch is typically a tensor of shape [batch_size, n1, n2, ...] where n1, n2, etc constitute + one example of the data. This method broadcasts a tensor of shape [batch_size] to a tensor of shape + [batch_size, n1, n2, ...] where all the values for the non-batch dimension are equal to the value + for the given batch index. + + This is useful when we want to multiply every value in the data example by the same number. + + Args: + batch_values : values to be braodcasted, of shape [batch_size] + final_shape : shape of the final tensor, [batch_size, n1, n2, ...] + + Returns: + broadcast_values : tensor of shape [batch_size, n1, n2, ...], where all entries are identical + along non-batch dimensions. + """ + assert len(batch_values.shape) == 1, "The batch values should be a one-dimensional tensor." + batch_size = len(batch_values) + + assert final_shape[0] == batch_size, "The final shape should have the batch_size as its first dimension." + + # reshape the batch_values array to have the same dimension as final_shape, with all values identical + # for a given batch index. + number_of_dimensions = len(final_shape) + reshape_dimension = [-1] + (number_of_dimensions - 1) * [1] + broadcast_values = batch_values.reshape(reshape_dimension).expand(final_shape) + return broadcast_values diff --git a/sanity_checks/__init__.py b/sanity_checks/__init__.py new file mode 100644 index 00000000..e311896f --- /dev/null +++ b/sanity_checks/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +SANITY_CHECK_FOLDER = str(Path(__file__).parent) diff --git a/sanity_checks/overfit_fake_data.py b/sanity_checks/overfit_fake_data.py new file mode 100644 index 00000000..2b8d80d8 --- /dev/null +++ b/sanity_checks/overfit_fake_data.py @@ -0,0 +1,56 @@ +"""Overfit fake data. + +A simple sanity check experiment to verify that we can overfit a batch of random data. +""" +import os + +import pytorch_lightning +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import TensorBoardLogger +from torch.utils.data import DataLoader + +from crystal_diffusion.models.optimizer import (OptimizerParameters, + ValidOptimizerNames) +from crystal_diffusion.models.position_diffusion_lightning_model import ( + PositionDiffusionLightningModel, PositionDiffusionParameters) +from crystal_diffusion.models.score_network import MLPScoreNetworkParameters +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from sanity_checks import SANITY_CHECK_FOLDER + +batch_size = 16 +number_of_atoms = 4 +spatial_dimension = 2 + +score_network_parameters = MLPScoreNetworkParameters( + number_of_atoms=number_of_atoms, + hidden_dim=32, + spatial_dimension=spatial_dimension, +) + +optimizer_parameters = OptimizerParameters(name=ValidOptimizerNames("adam"), + learning_rate=0.01) + +noise_parameters = NoiseParameters(total_time_steps=10) + +hyper_params = PositionDiffusionParameters( + score_network_parameters=score_network_parameters, + optimizer_parameters=optimizer_parameters, + noise_parameters=noise_parameters, +) + + +tbx_logger = TensorBoardLogger(save_dir=os.path.join(SANITY_CHECK_FOLDER, "tensorboard"), name="overfit_fake_data") + +if __name__ == '__main__': + + pytorch_lightning.seed_everything(123) + + all_positions = torch.rand(batch_size, number_of_atoms, spatial_dimension) + data = [dict(relative_positions=configuration) for configuration in all_positions] + train_dataloader = DataLoader(data, batch_size=batch_size) + + lightning_model = PositionDiffusionLightningModel(hyper_params) + + trainer = Trainer(accelerator='cpu', max_epochs=10000, logger=tbx_logger, log_every_n_steps=1) + trainer.fit(lightning_model, train_dataloaders=train_dataloader) diff --git a/tests/models/test_optimizer.py b/tests/models/test_optimizer.py new file mode 100644 index 00000000..33bbcf70 --- /dev/null +++ b/tests/models/test_optimizer.py @@ -0,0 +1,35 @@ +import pytest +import torch + +from crystal_diffusion.models.optimizer import (OptimizerParameters, + ValidOptimizerNames, + load_optimizer) + + +class FakeNeuralNet(torch.nn.Module): + """A fake neural net for testing that we can attach an optimizer.""" + def __init__(self): + super(FakeNeuralNet, self).__init__() + self.linear_layer = torch.nn.Linear(in_features=4, out_features=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_layer(x) + + +@pytest.fixture() +def model(): + return FakeNeuralNet() + + +@pytest.fixture() +def optimizer_parameters(optimizer_name): + valid_optimizer_name = ValidOptimizerNames(optimizer_name) + return OptimizerParameters(name=valid_optimizer_name, learning_rate=0.01) + + +@pytest.mark.parametrize("optimizer_name", [option.value for option in list(ValidOptimizerNames)]) +def test_load_optimizer(optimizer_name, optimizer_parameters, model): + # This is more of a "smoke test": can the optimizer be instantiated and run without crashing. + optimizer = load_optimizer(optimizer_parameters, model) + optimizer.zero_grad() + optimizer.step() diff --git a/tests/models/test_position_diffusion_lightning_model.py b/tests/models/test_position_diffusion_lightning_model.py new file mode 100644 index 00000000..19299bec --- /dev/null +++ b/tests/models/test_position_diffusion_lightning_model.py @@ -0,0 +1,200 @@ +import pytest +import torch +from pytorch_lightning import LightningDataModule, Trainer +from torch.utils.data import DataLoader, random_split + +from crystal_diffusion.models.optimizer import (OptimizerParameters, + ValidOptimizerNames) +from crystal_diffusion.models.position_diffusion_lightning_model import ( + PositionDiffusionLightningModel, PositionDiffusionParameters) +from crystal_diffusion.models.score_network import (MLPScoreNetwork, + MLPScoreNetworkParameters) +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.score.wrapped_gaussian_score import \ + get_sigma_normalized_score_brute_force +from crystal_diffusion.utils.tensor_utils import \ + broadcast_batch_tensor_to_all_dimensions + +available_accelerators = ["cpu"] +if torch.cuda.is_available(): + available_accelerators.append("gpu") + + +class FakePositionsDataModule(LightningDataModule): + def __init__( + self, + batch_size: int = 4, + dataset_size: int = 33, + number_of_atoms: int = 8, + spatial_dimension: int = 2, + ): + super().__init__() + self.batch_size = batch_size + all_positions = torch.rand(dataset_size, number_of_atoms, spatial_dimension) + self.data = [ + dict(relative_positions=configuration) for configuration in all_positions + ] + self.train_data, self.val_data, self.test_data = None, None, None + + def setup(self, stage: str): + self.train_data, self.val_data, self.test_data = random_split( + self.data, lengths=[0.5, 0.3, 0.2] + ) + + def train_dataloader(self): + return DataLoader(self.train_data, batch_size=self.batch_size) + + def val_dataloader(self): + return DataLoader(self.val_data, batch_size=self.batch_size) + + def test_dataloader(self): + return DataLoader(self.test_data, batch_size=self.batch_size) + + +@pytest.mark.parametrize("spatial_dimension", [2, 3]) +class TestPositionDiffusionLightningModel: + @pytest.fixture(scope="class", autouse=True) + def set_random_seed(self): + torch.manual_seed(23452342) + + @pytest.fixture() + def batch_size(self): + return 4 + + @pytest.fixture() + def number_of_atoms(self): + return 8 + + @pytest.fixture() + def hyper_params(self, number_of_atoms, spatial_dimension): + score_network_parameters = MLPScoreNetworkParameters( + number_of_atoms=number_of_atoms, + hidden_dim=16, + spatial_dimension=spatial_dimension, + ) + + optimizer_parameters = OptimizerParameters( + name=ValidOptimizerNames("adam"), learning_rate=0.01 + ) + noise_parameters = NoiseParameters(total_time_steps=15) + + hyper_params = PositionDiffusionParameters( + score_network_parameters=score_network_parameters, + optimizer_parameters=optimizer_parameters, + noise_parameters=noise_parameters, + ) + return hyper_params + + @pytest.fixture() + def real_relative_positions(self, batch_size, number_of_atoms, spatial_dimension): + relative_positions = torch.rand(batch_size, number_of_atoms, spatial_dimension) + return relative_positions + + @pytest.fixture() + def noisy_relative_positions(self, batch_size, number_of_atoms, spatial_dimension): + noisy_relative_positions = torch.rand( + batch_size, number_of_atoms, spatial_dimension + ) + return noisy_relative_positions + + @pytest.fixture() + def batch(self, real_relative_positions): + return dict(relative_positions=real_relative_positions) + + @pytest.fixture() + def fake_datamodule(self, batch_size, number_of_atoms, spatial_dimension): + data_module = FakePositionsDataModule( + batch_size=batch_size, + number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + ) + return data_module + + @pytest.fixture() + def times(self, batch_size): + times = torch.rand(batch_size) + return times + + @pytest.fixture() + def sigmas(self, batch_size, number_of_atoms, spatial_dimension): + sigma_values = 0.5 * torch.rand(batch_size) # smaller sigmas for harder tests! + sigmas = broadcast_batch_tensor_to_all_dimensions( + sigma_values, final_shape=(batch_size, number_of_atoms, spatial_dimension) + ) + return sigmas + + @pytest.fixture() + def lightning_model(self, hyper_params): + lightning_model = PositionDiffusionLightningModel(hyper_params) + return lightning_model + + @pytest.fixture() + def brute_force_target_normalized_score( + self, noisy_relative_positions, real_relative_positions, sigmas + ): + shape = noisy_relative_positions.shape + + expected_scores = [] + for xt, x0, sigma in zip( + noisy_relative_positions.flatten(), + real_relative_positions.flatten(), + sigmas.flatten(), + ): + u = torch.remainder(xt - x0, 1.0) + + # Note that the brute force algorithm is not robust and can sometimes produce NaNs in single precision! + # Let's compute in double precision to avoid NaNs. + expected_score = get_sigma_normalized_score_brute_force( + u.to(torch.double), sigma.to(torch.double), kmax=20 + ).to(torch.float) + expected_scores.append(expected_score) + + expected_scores = torch.tensor(expected_scores).reshape(shape) + assert not torch.any( + expected_scores.isnan() + ), "The brute force algorithm produced NaN scores. Review input." + return expected_scores + + def test_get_target_normalized_score( + self, + lightning_model, + noisy_relative_positions, + real_relative_positions, + sigmas, + brute_force_target_normalized_score, + ): + computed_target_normalized_scores = ( + lightning_model._get_target_normalized_score( + noisy_relative_positions, real_relative_positions, sigmas + ) + ) + + torch.testing.assert_allclose(computed_target_normalized_scores, + brute_force_target_normalized_score, + atol=1e-7, + rtol=1e-4) + + def test_get_predicted_normalized_score( + self, mocker, lightning_model, noisy_relative_positions, times + ): + mocker.patch.object(MLPScoreNetwork, "_forward_unchecked") + + _ = lightning_model._get_predicted_normalized_score( + noisy_relative_positions, times + ) + + list_calls = MLPScoreNetwork._forward_unchecked.mock_calls + assert len(list_calls) == 1 + input_batch = list_calls[0][1][0] + + assert MLPScoreNetwork.position_key in input_batch + torch.testing.assert_allclose(input_batch[MLPScoreNetwork.position_key], noisy_relative_positions) + + assert MLPScoreNetwork.timestep_key in input_batch + torch.testing.assert_allclose(input_batch[MLPScoreNetwork.timestep_key], times.reshape(-1, 1)) + + @pytest.mark.parametrize("accelerator", available_accelerators) + def test_smoke_test(self, lightning_model, fake_datamodule, accelerator): + trainer = Trainer(fast_dev_run=3, accelerator=accelerator) + trainer.fit(lightning_model, fake_datamodule) + trainer.test(lightning_model, fake_datamodule) diff --git a/tests/models/test_score_network.py b/tests/models/test_score_network.py index 5bae56c1..44e35ae8 100644 --- a/tests/models/test_score_network.py +++ b/tests/models/test_score_network.py @@ -2,20 +2,26 @@ import torch from crystal_diffusion.models.score_network import (BaseScoreNetwork, - MLPScoreNetwork) + BaseScoreNetworkParameters, + MLPScoreNetwork, + MLPScoreNetworkParameters) +@pytest.mark.parametrize("spatial_dimension", [2, 3]) class TestScoreNetworkCheck: + @pytest.fixture(scope="class", autouse=True) + def set_random_seed(self): + torch.manual_seed(123) + @pytest.fixture() - def base_score_network(self): - return BaseScoreNetwork({}) + def base_score_network(self, spatial_dimension): + return BaseScoreNetwork(BaseScoreNetworkParameters(spatial_dimension=spatial_dimension)) @pytest.fixture() - def good_batch(self): - torch.manual_seed(123) + def good_batch(self, spatial_dimension): batch_size = 16 - positions = torch.rand(batch_size, 8, 3) + positions = torch.rand(batch_size, 8, spatial_dimension) times = torch.rand(batch_size, 1) return {BaseScoreNetwork.position_key: positions, BaseScoreNetwork.timestep_key: times} @@ -67,6 +73,7 @@ def test_check_batch_bad(self, base_score_network, bad_batch): base_score_network._check_batch(bad_batch) +@pytest.mark.parametrize("spatial_dimension", [2, 3]) class TestMLPScoreNetwork: @pytest.fixture() @@ -78,26 +85,27 @@ def number_of_atoms(self): return 8 @pytest.fixture() - def expected_score_shape(self, batch_size, number_of_atoms): - return batch_size, number_of_atoms, 3 + def expected_score_shape(self, batch_size, number_of_atoms, spatial_dimension): + return batch_size, number_of_atoms, spatial_dimension @pytest.fixture() - def good_batch(self, batch_size, number_of_atoms): - torch.manual_seed(123) - positions = torch.rand(batch_size, number_of_atoms, 3) + def good_batch(self, batch_size, number_of_atoms, spatial_dimension): + positions = torch.rand(batch_size, number_of_atoms, spatial_dimension) times = torch.rand(batch_size, 1) return {BaseScoreNetwork.position_key: positions, BaseScoreNetwork.timestep_key: times} @pytest.fixture() - def bad_batch(self, batch_size, number_of_atoms): - torch.manual_seed(123) - positions = torch.rand(batch_size, number_of_atoms // 2, 3) + def bad_batch(self, batch_size, number_of_atoms, spatial_dimension): + positions = torch.rand(batch_size, number_of_atoms // 2, spatial_dimension) times = torch.rand(batch_size, 1) return {BaseScoreNetwork.position_key: positions, BaseScoreNetwork.timestep_key: times} @pytest.fixture() - def score_network(self, number_of_atoms): - return MLPScoreNetwork({'hidden_dim': 16, 'number_of_atoms': number_of_atoms}) + def score_network(self, number_of_atoms, spatial_dimension): + hyper_params = MLPScoreNetworkParameters(spatial_dimension=spatial_dimension, + number_of_atoms=number_of_atoms, + hidden_dim=16) + return MLPScoreNetwork(hyper_params) def test_check_batch_bad(self, score_network, bad_batch): with pytest.raises(AssertionError): diff --git a/tests/samplers/test_noisy_position_sampler.py b/tests/samplers/test_noisy_position_sampler.py new file mode 100644 index 00000000..0c875506 --- /dev/null +++ b/tests/samplers/test_noisy_position_sampler.py @@ -0,0 +1,95 @@ +import numpy as np +import pytest +import torch + +from crystal_diffusion.samplers.noisy_position_sampler import ( + NoisyPositionSampler, map_positions_to_unit_cell) + + +def test_remainder_failure(): + # This test demonstrates how torch.remainder does not do what we want. + epsilon = -torch.tensor(1.e-8) + position_not_in_unit_cell = torch.remainder(epsilon, 1.0) + assert position_not_in_unit_cell == 1.0 + + +@pytest.mark.parametrize("shape", [(10,), (10, 20), (3, 4, 5)]) +def test_map_positions_to_unit_cell_hard(shape): + positions = 1e-8 * (torch.rand((10,)) - 0.5) + computed_positions = map_positions_to_unit_cell(positions) + + positive_positions_mask = positions >= 0. + assert torch.all(positions[positive_positions_mask] == computed_positions[positive_positions_mask]) + torch.testing.assert_allclose(computed_positions[~positive_positions_mask], + torch.zeros_like(computed_positions[~positive_positions_mask])) + + +@pytest.mark.parametrize("shape", [(100, 8, 16)]) +def test_map_positions_to_unit_cell_easy(shape): + # Very unlikely to hit the edge cases. + positions = 10. * (torch.rand((10,)) - 0.5) + expected_values = torch.remainder(positions, 1.) + computed_values = map_positions_to_unit_cell(positions) + torch.testing.assert_allclose(computed_values, expected_values) + + +@pytest.mark.parametrize("shape", [(10, 1), (4, 5, 3), (2, 2, 2, 2)]) +class TestNoisyPositionSampler: + + @pytest.fixture(scope="class", autouse=True) + def set_random_seed(self): + torch.manual_seed(23423) + + @pytest.fixture() + def real_relative_positions(self, shape): + return torch.rand(shape) + + @pytest.fixture() + def sigmas(self, shape): + return torch.rand(shape) + + @pytest.fixture() + def computed_noisy_relative_positions(self, real_relative_positions, sigmas): + return NoisyPositionSampler.get_noisy_position_sample( + real_relative_positions, sigmas + ) + + @pytest.fixture() + def fake_gaussian_sample(self, shape): + # Note: this is NOT a Gaussian distribution. That's ok, it's fake data for testing! + return torch.rand(shape) + + def test_shape(self, computed_noisy_relative_positions, shape): + assert computed_noisy_relative_positions.shape == shape + + def test_range(self, computed_noisy_relative_positions): + assert torch.all(computed_noisy_relative_positions >= 0.0) + assert torch.all(computed_noisy_relative_positions < 1.0) + + def test_get_noisy_position_sample( + self, mocker, real_relative_positions, sigmas, fake_gaussian_sample + ): + mocker.patch.object( + NoisyPositionSampler, + "_get_gaussian_noise", + return_value=fake_gaussian_sample, + ) + + computed_samples = NoisyPositionSampler.get_noisy_position_sample( + real_relative_positions, sigmas + ) + + flat_sigmas = sigmas.flatten() + flat_positions = real_relative_positions.flatten() + flat_computed_samples = computed_samples.flatten() + flat_fake_gaussian_sample = fake_gaussian_sample.flatten() + + for sigma, x0, computed_sample, epsilon in zip( + flat_sigmas, + flat_positions, + flat_computed_samples, + flat_fake_gaussian_sample, + ): + expected_sample = np.mod(x0 + sigma * epsilon, 1).float() + + torch.testing.assert_allclose(computed_sample, expected_sample) diff --git a/tests/samplers/test_time_sampler.py b/tests/samplers/test_time_sampler.py deleted file mode 100644 index a72d6c2d..00000000 --- a/tests/samplers/test_time_sampler.py +++ /dev/null @@ -1,81 +0,0 @@ -import pytest -import torch - -from crystal_diffusion.samplers.time_sampler import TimeParameters, TimeSampler - - -@pytest.mark.parametrize('total_time_steps', [3, 10, 17]) -class TestTimeSampler: - @pytest.fixture() - def time_parameters(self, total_time_steps): - time_parameters = TimeParameters(total_time_steps=total_time_steps, random_seed=0) - return time_parameters - - @pytest.fixture() - def time_sampler(self, time_parameters): - return TimeSampler(time_parameters) - - @pytest.mark.parametrize('shape', [(100,), (3, 8), (1, 2, 3)]) - def test_random_time_step_indices(self, time_sampler, shape, total_time_steps): - computed_time_step_indices = time_sampler.get_random_time_step_indices(shape) - assert computed_time_step_indices.shape == shape - assert torch.all(computed_time_step_indices > 0) - assert torch.all(computed_time_step_indices <= total_time_steps) - - @pytest.mark.parametrize('shape', [(100,), (3, 8), (1, 2, 3)]) - def test_random_time_steps(self, time_sampler, shape, total_time_steps): - indices = time_sampler.get_random_time_step_indices(shape) - computed_time_steps = time_sampler.get_time_steps(indices) - assert computed_time_steps.shape == shape - - assert torch.all(computed_time_steps > 0.0) - assert torch.all(computed_time_steps <= 1.0) - - # The random values should come from an array of the form [dt, 2 dt, 3 dt,.... N dt] - delta_time = 1. / (total_time_steps - 1.) - number_of_time_segments = torch.round(computed_time_steps / delta_time) - expected_time_steps = number_of_time_segments * delta_time - - assert torch.all(torch.isclose(expected_time_steps, computed_time_steps)) - - @pytest.fixture() - def all_indices_and_times(self, total_time_steps): - dt = torch.tensor(1. / (total_time_steps - 1.)).to(torch.double) - all_time_indices = torch.arange(total_time_steps) - all_times = (all_time_indices * dt).float() - return all_time_indices, all_times - - def test_forward_iterator(self, all_indices_and_times, time_sampler): - - all_time_indices, all_times = all_indices_and_times - expected_indices = all_time_indices[:-1] - expected_times = all_times[:-1] - - computed_indices = [] - computed_times = [] - for index, time in time_sampler.get_forward_iterator(): - computed_indices.append(index) - computed_times.append(time) - - computed_indices = torch.tensor(computed_indices) - computed_times = torch.tensor(computed_times) - - assert torch.all(torch.isclose(computed_times, expected_times)) - assert torch.all(torch.isclose(computed_indices, expected_indices)) - - def test_backward_iterator(self, all_indices_and_times, time_sampler): - all_time_indices, all_times = all_indices_and_times - expected_indices = all_time_indices[1:].flip(dims=(0,)) - expected_times = all_times[1:].flip(dims=(0,)) - - computed_indices = [] - computed_times = [] - for index, time in time_sampler.get_backward_iterator(): - computed_indices.append(index) - computed_times.append(time) - - computed_indices = torch.tensor(computed_indices) - computed_times = torch.tensor(computed_times) - - assert torch.all(torch.isclose(computed_times, expected_times)) - assert torch.all(torch.isclose(computed_indices, expected_indices)) diff --git a/tests/samplers/test_variance_sampler.py b/tests/samplers/test_variance_sampler.py index dbe5ec07..0f997ef3 100644 --- a/tests/samplers/test_variance_sampler.py +++ b/tests/samplers/test_variance_sampler.py @@ -1,89 +1,101 @@ import pytest import torch -from crystal_diffusion.samplers.time_sampler import TimeParameters, TimeSampler from crystal_diffusion.samplers.variance_sampler import ( - ExplodingVarianceSampler, VarianceParameters) + ExplodingVarianceSampler, NoiseParameters) -@pytest.mark.parametrize('total_time_steps', [3, 10, 17]) +@pytest.mark.parametrize("total_time_steps", [3, 10, 17]) class TestExplodingVarianceSampler: @pytest.fixture() - def time_parameters(self, total_time_steps): - time_parameters = TimeParameters(total_time_steps=total_time_steps, random_seed=0) - return time_parameters + def noise_parameters(self, total_time_steps): + return NoiseParameters(total_time_steps=total_time_steps) @pytest.fixture() - def time_sampler(self, time_parameters): - return TimeSampler(time_parameters) + def variance_sampler(self, noise_parameters): + return ExplodingVarianceSampler(noise_parameters=noise_parameters) @pytest.fixture() - def variance_parameters(self, time_sampler): - return VarianceParameters() + def expected_times(self, total_time_steps): + times = torch.linspace(0.0, 1.0, total_time_steps) + return times @pytest.fixture() - def variance_sampler(self, variance_parameters, time_sampler): - return ExplodingVarianceSampler(variance_parameters=variance_parameters, - time_sampler=time_sampler) + def expected_sigmas(self, expected_times, noise_parameters): + smin = noise_parameters.sigma_min + smax = noise_parameters.sigma_max - @pytest.fixture() - def expected_variances(self, variance_parameters, time_sampler, time_parameters): - - times = torch.linspace(0., 1., time_parameters.total_time_steps) - - smin = variance_parameters.sigma_min - smax = variance_parameters.sigma_max - - sigmas = smin**(1. - times) * smax ** times - variances = sigmas**2 - return variances + sigmas = smin ** (1.0 - expected_times) * smax**expected_times + return sigmas @pytest.fixture() def indices(self, time_sampler, shape): return time_sampler.get_random_time_step_indices(shape) - @pytest.fixture() - def expected_variances_by_index(self, expected_variances, indices): - result = [] - for i in indices.flatten(): - result.append(expected_variances[i]) - return torch.tensor(result).reshape(indices.shape) + def test_time_array(self, variance_sampler, expected_times): + torch.testing.assert_allclose(variance_sampler._time_array, expected_times) - def test_sigma_square_array(self, variance_sampler, expected_variances): - assert torch.all(torch.isclose(variance_sampler._sigma_square_array, expected_variances)) + def test_sigma_and_sigma_squared_arrays(self, variance_sampler, expected_sigmas): + torch.testing.assert_allclose(variance_sampler._sigma_array, expected_sigmas) + torch.testing.assert_allclose(variance_sampler._sigma_squared_array, expected_sigmas**2) - def test_g_square_array(self, variance_sampler, expected_variances): + def test_g_and_g_square_array(self, variance_sampler, expected_sigmas): + expected_sigmas_square = expected_sigmas**2 - expected_g_squared_array = [float('nan')] - for sigma2_t, sigma2_tm1 in zip(expected_variances[1:], expected_variances[:-1]): + expected_g_squared_array = [float("nan")] + for sigma2_t, sigma2_tm1 in zip( + expected_sigmas_square[1:], expected_sigmas_square[:-1] + ): g2 = sigma2_t - sigma2_tm1 expected_g_squared_array.append(g2) expected_g_squared_array = torch.tensor(expected_g_squared_array) - - assert torch.isnan(variance_sampler._g_square_array[0]) - assert torch.all(torch.isclose(variance_sampler._g_square_array[1:], expected_g_squared_array[1:])) - - @pytest.mark.parametrize('shape', [(100,), (3, 8), (1, 2, 3)]) - def test_get_variances(self, variance_sampler, indices, expected_variances_by_index): - computed_variances = variance_sampler.get_variances(indices) - assert torch.all(torch.isclose(computed_variances, expected_variances_by_index)) - - def test_get_g_squared_factors(self, variance_sampler, expected_variances): - - indices = [] - expected_g_squared_factors = [] - previous_variance = expected_variances[0] - - for index, current_variance in enumerate(expected_variances[1:], 1): - indices.append(index) - current_variance = expected_variances[index] - g2 = current_variance - previous_variance - expected_g_squared_factors.append(g2) - previous_variance = current_variance - expected_g_squared_factors = torch.tensor(expected_g_squared_factors) - - indices = torch.tensor(indices) - computed_g_squared_factors = variance_sampler.get_g_squared_factors(indices) - - assert torch.all(torch.isclose(computed_g_squared_factors, expected_g_squared_factors)) + expected_g_array = torch.sqrt(expected_g_squared_array) + + assert torch.isnan(variance_sampler._g_array[0]) + assert torch.isnan(variance_sampler._g_squared_array[0]) + torch.testing.assert_allclose(variance_sampler._g_array[1:], expected_g_array[1:]) + torch.testing.assert_allclose(variance_sampler._g_squared_array[1:], expected_g_squared_array[1:]) + + def test_get_random_time_step_indices(self, variance_sampler, total_time_steps): + # Check that we never sample zero. + random_indices = variance_sampler._get_random_time_step_indices(shape=(1000,)) + assert torch.all(random_indices > 0) + assert torch.all(random_indices < total_time_steps) + + @pytest.mark.parametrize("batch_size", [1, 10, 100]) + def test_get_random_noise_parameter_sample( + self, mocker, variance_sampler, batch_size + ): + random_indices = variance_sampler._get_random_time_step_indices(shape=(1000,)) + mocker.patch.object( + variance_sampler, + "_get_random_time_step_indices", + return_value=random_indices, + ) + + noise_sample = variance_sampler.get_random_noise_sample(batch_size) + + expected_times = variance_sampler._time_array.take(random_indices) + expected_sigmas = variance_sampler._sigma_array.take(random_indices) + expected_sigmas_squared = variance_sampler._sigma_squared_array.take( + random_indices + ) + expected_gs = variance_sampler._g_array.take(random_indices) + expected_gs_squared = variance_sampler._g_squared_array.take(random_indices) + + torch.testing.assert_allclose(noise_sample.time, expected_times) + torch.testing.assert_allclose(noise_sample.sigma, expected_sigmas) + torch.testing.assert_allclose(noise_sample.sigma_squared, expected_sigmas_squared) + torch.testing.assert_allclose(noise_sample.g, expected_gs) + torch.testing.assert_allclose(noise_sample.g_squared, expected_gs_squared) + + def test_get_all_noise(self, variance_sampler): + noise = variance_sampler.get_all_noise() + torch.testing.assert_allclose(noise.time, variance_sampler._time_array) + torch.testing.assert_allclose(noise.sigma, variance_sampler._sigma_array) + torch.testing.assert_allclose(noise.sigma_squared, variance_sampler._sigma_squared_array) + assert torch.isnan(noise.g[0]) + assert torch.isnan(noise.g_squared[0]) + torch.testing.assert_allclose(noise.g[1:], variance_sampler._g_array[1:]) + torch.testing.assert_allclose(noise.g_squared[1:], variance_sampler._g_squared_array[1:]) diff --git a/tests/score/test_wrapped_gaussian_score.py b/tests/score/test_wrapped_gaussian_score.py index f969737c..b21187d2 100644 --- a/tests/score/test_wrapped_gaussian_score.py +++ b/tests/score/test_wrapped_gaussian_score.py @@ -4,22 +4,24 @@ from crystal_diffusion.score.wrapped_gaussian_score import ( SIGMA_THRESHOLD, _get_large_sigma_mask, _get_s1a_exponential, - _get_s1b_exponential, _get_sigma_normalized_s1_from_exponential, - _get_sigma_normalized_s2, _get_small_sigma_large_u_mask, - _get_small_sigma_small_u_mask, - get_expected_sigma_normalized_score_brute_force, - get_sigma_normalized_score) + _get_s1b_exponential, _get_sigma_normalized_s2, + _get_sigma_square_times_score_1_from_exponential, + _get_small_sigma_large_u_mask, _get_small_sigma_small_u_mask, + get_sigma_normalized_score, get_sigma_normalized_score_brute_force) + + +@pytest.fixture(scope="module", autouse=True) +def set_random_seed(): + torch.manual_seed(1234) @pytest.fixture def relative_positions(shape): - torch.manual_seed(1234) return torch.rand(shape) @pytest.fixture def sigmas(shape): - torch.manual_seed(4321) return torch.rand(shape) * 5.0 * SIGMA_THRESHOLD @@ -44,13 +46,13 @@ def expected_sigma_normalized_scores(relative_positions, sigmas): list_sigma_normalized_scores = [] for u, sigma in zip(relative_positions.numpy().flatten(), sigmas.numpy().flatten()): - s = get_expected_sigma_normalized_score_brute_force(u, sigma) + s = get_sigma_normalized_score_brute_force(u, sigma) list_sigma_normalized_scores.append(s) return torch.tensor(list_sigma_normalized_scores).reshape(shape) -test_shapes = [(10,), (3, 4, 5), (10, 5)] +test_shapes = [(100,), (3, 4, 5), (10, 5)] @pytest.mark.parametrize("shape", test_shapes) @@ -104,7 +106,6 @@ def test_mask_2(self, expected_shape, list_u, list_sigma): class TestExponentials: @pytest.fixture() def fake_exponential(self, list_u, list_k): - torch.manual_seed(6345345) return torch.rand(len(list_u), len(list_k)) def test_get_s1a_exponential(self, list_k, list_u, list_sigma): @@ -132,7 +133,7 @@ def test_get_s1b_exponential(self, list_k, list_u, list_sigma): def test_get_sigma_normalized_s1_from_exponential( self, fake_exponential, list_u, list_k ): - computed_results = _get_sigma_normalized_s1_from_exponential( + computed_results = _get_sigma_square_times_score_1_from_exponential( fake_exponential, list_u, list_k ) @@ -147,27 +148,32 @@ def test_get_sigma_normalized_s1_from_exponential( @pytest.mark.parametrize("kmax", [1, 5, 10]) @pytest.mark.parametrize("shape", test_shapes) -def test_get_sigma_normalized_s2(list_u, list_sigma, list_k): - list_computed_s2 = _get_sigma_normalized_s2(list_u, list_sigma, list_k) +@pytest.mark.parametrize("numerical_type", [torch.double]) +def test_get_sigma_normalized_s2(list_u, list_sigma, list_k, numerical_type): + # TODO: the test fails for numerical_type = torch.float. Should we be worried about numerical error here? + list_u_cast = list_u.to(numerical_type) + list_sigma_cast = list_sigma.to(numerical_type) + pi = torch.tensor(np.pi, dtype=numerical_type) - for u, sigma, computed_value in zip(list_u, list_sigma, list_computed_s2): - z2 = 0.0 - deriv_z2 = 0.0 + list_computed_s2 = _get_sigma_normalized_s2(list_u_cast, list_sigma_cast, list_k) + + list_expected_s2 = [] + for u, sigma in zip(list_u_cast, list_sigma_cast): + z2 = torch.tensor(0.0, dtype=numerical_type) + deriv_z2 = torch.tensor(0.0, dtype=numerical_type) for k in list_k: - g_term = torch.exp(-2 * np.pi**2 * sigma**2 * k**2) - torch.exp( - -np.pi * k**2 - ) / sigma / np.sqrt(2.0 * np.pi) - z2 += torch.exp(-np.pi * (u + k) ** 2) + np.sqrt( - 2.0 * np.pi - ) * sigma * g_term * torch.cos(2 * np.pi * k * u) - deriv_z2 += -2.0 * np.pi * (u + k) * torch.exp(-np.pi * (u + k) ** 2) - ( - 2.0 * np.pi - ) ** 1.5 * sigma * k * g_term * torch.sin(2.0 * np.pi * k * u) - - expected_value = sigma**2 * deriv_z2 / z2 - - torch.testing.assert_close(computed_value, expected_value) + g_term = torch.sqrt(2.0 * pi) * sigma * (-2.0 * pi**2 * sigma**2 * k**2).exp() - (-pi * k**2).exp() + z2 += (-pi * (u + k) ** 2).exp() + g_term * torch.cos(2 * pi * k * u) + deriv_z2 += (-2.0 * pi * (u + k) * (-pi * (u + k) ** 2).exp() + - 2.0 * pi * k * g_term * torch.sin(2.0 * pi * k * u)) + + expected_value = sigma * deriv_z2 / z2 + list_expected_s2.append(expected_value) + + list_expected_s2 = torch.tensor(list_expected_s2) + + torch.testing.assert_close(list_computed_s2, list_expected_s2) @pytest.mark.parametrize("kmax", [4]) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/test_tensor_utils.py b/tests/utils/test_tensor_utils.py new file mode 100644 index 00000000..79ccd508 --- /dev/null +++ b/tests/utils/test_tensor_utils.py @@ -0,0 +1,34 @@ +import pytest +import torch + +from crystal_diffusion.utils.tensor_utils import \ + broadcast_batch_tensor_to_all_dimensions + + +@pytest.fixture(scope="module", autouse=True) +def set_random_seed(): + torch.manual_seed(2345234) + + +@pytest.fixture() +def batch_values(batch_size): + return torch.rand(batch_size) + + +@pytest.fixture() +def final_shape(batch_size, number_of_dimensions): + shape = torch.randint(low=1, high=5, size=(number_of_dimensions,)) + shape[0] = batch_size + return tuple(shape.numpy()) + + +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("number_of_dimensions", [1, 2, 3]) +def test_broadcast_batch_tensor_to_all_dimensions(batch_size, batch_values, final_shape): + broadcast_values = broadcast_batch_tensor_to_all_dimensions(batch_values, final_shape) + + value_arrays = broadcast_values.reshape(batch_size, -1) + + for expected_value, computed_values in zip(batch_values, value_arrays): + expected_values = torch.ones_like(computed_values) * expected_value + torch.testing.assert_allclose(expected_values, computed_values)