Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss #8

Merged
merged 44 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
787f95f
Fixed model loader.
rousseab Mar 13, 2024
31ae1cb
Sampler of noisy positions.
rousseab Mar 13, 2024
c6c60ec
ignore results folder
rousseab Mar 6, 2024
10a31cf
A score base class.
rousseab Mar 6, 2024
ce66bb2
More comments in base score class, use variables for field names.
rousseab Mar 12, 2024
a1a376b
Class to sample time steps.
rousseab Mar 12, 2024
cb77ee6
Fix iterators to only cover what is needed.
rousseab Mar 12, 2024
36f81bf
Fix variable name bjork.
rousseab Mar 12, 2024
2c97132
Variance sampler.
rousseab Mar 12, 2024
64786ad
A plotting script to show what sigma looks like.
rousseab Mar 13, 2024
2d4d975
Plot the target noise, too.
rousseab Mar 13, 2024
41ff340
Remove needless end2end testing.
rousseab Mar 13, 2024
1545841
Fixed model loader.
rousseab Mar 13, 2024
311015a
Sampler of noisy positions.
rousseab Mar 13, 2024
27bfff7
Merge remote-tracking branch 'origin/loss' into loss
rousseab Mar 13, 2024
46dadfe
Fixed model loader.
rousseab Mar 13, 2024
3de9bc7
Sampler of noisy positions.
rousseab Mar 13, 2024
2e9d584
Merge remote-tracking branch 'origin/loss' into loss
rousseab Mar 14, 2024
152c173
Use dataclass hyper-parameters for the score network class.
rousseab Mar 15, 2024
f8ffcb2
More robust typing for optimizer.
rousseab Mar 15, 2024
3802630
Remoevd needless file.
rousseab Mar 15, 2024
5292e3b
New optimizer file.
rousseab Mar 15, 2024
af1b502
A better name for the noisy positions.
rousseab Mar 15, 2024
ee766b9
A cleaner way to draw noise samples.
rousseab Mar 16, 2024
9eae929
Method to return all noise arrays.
rousseab Mar 16, 2024
61ef471
Fix plotting script.
rousseab Mar 16, 2024
8fbd19b
Remove needless time sampler code.
rousseab Mar 16, 2024
25c627f
Reshaping utility to broadcast batch quantity.
rousseab Mar 16, 2024
8679dec
Reshaping utility to broadcast batch quantity.
rousseab Mar 16, 2024
1a1108a
Simplify the noisy position sampler.
rousseab Mar 16, 2024
537e11f
Use reshape to avoid non-contiguous bjork.
rousseab Mar 16, 2024
01c9901
Generic step to compute the loss.
rousseab Mar 16, 2024
386b26a
Improved docstring.
rousseab Mar 18, 2024
342344e
Position diffusion lightning model that computes loss.
rousseab Mar 18, 2024
4418bbb
Cleaner way of setting seeds in tests.
rousseab Mar 18, 2024
425df35
Fixed an inconsistency in the definition of "sigma normalized".
rousseab Mar 18, 2024
4537e47
Don't be too demanding on the target computation.
rousseab Mar 18, 2024
d7aefc2
Cleaner testing asserts.
rousseab Mar 18, 2024
9a52b3b
Cleaner testing asserts.
rousseab Mar 18, 2024
cbfd768
Fix plot.
rousseab Mar 18, 2024
1c4f674
A small sanity check experiments to see if we can overfit fake data.
rousseab Mar 18, 2024
5def43e
Fix issues with pytorch's bad version of modulo.
rousseab Mar 18, 2024
4503c72
A beefier mlp.
rousseab Mar 18, 2024
63ebeb6
An overifiting sanity check.
rousseab Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions crystal_diffusion/analysis/exploding_variance_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,29 @@

from crystal_diffusion import ANALYSIS_RESULTS_DIR
from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH
from crystal_diffusion.samplers.time_sampler import TimeParameters, TimeSampler
from crystal_diffusion.samplers.variance_sampler import (
ExplodingVarianceSampler, VarianceParameters)
ExplodingVarianceSampler, NoiseParameters)
from crystal_diffusion.score.wrapped_gaussian_score import \
get_sigma_normalized_score

plt.style.use(PLOT_STYLE_PATH)

if __name__ == '__main__':

variance_parameters = VarianceParameters()
time_parameters = TimeParameters(total_time_steps=1000)
noise_parameters = NoiseParameters(total_time_steps=1000)
variance_sampler = ExplodingVarianceSampler(noise_parameters=noise_parameters)

time_sampler = TimeSampler(time_parameters=time_parameters)
variance_sampler = ExplodingVarianceSampler(variance_parameters=variance_parameters,
time_sampler=time_sampler)
noise = variance_sampler.get_all_noise()

indices = torch.arange(time_parameters.total_time_steps)
times = time_sampler.get_time_steps(indices)
sigmas = torch.sqrt(variance_sampler.get_variances(indices))
gs = torch.sqrt(variance_sampler.get_g_squared_factors(indices[1:]))

# A first figure to compare the "smart" and the "brute force" results
fig1 = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig1.suptitle("Noise Schedule")

ax1 = fig1.add_subplot(221)
ax2 = fig1.add_subplot(223)
ax3 = fig1.add_subplot(122)

ax1.plot(times, sigmas, '-', c='k', lw=2)
ax2.plot(times[1:], gs, '-', c='k', lw=2)
ax1.plot(noise.time, noise.sigma, '-', c='k', lw=2)
ax2.plot(noise.time[1:], noise.g[1:], '-', c='k', lw=2)

ax1.set_ylabel('$\\sigma(t)$')
ax2.set_ylabel('$g(t)$')
Expand All @@ -55,17 +46,18 @@

kmax = 4
indices = torch.tensor([1, 250, 750, 999])
times = time_sampler.get_time_steps(indices)
sigmas = torch.sqrt(variance_sampler.get_variances(indices))
gs_squared = variance_sampler.get_g_squared_factors(indices)

times = noise.time.take(indices)
sigmas = noise.sigma.take(indices)
gs_squared = noise.g_squared.take(indices)

for t, sigma in zip(times, sigmas):
target_scores = get_sigma_normalized_score(relative_positions,
torch.ones_like(relative_positions) * sigma,
kmax=kmax)
ax3.plot(relative_positions, sigma * target_scores, label=f"t = {t:3.2f}")
target_sigma_normalized_scores = get_sigma_normalized_score(relative_positions,
torch.ones_like(relative_positions) * sigma,
kmax=kmax)
ax3.plot(relative_positions, target_sigma_normalized_scores, label=f"t = {t:3.2f}")

ax3.set_title("Target Noise")
ax3.set_title("Target Normalized Score")
ax3.set_xlabel("relative position, u")
ax3.set_ylabel("$\\sigma(t) \\times S(u, t)$")
ax3.legend(loc=0)
Expand Down
8 changes: 4 additions & 4 deletions crystal_diffusion/analysis/target_score_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from crystal_diffusion import ANALYSIS_RESULTS_DIR
from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH
from crystal_diffusion.score.wrapped_gaussian_score import (
SIGMA_THRESHOLD, get_expected_sigma_normalized_score_brute_force,
get_sigma_normalized_score)
SIGMA_THRESHOLD, get_sigma_normalized_score,
get_sigma_normalized_score_brute_force)

plt.style.use(PLOT_STYLE_PATH)

Expand All @@ -32,7 +32,7 @@
sigma = sigma_factor * SIGMA_THRESHOLD

sigmas = torch.ones_like(relative_positions) * sigma
list_scores_brute = np.array([get_expected_sigma_normalized_score_brute_force(u, sigma) for u in list_u])
list_scores_brute = np.array([get_sigma_normalized_score_brute_force(u, sigma) for u in list_u])
list_scores = get_sigma_normalized_score(relative_positions, sigmas, kmax=kmax).numpy()
error = list_scores - list_scores_brute

Expand Down Expand Up @@ -73,7 +73,7 @@
ms=ms, c=color, lw=2, alpha=0.25, label=f'kmax = {kmax}')

list_scores_brute = np.array([
get_expected_sigma_normalized_score_brute_force(u, sigma, kmax=4 * kmax) for sigma in sigmas])
get_sigma_normalized_score_brute_force(u, sigma, kmax=4 * kmax) for sigma in sigmas])
ax4.semilogy(sigma_factors, list_scores_brute, 'o-',
ms=ms, c=color, lw=2, alpha=0.25, label=f'kmax = {4 * kmax}')

Expand Down
55 changes: 0 additions & 55 deletions crystal_diffusion/models/my_model.py

This file was deleted.

51 changes: 0 additions & 51 deletions crystal_diffusion/models/optim.py

This file was deleted.

41 changes: 41 additions & 0 deletions crystal_diffusion/models/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from dataclasses import dataclass
from enum import Enum

import torch
from torch import optim

logger = logging.getLogger(__name__)


class ValidOptimizerNames(Enum):
"""Valid optimizer names."""
adam = "adam"
sgd = "sgd"


@dataclass(kw_only=True)
class OptimizerParameters:
"""Parameters for the optimizer."""
name: ValidOptimizerNames
learning_rate: float


def load_optimizer(hyper_params: OptimizerParameters, model: torch.nn.Module) -> optim.Optimizer:
"""Instantiate the optimizer.

Args:
hyper_params : hyperparameters defining the optimizer
model : A neural network model.

Returns:
optimizer : The optimizer for the given model
"""
match hyper_params.name:
case ValidOptimizerNames.adam:
optimizer = optim.Adam(model.parameters(), lr=hyper_params.learning_rate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depending on the regime, we might have to change the eps variable in adam. It might be premature to add it at this stage, it is easy enough to do later. Just something to keep in mind.

case ValidOptimizerNames.sgd:
optimizer = optim.SGD(model.parameters(), lr=hyper_params.learning_rate)
case _:
raise ValueError(f"optimizer {hyper_params.name} not supported")
return optimizer
Loading
Loading