-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loss #8
Merged
Loss #8
Changes from all commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
787f95f
Fixed model loader.
rousseab 31ae1cb
Sampler of noisy positions.
rousseab c6c60ec
ignore results folder
rousseab 10a31cf
A score base class.
rousseab ce66bb2
More comments in base score class, use variables for field names.
rousseab a1a376b
Class to sample time steps.
rousseab cb77ee6
Fix iterators to only cover what is needed.
rousseab 36f81bf
Fix variable name bjork.
rousseab 2c97132
Variance sampler.
rousseab 64786ad
A plotting script to show what sigma looks like.
rousseab 2d4d975
Plot the target noise, too.
rousseab 41ff340
Remove needless end2end testing.
rousseab 1545841
Fixed model loader.
rousseab 311015a
Sampler of noisy positions.
rousseab 27bfff7
Merge remote-tracking branch 'origin/loss' into loss
rousseab 46dadfe
Fixed model loader.
rousseab 3de9bc7
Sampler of noisy positions.
rousseab 2e9d584
Merge remote-tracking branch 'origin/loss' into loss
rousseab 152c173
Use dataclass hyper-parameters for the score network class.
rousseab f8ffcb2
More robust typing for optimizer.
rousseab 3802630
Remoevd needless file.
rousseab 5292e3b
New optimizer file.
rousseab af1b502
A better name for the noisy positions.
rousseab ee766b9
A cleaner way to draw noise samples.
rousseab 9eae929
Method to return all noise arrays.
rousseab 61ef471
Fix plotting script.
rousseab 8fbd19b
Remove needless time sampler code.
rousseab 25c627f
Reshaping utility to broadcast batch quantity.
rousseab 8679dec
Reshaping utility to broadcast batch quantity.
rousseab 1a1108a
Simplify the noisy position sampler.
rousseab 537e11f
Use reshape to avoid non-contiguous bjork.
rousseab 01c9901
Generic step to compute the loss.
rousseab 386b26a
Improved docstring.
rousseab 342344e
Position diffusion lightning model that computes loss.
rousseab 4418bbb
Cleaner way of setting seeds in tests.
rousseab 425df35
Fixed an inconsistency in the definition of "sigma normalized".
rousseab 4537e47
Don't be too demanding on the target computation.
rousseab d7aefc2
Cleaner testing asserts.
rousseab 9a52b3b
Cleaner testing asserts.
rousseab cbfd768
Fix plot.
rousseab 1c4f674
A small sanity check experiments to see if we can overfit fake data.
rousseab 5def43e
Fix issues with pytorch's bad version of modulo.
rousseab 4503c72
A beefier mlp.
rousseab 63ebeb6
An overifiting sanity check.
rousseab File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import logging | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
|
||
import torch | ||
from torch import optim | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ValidOptimizerNames(Enum): | ||
"""Valid optimizer names.""" | ||
adam = "adam" | ||
sgd = "sgd" | ||
|
||
|
||
@dataclass(kw_only=True) | ||
class OptimizerParameters: | ||
"""Parameters for the optimizer.""" | ||
name: ValidOptimizerNames | ||
learning_rate: float | ||
|
||
|
||
def load_optimizer(hyper_params: OptimizerParameters, model: torch.nn.Module) -> optim.Optimizer: | ||
"""Instantiate the optimizer. | ||
|
||
Args: | ||
hyper_params : hyperparameters defining the optimizer | ||
model : A neural network model. | ||
|
||
Returns: | ||
optimizer : The optimizer for the given model | ||
""" | ||
match hyper_params.name: | ||
case ValidOptimizerNames.adam: | ||
optimizer = optim.Adam(model.parameters(), lr=hyper_params.learning_rate) | ||
case ValidOptimizerNames.sgd: | ||
optimizer = optim.SGD(model.parameters(), lr=hyper_params.learning_rate) | ||
case _: | ||
raise ValueError(f"optimizer {hyper_params.name} not supported") | ||
return optimizer |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
depending on the regime, we might have to change the eps variable in adam. It might be premature to add it at this stage, it is easy enough to do later. Just something to keep in mind.