Skip to content

Commit

Permalink
Merge pull request #322 from choderalab/dev-profiler-from-training-co…
Browse files Browse the repository at this point in the history
…nfig

Add training parameter "profiler"
  • Loading branch information
chrisiacovella authored Nov 21, 2024
2 parents 3e57f87 + 16bbae2 commit 024e297
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 0 deletions.
1 change: 1 addition & 0 deletions modelforge/dataset/phalkethoh.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class PhAlkEthOHDataset(HDF5Dataset):
positions="geometry",
E="dft_total_energy",
F="dft_total_force",
dipole_moment="scf_dipole",
)

_available_properties = [
Expand Down
12 changes: 12 additions & 0 deletions modelforge/train/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ class Loggers(CaseInsensitiveEnum):
tensorboard = "tensorboard"


class Profilers(CaseInsensitiveEnum):
"""
Enum class for the experiment profiler
"""

simple = "simple"
advanced = "advanced"
pytorch = "pytorch"
xla = "xla"


class TensorboardConfig(ParametersBase):
save_dir: str

Expand Down Expand Up @@ -389,6 +400,7 @@ def ensure_logger_configuration(self) -> "ExperimentLogger":
limit_train_batches: Union[float, int, None] = None
limit_val_batches: Union[float, int, None] = None
limit_test_batches: Union[float, int, None] = None
profiler: Optional[Profilers] = None
optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW
min_number_of_epochs: Union[int, None] = None

Expand Down
3 changes: 3 additions & 0 deletions modelforge/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,8 @@ def setup_trainer(self) -> Trainer:
strategy = DDPStrategy(find_unused_parameters=True)
else:
strategy = "auto"
if self.training_parameter.profiler is not None:
log.debug(f"Using profiler {self.training_parameter.profiler}")

trainer = Trainer(
strategy=strategy,
Expand All @@ -1813,6 +1815,7 @@ def setup_trainer(self) -> Trainer:
limit_train_batches=self.training_parameter.limit_train_batches,
limit_val_batches=self.training_parameter.limit_val_batches,
limit_test_batches=self.training_parameter.limit_test_batches,
profiler=self.training_parameter.profiler,
num_sanity_val_steps=1,
gradient_clip_val=5.0, # FIXME: hardcoded for now
log_every_n_steps=self.runtime_parameter.log_every_n_steps,
Expand Down

0 comments on commit 024e297

Please sign in to comment.