-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #114 from mila-iqia/fokker_planck_regularizer
Regularizers
- Loading branch information
Showing
38 changed files
with
3,277 additions
and
73 deletions.
There are no files selected for viewing
185 changes: 185 additions & 0 deletions
185
experiments/analysis/visualize_score_vector_field_2D.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
import glob | ||
import subprocess | ||
import tempfile | ||
from pathlib import Path | ||
|
||
import einops | ||
import matplotlib.pyplot as plt | ||
import torch | ||
from tqdm import tqdm | ||
|
||
from diffusion_for_multi_scale_molecular_dynamics.analysis import \ | ||
PLOT_STYLE_PATH | ||
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( | ||
AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) | ||
from diffusion_for_multi_scale_molecular_dynamics.namespace import ( | ||
AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) | ||
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \ | ||
VarianceScheduler | ||
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ | ||
NoiseParameters | ||
from diffusion_for_multi_scale_molecular_dynamics.sample_diffusion import \ | ||
get_axl_network | ||
from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ | ||
map_relative_coordinates_to_unit_cell | ||
from experiments.two_atoms_in_one_dimension.utils import \ | ||
get_2d_vector_field_figure | ||
|
||
plt.style.use(PLOT_STYLE_PATH) | ||
|
||
|
||
base_path = Path( | ||
"/Users/brunorousseau/courtois/local_experiments/egnn_small_regularizer_orion/orion_working_dir/" | ||
"785ed337118e5c748ca70517ff8569ee/last_model" | ||
) | ||
|
||
output_path = Path("/Users/brunorousseau/courtois/local_experiments/videos") | ||
|
||
output_name = "785ed337118e5c748ca70517ff8569ee" | ||
|
||
|
||
sigma_min = 0.001 | ||
sigma_max = 0.2 | ||
sigma_d = 0.01 | ||
|
||
n1 = 100 | ||
n2 = 100 | ||
nt = 100 | ||
|
||
spatial_dimension = 1 | ||
x0_1 = 0.25 | ||
x0_2 = 0.75 | ||
|
||
if __name__ == "__main__": | ||
noise_parameters = NoiseParameters( | ||
total_time_steps=nt, sigma_min=sigma_min, sigma_max=sigma_max | ||
) | ||
|
||
score_network_parameters = AnalyticalScoreNetworkParameters( | ||
number_of_atoms=2, | ||
num_atom_types=1, | ||
kmax=5, | ||
equilibrium_relative_coordinates=[[x0_1], [x0_2]], | ||
sigma_d=sigma_d, | ||
spatial_dimension=spatial_dimension, | ||
use_permutation_invariance=True, | ||
) | ||
|
||
analytical_score_network = AnalyticalScoreNetwork(score_network_parameters) | ||
|
||
checkpoint_path = glob.glob(str(base_path / "**/last_model*.ckpt"), recursive=True)[ | ||
0 | ||
] | ||
checkpoint_name = Path(checkpoint_path).name | ||
axl_network = get_axl_network(checkpoint_path) | ||
|
||
list_times = torch.linspace(0.0, 1.0, nt) | ||
list_sigmas = VarianceScheduler(noise_parameters).get_sigma(list_times).numpy() | ||
|
||
x1 = torch.linspace(0, 1, n1) | ||
x2 = torch.linspace(0, 1, n2) | ||
|
||
X1, X2_ = torch.meshgrid(x1, x2, indexing="xy") | ||
X2 = torch.flip(X2_, dims=[0]) | ||
|
||
relative_coordinates = einops.repeat( | ||
[X1, X2], "natoms n1 n2 -> (n1 n2) natoms space", space=spatial_dimension | ||
).contiguous() | ||
relative_coordinates = map_relative_coordinates_to_unit_cell(relative_coordinates) | ||
|
||
forces = torch.zeros_like(relative_coordinates) | ||
batch_size, natoms, _ = relative_coordinates.shape | ||
|
||
atom_types = torch.ones(batch_size, natoms, dtype=torch.int64) | ||
|
||
list_ground_truth_probabilities = [] | ||
list_sigma_normalized_scores = [] | ||
for time, sigma in tqdm(zip(list_times, list_sigmas), "SIGMAS"): | ||
grid_sigmas = sigma * torch.ones_like(relative_coordinates) | ||
flat_probabilities, flat_normalized_scores = ( | ||
analytical_score_network.get_probabilities_and_normalized_scores( | ||
relative_coordinates, grid_sigmas | ||
) | ||
) | ||
probabilities = einops.rearrange( | ||
flat_probabilities, "(n1 n2) -> n1 n2", n1=n1, n2=n2 | ||
) | ||
list_ground_truth_probabilities.append(probabilities) | ||
|
||
sigma_t = sigma * torch.ones(batch_size, 1) | ||
times = time * torch.ones(batch_size, 1) | ||
unit_cell = torch.ones(batch_size, 1, 1) | ||
|
||
composition = AXL( | ||
A=atom_types, | ||
X=relative_coordinates, | ||
L=torch.zeros_like(relative_coordinates), | ||
) | ||
|
||
batch = { | ||
NOISY_AXL_COMPOSITION: composition, | ||
NOISE: sigma_t, | ||
TIME: times, | ||
UNIT_CELL: unit_cell, | ||
CARTESIAN_FORCES: forces, | ||
} | ||
|
||
model_predictions = axl_network(batch) | ||
sigma_normalized_scores = einops.rearrange( | ||
model_predictions.X.detach(), | ||
"(n1 n2) natoms space -> n1 n2 natoms space", | ||
n1=n1, | ||
n2=n2, | ||
) | ||
|
||
list_sigma_normalized_scores.append(sigma_normalized_scores) | ||
|
||
sigma_normalized_scores = torch.stack(list_sigma_normalized_scores).squeeze(-1) | ||
ground_truth_probabilities = torch.stack(list_ground_truth_probabilities) | ||
|
||
# ================================================================================ | ||
|
||
s = 2 | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
|
||
tmp_dir = Path(tmpdirname) | ||
|
||
for time_idx in tqdm(range(len(list_times)), "VIDEO"): | ||
sigma_t = list_sigmas[time_idx] | ||
time = list_times[time_idx].item() | ||
|
||
fig = get_2d_vector_field_figure( | ||
X1=X1, | ||
X2=X2, | ||
probabilities=ground_truth_probabilities[time_idx], | ||
sigma_normalized_scores=sigma_normalized_scores[time_idx], | ||
time=time, | ||
sigma_t=sigma_t, | ||
sigma_d=sigma_d, | ||
supsampling_scale=s, | ||
) | ||
|
||
output_image = tmp_dir / f"vector_field_{time_idx}.png" | ||
fig.savefig(output_image) | ||
plt.close(fig) | ||
|
||
output_path.mkdir(parents=True, exist_ok=True) | ||
output_file_path = output_path / f"vector_field_{output_name}.mp4" | ||
|
||
# ffmpeg -r 10 -start_number 0 -i vector_field_%d.png -vcodec libx264 -pix_fmt yuv420p mlp_vector_field.mp4 | ||
commands = [ | ||
"ffmpeg", | ||
"-r", | ||
"10", | ||
"-start_number", | ||
"0", | ||
"-i", | ||
str(tmp_dir / "vector_field_%d.png"), | ||
"-vcodec", | ||
"libx264", | ||
"-pix_fmt", | ||
"yuv420p", | ||
str(output_file_path), | ||
] | ||
|
||
process = subprocess.run(commands, capture_output=True, text=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from pathlib import Path | ||
|
||
EXPERIMENTS_DIR = Path(__file__).parent / "experiments" | ||
|
||
RESULTS_DIR: Path = Path(__file__).parent / "results" | ||
RESULTS_DIR.mkdir(parents=True, exist_ok=True) |
132 changes: 132 additions & 0 deletions
132
experiments/regularization_toy_problem/experiments/consistency_regularizer/config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
#================================================================================ | ||
# Configuration file for a diffusion experiment for 2 pseudo-atoms in 1D. | ||
#================================================================================ | ||
exp_name: mlp | ||
run_name: consistency_regularizer | ||
max_epoch: 1000 | ||
log_every_n_steps: 1 | ||
gradient_clipping: 0.0 | ||
accumulate_grad_batches: 1 | ||
|
||
elements: [A] | ||
|
||
# set to null to avoid setting a seed (can speed up GPU computation, but | ||
# results will not be reproducible) | ||
seed: 1234 | ||
|
||
# Data: a fake dataloader will recreate the same example over and over. | ||
data: | ||
data_source: gaussian | ||
random_seed: 42 | ||
number_of_atoms: 2 | ||
sigma_d: 0.01 | ||
equilibrium_relative_coordinates: | ||
- [0.25] | ||
- [0.75] | ||
|
||
train_dataset_size: 8_192 | ||
valid_dataset_size: 1_024 | ||
|
||
batch_size: 64 # batch size for everyone | ||
num_workers: 0 | ||
max_atom: 2 | ||
spatial_dimension: 1 | ||
|
||
# architecture | ||
spatial_dimension: 1 | ||
|
||
model: | ||
loss: | ||
coordinates_algorithm: mse | ||
atom_types_ce_weight: 0.0 | ||
atom_types_lambda_weight: 0.0 | ||
relative_coordinates_lambda_weight: 1.0 | ||
lattice_lambda_weight: 0.0 | ||
score_network: | ||
architecture: mlp | ||
use_permutation_invariance: True | ||
spatial_dimension: 1 | ||
number_of_atoms: 2 | ||
num_atom_types: 1 | ||
n_hidden_dimensions: 3 | ||
hidden_dimensions_size: 64 | ||
relative_coordinates_embedding_dimensions_size: 32 | ||
noise_embedding_dimensions_size: 16 | ||
time_embedding_dimensions_size: 16 | ||
atom_type_embedding_dimensions_size: 8 | ||
condition_embedding_size: 8 | ||
noise: | ||
total_time_steps: 100 | ||
sigma_min: 0.001 | ||
sigma_max: 0.2 | ||
|
||
# optimizer and scheduler | ||
optimizer: | ||
name: adamw | ||
learning_rate: 0.001 | ||
weight_decay: 5.0e-8 | ||
|
||
|
||
|
||
regularizer: | ||
type: consistency | ||
maximum_number_of_steps: 5 | ||
number_of_burn_in_epochs: 0 | ||
regularizer_lambda_weight: 0.001 | ||
|
||
noise: | ||
total_time_steps: 100 | ||
sigma_min: 0.001 | ||
sigma_max: 0.2 | ||
|
||
sampling: | ||
num_atom_types: 1 | ||
number_of_atoms: 2 | ||
number_of_samples: 64 | ||
spatial_dimension: 1 | ||
number_of_corrector_steps: 0 | ||
cell_dimensions: [1.0] | ||
|
||
|
||
scheduler: | ||
name: CosineAnnealingLR | ||
T_max: 1000 | ||
eta_min: 0.0 | ||
|
||
# early stopping | ||
early_stopping: | ||
metric: validation_epoch_loss | ||
mode: min | ||
patience: 1000 | ||
|
||
model_checkpoint: | ||
monitor: validation_epoch_loss | ||
mode: min | ||
|
||
score_viewer: | ||
record_every_n_epochs: 1 | ||
|
||
score_viewer_parameters: | ||
sigma_min: 0.001 | ||
sigma_max: 0.2 | ||
number_of_space_steps: 100 | ||
starting_relative_coordinates: | ||
- [0.0] | ||
- [1.0] | ||
ending_relative_coordinates: | ||
- [1.0] | ||
- [0.0] | ||
analytical_score_network: | ||
architecture: "analytical" | ||
spatial_dimension: 1 | ||
number_of_atoms: 2 | ||
num_atom_types: 1 | ||
kmax: 5 | ||
equilibrium_relative_coordinates: | ||
- [0.25] | ||
- [0.75] | ||
sigma_d: 0.01 | ||
use_permutation_invariance: True | ||
|
||
logging: | ||
- tensorboard |
14 changes: 14 additions & 0 deletions
14
experiments/regularization_toy_problem/experiments/consistency_regularizer/run.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/bin/bash | ||
|
||
export OMP_PATH="/opt/homebrew/opt/libomp/include/" | ||
export PYTORCH_ENABLE_MPS_FALLBACK=1 | ||
|
||
|
||
CONFIG=config.yaml | ||
|
||
OUTPUT=./output/run1 | ||
|
||
SRC=/Users/brunorousseau/PycharmProjects/diffusion_for_multi_scale_molecular_dynamics/src/diffusion_for_multi_scale_molecular_dynamics | ||
|
||
|
||
python ${SRC}/train_diffusion.py --accelerator "cpu" --config $CONFIG --output $OUTPUT |
Oops, something went wrong.