Skip to content

Commit

Permalink
Merge pull request #327 from chrisiacovella/improved_dataset_control
Browse files Browse the repository at this point in the history
Improved dataset control
  • Loading branch information
chrisiacovella authored Dec 5, 2024
2 parents f8c7834 + 3ac9e6c commit f8503c2
Show file tree
Hide file tree
Showing 20 changed files with 388 additions and 70 deletions.
2 changes: 2 additions & 0 deletions docs/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,6 @@ Explanation of fields in `qm9.toml`:
- `dataset_name`: Specifies the name of the dataset. For this example, it is QM9.
- `number_of_worker`: Determines the number of worker threads for data loading. Increasing the number of workers can speed up data loading but requires more memory.
- `version_select`: Indicates the version of the dataset to use. In this example, it points to a small subset of the dataset for quick testing. To use the full QM9 dataset, set this variable to `latest`.
- `properties_of_interest`: Lists the properties of interest to load from the hdf5 file.
- `properties_assignment`: Maps the properties of interest to the corresponding fields in the dataset. This mapping is crucial for the correct loading of properties during training; note, many datasets contain multiple properties can potentially be swapped (e.g., energy calculated with or without dispersion corrections, different charge population schemes, different levels of theory, etc.). Any properties listed here must appear in the properties of interest list; the code will raise a validation error if this condition is not met.

1 change: 1 addition & 0 deletions docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ These architectures can be trained on the following datasets (distributed via ze
- QM9
- SPICE1 (/openff)
- SPICE2
- tmQM

By default, potentials predict the total energy and per-atom forces within a given cutoff radius and can be trained on energies and forces.

Expand Down
2 changes: 1 addition & 1 deletion modelforge/curation/scripts/curate_spice114_openff.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main():

# We'll want to provide some simple means of versioning
# if we make updates to either the underlying dataset, curation modules, or parameters given to the code
version = "1"
version = "2"
# version of the dataset to curate
version_select = f"v_0"

Expand Down
2 changes: 1 addition & 1 deletion modelforge/curation/spice_1_openff_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _init_record_entries_series(self):
"dft_total_force": "series_atom",
"formation_energy": "series_mol",
"mbis_charges": "series_atom",
"scf_dipole": "series_atom",
"scf_dipole": "series_mol",
}

# we will use the retry package to allow us to resume download if we lose connection to the server
Expand Down
82 changes: 26 additions & 56 deletions modelforge/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,62 +22,7 @@
if TYPE_CHECKING:
from modelforge.potential.processing import AtomicSelfEnergies

from enum import Enum

from pydantic import BaseModel, ConfigDict, Field


class CaseInsensitiveEnum(str, Enum):
"""
Enum class that allows case-insensitive comparison of its members.
"""

@classmethod
def _missing_(cls, value):
for member in cls:
if member.value.lower() == value.lower():
return member
return super()._missing_(value)


class DataSetName(CaseInsensitiveEnum):
QM9 = "QM9"
ANI1X = "ANI1X"
ANI2X = "ANI2X"
SPICE1 = "SPICE1"
SPICE2 = "SPICE2"
SPICE1_OPENFF = "SPICE1_OPENFF"
PHALKETHOH = "PhAlkEthOH"
TMQM = "tmQM"


class DatasetParameters(BaseModel):
"""
Class to hold the dataset parameters.
Attributes
----------
dataset_name : DataSetName
The name of the dataset.
version_select : str
The version of the dataset to use.
num_workers : int
The number of workers to use for the DataLoader.
pin_memory : bool
Whether to pin memory for the DataLoader.
regenerate_processed_cache : bool
Whether to regenerate the processed cache.
"""

model_config = ConfigDict(
use_enum_values=True, arbitrary_types_allowed=True, validate_assignment=True
)

dataset_name: DataSetName
version_select: str
num_workers: int = Field(gt=0)
pin_memory: bool
regenerate_processed_cache: bool = False
from modelforge.dataset.parameters import DatasetParameters


# Define the input class
Expand Down Expand Up @@ -922,6 +867,8 @@ def __init__(
regenerate_cache: bool = False,
regenerate_dataset_statistic: bool = False,
regenerate_processed_cache: bool = True,
properties_of_interest: Optional[PropertyNames] = None,
properties_assignment: Optional[Dict[str, str]] = None,
):
"""
Initializes adData module for PyTorch Lightning handling data preparation and loading object with the specified configuration.
Expand Down Expand Up @@ -957,6 +904,14 @@ def __init__(
Directory to store the files.
regenerate_cache : bool, defaults to False
Whether to regenerate the cache.
regenerate_dataset_statistic : bool, defaults to False
Whether to regenerate the dataset statistics.
regenerate_processed_cache : bool, defaults to True
Whether to regenerate the processed cache.
properties_of_interest : Optional[PropertyNames]
The properties to include in the dataset.
properties_assignment : Optional[Dict[str, str]]
The properties of interest from the hdf5 dataset to associate with internal properties with the code.
"""
from modelforge.potential.neighbors import Pairlist
import os
Expand All @@ -979,6 +934,9 @@ def __init__(
self.val_dataset: Optional[TorchDataset] = None
self.test_dataset: Optional[TorchDataset] = None

self.properties_of_interest = properties_of_interest
self.properties_assignment = properties_assignment

# make sure we can handle a path with a ~ in it
self.local_cache_dir = os.path.expanduser(local_cache_dir)
# create the local cache directory if it does not exist
Expand Down Expand Up @@ -1038,6 +996,14 @@ def prepare_data(
local_cache_dir=self.local_cache_dir,
regenerate_cache=self.regenerate_cache,
)
if self.properties_of_interest is not None:
dataset.properties_of_interest = self.properties_of_interest

if self.properties_assignment is not None:
from modelforge.utils import PropertyNames

dataset._properties_names = PropertyNames(**self.properties_assignment)

torch_dataset = self._create_torch_dataset(dataset)
# if dataset statistics is present load it from disk
if (
Expand Down Expand Up @@ -1474,6 +1440,8 @@ def initialize_datamodule(
regression_ase: bool = False,
regenerate_dataset_statistic: bool = False,
local_cache_dir="./",
properties_of_interest: Optional[PropertyNames] = None,
properties_assignment: Optional[Dict[str, str]] = None,
) -> DataModule:
"""
Initialize a dataset for a given mode.
Expand All @@ -1489,6 +1457,8 @@ def initialize_datamodule(
regression_ase=regression_ase,
regenerate_dataset_statistic=regenerate_dataset_statistic,
local_cache_dir=local_cache_dir,
properties_of_interest=properties_of_interest,
properties_assignment=properties_assignment,
)
data_module.prepare_data()
data_module.setup()
Expand Down
100 changes: 100 additions & 0 deletions modelforge/dataset/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from enum import Enum
from typing import Optional, List, Dict
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self


class CaseInsensitiveEnum(str, Enum):
"""
Enum class that allows case-insensitive comparison of its members.
"""

@classmethod
def _missing_(cls, value):
for member in cls:
if member.value.lower() == value.lower():
return member
return super()._missing_(value)


# To avoid having to set config parameters for each class,
# we will just create a parent class for all the parameters classes.
class ParametersBase(BaseModel):
model_config = ConfigDict(
use_enum_values=True,
arbitrary_types_allowed=True,
validate_assignment=True,
extra="forbid",
)


class DataSetName(CaseInsensitiveEnum):
QM9 = "QM9"
ANI1X = "ANI1X"
ANI2X = "ANI2X"
SPICE1 = "SPICE1"
SPICE2 = "SPICE2"
SPICE1_OPENFF = "SPICE1_OPENFF"
PHALKETHOH = "PhAlkEthOH"
TMQM = "tmQM"


class PropertiesDefinition(ParametersBase):
atomic_numbers: str
positions: str
E: str
F: Optional[str] = None
dipole_moment: Optional[str] = None
total_charge: Optional[str] = None


class DatasetParameters(BaseModel):
"""
Class to hold the dataset parameters.
Attributes
----------
dataset_name : DataSetName
The name of the dataset.
version_select : str
The version of the dataset to use.
num_workers : int
The number of workers to use for the DataLoader.
pin_memory : bool
Whether to pin memory for the DataLoader.
regenerate_processed_cache : bool
Whether to regenerate the processed cache.
properties_of_interest : List[str]
The properties of interest to load from the hdf5 file.
properties_assignment : PropertiesDefinition
Association between the properties of interest and the internal naming convention
"""

model_config = ConfigDict(
use_enum_values=True, arbitrary_types_allowed=True, validate_assignment=True
)

dataset_name: DataSetName
version_select: str
num_workers: int = Field(gt=0)
pin_memory: bool
regenerate_processed_cache: bool = False
properties_of_interest: List[str]
properties_assignment: PropertiesDefinition

@model_validator(mode="after")
def validate_properties(self) -> Self:
"""
Validate that the properties of interest are in the properties assignment.
Note, datasets will validate the properties_of_interest against available properties in the dataset,
so we do not need additional validation here.
"""
for prop in self.properties_assignment.model_dump().values():
if prop not in self.properties_of_interest:
if prop is not None:
raise ValueError(
f"Property {prop} is not in the properties_of_interest."
)

return self
60 changes: 58 additions & 2 deletions modelforge/dataset/yaml_files/spice1openff.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,62 @@
dataset: spice1openff
latest: full_dataset_v1
latest_test: nc_1000_v1
latest: full_dataset_v2
latest_test: nc_1000_v2
full_dataset_v2:
version: 2
gz_data_file:
doi: 10.5281/zenodo.14264431
length: 2545870978
md5: 8e172839e8812747cbbf84f117168dc0
name: SPICE1_OpenFF_dataset_v2.hdf5.gz
hdf5_data_file:
md5: 7fbf84cc98af8cd63235be88eed30f51
name: SPICE1_OpenFF_dataset_v2.hdf5
processed_data_file:
md5: null
name: SPICE1_OpenFF_dataset_v2_processed.npz
url: https://zenodo.org/records/14264431/files/spice_114_openff_dataset_v2.hdf5.gz
nc_1000_v2:
version: 2
doi: 10.5281/zenodo.14269264
gz_data_file:
length: 2534128
md5: 9535d75cfee3facc8c1d62f1650fac07
name: SPICE1_OpenFF_dataset_v2_nc_1000.hdf5.gz
hdf5_data_file:
md5: 31e8a613f99d215b2fbfc33c77cdeff2
name: SPICE1_OpenFF_dataset_v2_nc_1000.hdf5
processed_data_file:
md5: null
name: SPICE1_OpenFF_dataset_v2_nc_1000_processed.npz
url: https://zenodo.org/records/14269264/files/spice_114_openff_dataset_v2_ntc_1000.hdf5.gz
full_dataset_v2_HCNOFClS:
version: 2
doi: 10.5281/zenodo.14264561
gz_data_file:
length: 2306573715
md5: 9de6a9e9e490beddfecc24332cba7f11
name: SPICE1_OpenFF_dataset_v2_HCNOFClS.hdf5.gz
hdf5_data_file:
md5: 23c6da20754023daa37c5e24f4d3c432
name: SPICE1_OpenFF_dataset_v2_HCNOFClS.hdf5
processed_data_file:
md5: null
name: SPICE1_OpenFF_dataset_v2_HCNOFClS_processed.npz
url: https://zenodo.org/records/14264561/files/spice_114_openff_dataset_v2_HCNOFClS.hdf5.gz
nc_1000_v2_HCNOFClS:
version: 2
doi: 10.5281/zenodo.14269058
gz_data_file:
length: 2534128
md5: fc07d6ac041b3ed835b20a50b97901d8
name: SPICE1_OpenFF_dataset_v2_nc_1000_HCNOFClS.hdf5.gz
hdf5_data_file:
md5: 03244a62f184611bbb6680d6ef6f13e3
name: SPICE1_OpenFF_dataset_v2_nc_1000_HCNOFClS.hdf5
processed_data_file:
md5: null
name: SPICE1_OpenFF_dataset_v2_nc_1000_HCNOFClS_processed.npz
url: https://zenodo.org/records/14269058/files/spice_114_openff_dataset_v2_ntc_1000_HCNOFClS.hdf5.gz
full_dataset_v1:
version: 1
gz_data_file:
Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class AtomicNumber(BaseModel):
number_of_per_atom_features: int = 32


class Featurization(BaseModel):
class Featurization(ParametersBase):
properties_to_featurize: List[str]
atomic_number: AtomicNumber = Field(default_factory=AtomicNumber)

Expand Down
6 changes: 6 additions & 0 deletions modelforge/tests/data/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ dataset_name = "QM9"
version_select = "nc_1000_v0"
num_workers = 4
pin_memory = true
properties_of_interest = ["atomic_numbers", "geometry", "internal_energy_at_0K", "dipole_moment"]

[dataset.properties_assignment]
atomic_numbers = "atomic_numbers"
positions = "geometry"
E = "internal_energy_at_0K"

[training]
number_of_epochs = 2
Expand Down
7 changes: 7 additions & 0 deletions modelforge/tests/data/dataset_defaults/ani1x.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,10 @@ dataset_name = "ANI1x"
version_select = "nc_1000_v0"
num_workers = 4
pin_memory = true
properties_of_interest = ["geometry", "atomic_numbers", "wb97x_dz.energy", "wb97x_dz.forces"]

[dataset.properties_assignment]
atomic_numbers="atomic_numbers"
positions="geometry"
E="wb97x_dz.energy"
F="wb97x_dz.forces"
9 changes: 8 additions & 1 deletion modelforge/tests/data/dataset_defaults/ani2x.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@
dataset_name = "ANI2x"
version_select = "nc_1000_v0"
num_workers = 4
pin_memory = true
pin_memory = true
properties_of_interest = ["atomic_numbers", "geometry", "energies", "forces"]

[dataset.properties_assignment]
atomic_numbers = "atomic_numbers"
positions = "geometry"
E = "energies"
F = "forces"
10 changes: 9 additions & 1 deletion modelforge/tests/data/dataset_defaults/phalkethoh.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,12 @@
dataset_name = "PHALKETHOH"
version_select = "nc_1000_v1"
num_workers = 4
pin_memory = true
pin_memory = true
properties_of_interest = ["geometry", "atomic_numbers", "dft_total_energy", "dft_total_force", "total_charge", "scf_dipole"]

[dataset.properties_assignment]
atomic_numbers = "atomic_numbers"
positions = "geometry"
E = "dft_total_energy"
F = "dft_total_force"
dipole_moment = "scf_dipole"
Loading

0 comments on commit f8503c2

Please sign in to comment.