Skip to content

Commit

Permalink
Merge pull request #114 from mila-iqia/fokker_planck_regularizer
Browse files Browse the repository at this point in the history
Regularizers
  • Loading branch information
rousseab authored Jan 9, 2025
2 parents 0fe4f6d + f07d879 commit 4871158
Show file tree
Hide file tree
Showing 38 changed files with 3,277 additions and 73 deletions.
185 changes: 185 additions & 0 deletions experiments/analysis/visualize_score_vector_field_2D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import glob
import subprocess
import tempfile
from pathlib import Path

import einops
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

from diffusion_for_multi_scale_molecular_dynamics.analysis import \
PLOT_STYLE_PATH
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import (
AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters)
from diffusion_for_multi_scale_molecular_dynamics.namespace import (
AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL)
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \
VarianceScheduler
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \
NoiseParameters
from diffusion_for_multi_scale_molecular_dynamics.sample_diffusion import \
get_axl_network
from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \
map_relative_coordinates_to_unit_cell
from experiments.two_atoms_in_one_dimension.utils import \
get_2d_vector_field_figure

plt.style.use(PLOT_STYLE_PATH)


base_path = Path(
"/Users/brunorousseau/courtois/local_experiments/egnn_small_regularizer_orion/orion_working_dir/"
"785ed337118e5c748ca70517ff8569ee/last_model"
)

output_path = Path("/Users/brunorousseau/courtois/local_experiments/videos")

output_name = "785ed337118e5c748ca70517ff8569ee"


sigma_min = 0.001
sigma_max = 0.2
sigma_d = 0.01

n1 = 100
n2 = 100
nt = 100

spatial_dimension = 1
x0_1 = 0.25
x0_2 = 0.75

if __name__ == "__main__":
noise_parameters = NoiseParameters(
total_time_steps=nt, sigma_min=sigma_min, sigma_max=sigma_max
)

score_network_parameters = AnalyticalScoreNetworkParameters(
number_of_atoms=2,
num_atom_types=1,
kmax=5,
equilibrium_relative_coordinates=[[x0_1], [x0_2]],
sigma_d=sigma_d,
spatial_dimension=spatial_dimension,
use_permutation_invariance=True,
)

analytical_score_network = AnalyticalScoreNetwork(score_network_parameters)

checkpoint_path = glob.glob(str(base_path / "**/last_model*.ckpt"), recursive=True)[
0
]
checkpoint_name = Path(checkpoint_path).name
axl_network = get_axl_network(checkpoint_path)

list_times = torch.linspace(0.0, 1.0, nt)
list_sigmas = VarianceScheduler(noise_parameters).get_sigma(list_times).numpy()

x1 = torch.linspace(0, 1, n1)
x2 = torch.linspace(0, 1, n2)

X1, X2_ = torch.meshgrid(x1, x2, indexing="xy")
X2 = torch.flip(X2_, dims=[0])

relative_coordinates = einops.repeat(
[X1, X2], "natoms n1 n2 -> (n1 n2) natoms space", space=spatial_dimension
).contiguous()
relative_coordinates = map_relative_coordinates_to_unit_cell(relative_coordinates)

forces = torch.zeros_like(relative_coordinates)
batch_size, natoms, _ = relative_coordinates.shape

atom_types = torch.ones(batch_size, natoms, dtype=torch.int64)

list_ground_truth_probabilities = []
list_sigma_normalized_scores = []
for time, sigma in tqdm(zip(list_times, list_sigmas), "SIGMAS"):
grid_sigmas = sigma * torch.ones_like(relative_coordinates)
flat_probabilities, flat_normalized_scores = (
analytical_score_network.get_probabilities_and_normalized_scores(
relative_coordinates, grid_sigmas
)
)
probabilities = einops.rearrange(
flat_probabilities, "(n1 n2) -> n1 n2", n1=n1, n2=n2
)
list_ground_truth_probabilities.append(probabilities)

sigma_t = sigma * torch.ones(batch_size, 1)
times = time * torch.ones(batch_size, 1)
unit_cell = torch.ones(batch_size, 1, 1)

composition = AXL(
A=atom_types,
X=relative_coordinates,
L=torch.zeros_like(relative_coordinates),
)

batch = {
NOISY_AXL_COMPOSITION: composition,
NOISE: sigma_t,
TIME: times,
UNIT_CELL: unit_cell,
CARTESIAN_FORCES: forces,
}

model_predictions = axl_network(batch)
sigma_normalized_scores = einops.rearrange(
model_predictions.X.detach(),
"(n1 n2) natoms space -> n1 n2 natoms space",
n1=n1,
n2=n2,
)

list_sigma_normalized_scores.append(sigma_normalized_scores)

sigma_normalized_scores = torch.stack(list_sigma_normalized_scores).squeeze(-1)
ground_truth_probabilities = torch.stack(list_ground_truth_probabilities)

# ================================================================================

s = 2
with tempfile.TemporaryDirectory() as tmpdirname:

tmp_dir = Path(tmpdirname)

for time_idx in tqdm(range(len(list_times)), "VIDEO"):
sigma_t = list_sigmas[time_idx]
time = list_times[time_idx].item()

fig = get_2d_vector_field_figure(
X1=X1,
X2=X2,
probabilities=ground_truth_probabilities[time_idx],
sigma_normalized_scores=sigma_normalized_scores[time_idx],
time=time,
sigma_t=sigma_t,
sigma_d=sigma_d,
supsampling_scale=s,
)

output_image = tmp_dir / f"vector_field_{time_idx}.png"
fig.savefig(output_image)
plt.close(fig)

output_path.mkdir(parents=True, exist_ok=True)
output_file_path = output_path / f"vector_field_{output_name}.mp4"

# ffmpeg -r 10 -start_number 0 -i vector_field_%d.png -vcodec libx264 -pix_fmt yuv420p mlp_vector_field.mp4
commands = [
"ffmpeg",
"-r",
"10",
"-start_number",
"0",
"-i",
str(tmp_dir / "vector_field_%d.png"),
"-vcodec",
"libx264",
"-pix_fmt",
"yuv420p",
str(output_file_path),
]

process = subprocess.run(commands, capture_output=True, text=True)
6 changes: 6 additions & 0 deletions experiments/regularization_toy_problem/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pathlib import Path

EXPERIMENTS_DIR = Path(__file__).parent / "experiments"

RESULTS_DIR: Path = Path(__file__).parent / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#================================================================================
# Configuration file for a diffusion experiment for 2 pseudo-atoms in 1D.
#================================================================================
exp_name: mlp
run_name: consistency_regularizer
max_epoch: 1000
log_every_n_steps: 1
gradient_clipping: 0.0
accumulate_grad_batches: 1

elements: [A]

# set to null to avoid setting a seed (can speed up GPU computation, but
# results will not be reproducible)
seed: 1234

# Data: a fake dataloader will recreate the same example over and over.
data:
data_source: gaussian
random_seed: 42
number_of_atoms: 2
sigma_d: 0.01
equilibrium_relative_coordinates:
- [0.25]
- [0.75]

train_dataset_size: 8_192
valid_dataset_size: 1_024

batch_size: 64 # batch size for everyone
num_workers: 0
max_atom: 2
spatial_dimension: 1

# architecture
spatial_dimension: 1

model:
loss:
coordinates_algorithm: mse
atom_types_ce_weight: 0.0
atom_types_lambda_weight: 0.0
relative_coordinates_lambda_weight: 1.0
lattice_lambda_weight: 0.0
score_network:
architecture: mlp
use_permutation_invariance: True
spatial_dimension: 1
number_of_atoms: 2
num_atom_types: 1
n_hidden_dimensions: 3
hidden_dimensions_size: 64
relative_coordinates_embedding_dimensions_size: 32
noise_embedding_dimensions_size: 16
time_embedding_dimensions_size: 16
atom_type_embedding_dimensions_size: 8
condition_embedding_size: 8
noise:
total_time_steps: 100
sigma_min: 0.001
sigma_max: 0.2

# optimizer and scheduler
optimizer:
name: adamw
learning_rate: 0.001
weight_decay: 5.0e-8



regularizer:
type: consistency
maximum_number_of_steps: 5
number_of_burn_in_epochs: 0
regularizer_lambda_weight: 0.001

noise:
total_time_steps: 100
sigma_min: 0.001
sigma_max: 0.2

sampling:
num_atom_types: 1
number_of_atoms: 2
number_of_samples: 64
spatial_dimension: 1
number_of_corrector_steps: 0
cell_dimensions: [1.0]


scheduler:
name: CosineAnnealingLR
T_max: 1000
eta_min: 0.0

# early stopping
early_stopping:
metric: validation_epoch_loss
mode: min
patience: 1000

model_checkpoint:
monitor: validation_epoch_loss
mode: min

score_viewer:
record_every_n_epochs: 1

score_viewer_parameters:
sigma_min: 0.001
sigma_max: 0.2
number_of_space_steps: 100
starting_relative_coordinates:
- [0.0]
- [1.0]
ending_relative_coordinates:
- [1.0]
- [0.0]
analytical_score_network:
architecture: "analytical"
spatial_dimension: 1
number_of_atoms: 2
num_atom_types: 1
kmax: 5
equilibrium_relative_coordinates:
- [0.25]
- [0.75]
sigma_d: 0.01
use_permutation_invariance: True

logging:
- tensorboard
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

export OMP_PATH="/opt/homebrew/opt/libomp/include/"
export PYTORCH_ENABLE_MPS_FALLBACK=1


CONFIG=config.yaml

OUTPUT=./output/run1

SRC=/Users/brunorousseau/PycharmProjects/diffusion_for_multi_scale_molecular_dynamics/src/diffusion_for_multi_scale_molecular_dynamics


python ${SRC}/train_diffusion.py --accelerator "cpu" --config $CONFIG --output $OUTPUT
Loading

0 comments on commit 4871158

Please sign in to comment.