Skip to content

Commit

Permalink
add serialization to pd
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jan 5, 2025
1 parent 219c417 commit 8f32640
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions deepmd/pd/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.version import (
check_version_compatibility,
)


class EnergyStdLoss(TaskLoss):
Expand Down Expand Up @@ -422,3 +425,51 @@ def label_requirement(self) -> list[DataRequirementItem]:
)
)
return label_requirement

def serialize(self) -> dict:
"""Serialize the loss module.
Returns
-------
dict
The serialized loss module
"""
return {
"@class": "EnergyLoss",
"@version": 1,
"starter_learning_rate": self.starter_learning_rate,
"start_pref_e": self.start_pref_e,
"limit_pref_e": self.limit_pref_e,
"start_pref_f": self.start_pref_f,
"limit_pref_f": self.limit_pref_f,
"start_pref_v": self.start_pref_v,
"limit_pref_v": self.limit_pref_v,
"start_pref_ae": self.start_pref_ae,
"limit_pref_ae": self.limit_pref_ae,
"start_pref_pf": self.start_pref_pf,
"limit_pref_pf": self.limit_pref_pf,
"relative_f": self.relative_f,
"enable_atom_ener_coeff": self.enable_atom_ener_coeff,
"start_pref_gf": self.start_pref_gf,
"limit_pref_gf": self.limit_pref_gf,
"numb_generalized_coord": self.numb_generalized_coord,
}

@classmethod
def deserialize(cls, data: dict) -> "TaskLoss":
"""Deserialize the loss module.
Parameters
----------
data : dict
The serialized loss module
Returns
-------
Loss
The deserialized loss module
"""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
return cls(**data)

0 comments on commit 8f32640

Please sign in to comment.