Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): train with energy Hessian #4169

Open
wants to merge 220 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 204 commits
Commits
Show all changes
220 commits
Select commit Hold shift + click to select a range
347b064
Update argcheck.py
1azyking Aug 6, 2024
9a6b896
Update data.py
1azyking Aug 6, 2024
717e672
Update training.py
1azyking Aug 6, 2024
9ce1c5f
Update wrapper.py
1azyking Aug 6, 2024
2bedfc3
Create ener_hess.py
1azyking Aug 6, 2024
afb0227
Update __init__.py
1azyking Aug 6, 2024
5742674
Update __init__.py
1azyking Aug 6, 2024
75f4dce
Update make_hessian_model.py
1azyking Aug 6, 2024
0165eba
Create ener_hess_model.py
1azyking Aug 6, 2024
e560573
Update make_hessian_model.py
1azyking Aug 9, 2024
5384731
Update test.py
1azyking Aug 14, 2024
ab73bde
Update deep_eval.py
1azyking Aug 14, 2024
1befb61
Update deep_pot.py
1azyking Aug 14, 2024
9ae7163
Update deep_eval.py
1azyking Aug 14, 2024
6dd723e
Update __init__.py
1azyking Aug 14, 2024
4f078b8
Update test.py
1azyking Aug 14, 2024
bedb6ec
Update deep_eval.py
1azyking Aug 14, 2024
484c62e
Update deep_pot.py
1azyking Aug 14, 2024
7502ae8
Update deep_eval.py
1azyking Aug 14, 2024
950a4da
Update ener_hess.py
1azyking Aug 14, 2024
2ccc19a
Update ener_hess_model.py
1azyking Aug 14, 2024
bc92bde
Update make_hessian_model.py
1azyking Aug 14, 2024
d69e5f7
Update training.py
1azyking Aug 14, 2024
e2acfe2
Update wrapper.py
1azyking Aug 14, 2024
bffe6eb
Update deep_eval.py
1azyking Aug 14, 2024
31de3e4
Update deep_eval.py
1azyking Aug 15, 2024
ae06bfc
Create test_dp_hessian_model.py
1azyking Aug 15, 2024
24f76f1
Update argcheck.py
1azyking Aug 21, 2024
a933f7a
Update test.py
1azyking Aug 21, 2024
37e9ba4
Update tensor.py
1azyking Aug 21, 2024
0dd8124
Update test.py
1azyking Aug 26, 2024
65ef47a
Update env_mat.py
1azyking Sep 29, 2024
1004f76
Update env_mat.py
1azyking Sep 29, 2024
97f09b0
Update training.py
1azyking Sep 29, 2024
967952b
Create input.json
1azyking Sep 29, 2024
7db3688
Delete examples/hess directory
1azyking Sep 29, 2024
d140afa
Create input.json
1azyking Sep 29, 2024
2e41d8a
Create tmp
1azyking Sep 29, 2024
f54a1fd
Add files via upload
1azyking Sep 29, 2024
5bf9b3a
Delete examples/hessian/data/H8C4N2O/set.000/tmp
1azyking Sep 29, 2024
e8c4b75
Add files via upload
1azyking Sep 29, 2024
af9544e
Create tmp
1azyking Sep 29, 2024
4a7ac97
Add files via upload
1azyking Sep 29, 2024
7376f08
Delete examples/hessian/data/H10C5N2O/set.000/tmp
1azyking Sep 29, 2024
0cd6951
Add files via upload
1azyking Sep 29, 2024
a556e41
Update input.json
1azyking Sep 29, 2024
f1ed807
Create input.json
1azyking Sep 29, 2024
f68983f
Update env_mat.py
1azyking Sep 29, 2024
fea198e
Update input.json
1azyking Sep 29, 2024
19f4879
Update input.json
1azyking Sep 29, 2024
87d765a
Delete examples/hessian/data/H10C5N2O/set.000/virial.npy
1azyking Sep 29, 2024
43c0bbf
Delete examples/hessian/data/H8C4N2O/set.000/virial.npy
1azyking Sep 29, 2024
3fa9663
Create train-energy-hessian.md
1azyking Sep 29, 2024
c6d6519
Restore tensor.py
1azyking Sep 29, 2024
023ce20
Update deep_pot.py
1azyking Sep 29, 2024
5166351
Update deep_eval.py
1azyking Sep 29, 2024
a0a93b4
Update test.py
1azyking Sep 29, 2024
c8fb851
Update argcheck.py
1azyking Sep 29, 2024
1d117ca
Update data.py
1azyking Sep 29, 2024
570477b
Update deep_eval.py
1azyking Sep 29, 2024
a781b16
Update training.py
1azyking Sep 29, 2024
63da9a4
Update ener_hess.py
1azyking Sep 29, 2024
876ccc6
Update __init__.py
1azyking Sep 29, 2024
febc0fa
Update __init__.py
1azyking Sep 29, 2024
307b175
Update ener_hess_model.py
1azyking Sep 29, 2024
9ea1a6f
Update make_hessian_model.py
1azyking Sep 29, 2024
0a9a0ef
Update train-energy-hessian.md
1azyking Sep 29, 2024
1df1329
Update __init__.py
1azyking Sep 29, 2024
0880d2b
Update train-energy-hessian.md
1azyking Sep 29, 2024
3a185ea
Update __init__.py
1azyking Sep 29, 2024
e4f6c0d
Update ener_hess_model.py
1azyking Sep 30, 2024
ecbb8da
Update deep_pot.py
1azyking Oct 3, 2024
8654f49
Update deepmd/entrypoints/test.py
1azyking Oct 3, 2024
af86914
Update training.py
1azyking Oct 4, 2024
deb13c3
Update test.py
1azyking Oct 5, 2024
40135ec
Update deep_pot.py
1azyking Oct 5, 2024
3ef3891
Update ener_hess.py
1azyking Oct 5, 2024
4bcac3d
Update wrapper.py
1azyking Oct 5, 2024
86f6d43
Update test_dp_hessian_model.py
1azyking Oct 5, 2024
9522e8b
Update test_dp_hessian_model.py
1azyking Oct 5, 2024
73ac324
Update test_dp_hessian_model.py
1azyking Oct 5, 2024
260d07c
Update training.py
1azyking Oct 5, 2024
a41c0f4
Update test.py
1azyking Oct 5, 2024
4879f4c
Update input.json
1azyking Oct 5, 2024
23fae86
Update test.py
1azyking Oct 5, 2024
868f69f
Update ener_hess.py
1azyking Oct 5, 2024
8062980
Merge branch 'devel' into enerhess
1azyking Oct 5, 2024
54411bd
Update ener_hess.py
1azyking Oct 5, 2024
cd69988
Merge branch 'devel' into enerhess
1azyking Oct 9, 2024
17e62cd
Update training.py
1azyking Oct 9, 2024
776c2ef
Update training.py
1azyking Oct 9, 2024
ee926a8
Update argcheck.py
1azyking Oct 10, 2024
e9d1dce
Update deep_eval.py
1azyking Oct 10, 2024
128f894
Merge branch 'devel' into enerhess
1azyking Oct 10, 2024
3c6f47a
Update deep_eval.py
1azyking Oct 11, 2024
68bfb92
Merge branch 'devel' into enerhess
1azyking Oct 13, 2024
8a8faea
Update __init__.py
1azyking Oct 13, 2024
206787b
Update __init__.py
1azyking Oct 13, 2024
d2cb23d
Update data.py
1azyking Oct 13, 2024
cfcd9f7
Update deep_eval.py
1azyking Oct 13, 2024
afba7af
Update __init__.py
1azyking Oct 13, 2024
fb6e49a
Update training.py
1azyking Oct 13, 2024
95e55cf
Update deep_eval.py
1azyking Oct 13, 2024
90572c9
Update ener_hess.py
1azyking Oct 13, 2024
d42f041
Update ener_hess_model.py
1azyking Oct 13, 2024
0983bf0
Merge branch 'devel' into enerhess
1azyking Oct 13, 2024
672d975
Update deep_eval.py
1azyking Oct 15, 2024
1b79250
Update training.py
1azyking Oct 15, 2024
11af67a
Update deep_eval.py
1azyking Oct 15, 2024
79d5d62
Update data.py
1azyking Oct 15, 2024
3d09cde
Update deep_eval.py
1azyking Oct 15, 2024
a4c94f9
Update test.py
1azyking Oct 15, 2024
81d4cc7
Update __init__.py
1azyking Oct 15, 2024
51f31bd
Update __init__.py
1azyking Oct 15, 2024
a1adf35
Update ener_hess_model.py
1azyking Oct 15, 2024
75bda4b
Update training.py
1azyking Oct 15, 2024
baa5581
Update test_dp_hessian_model.py
1azyking Oct 15, 2024
1ac9484
Update train-energy-hessian.md
1azyking Oct 15, 2024
ade2616
Merge branch 'devel' into enerhess
1azyking Oct 15, 2024
4bd2906
Update train-energy-hessian.md
1azyking Oct 15, 2024
19cbc34
Update __init__.py
1azyking Oct 15, 2024
cff11b0
Update argcheck.py
1azyking Oct 15, 2024
6b0ad9d
Update ener_hess.py
1azyking Oct 15, 2024
6d5a463
Update __init__.py
1azyking Oct 15, 2024
69b5cb4
Update training.py
1azyking Oct 15, 2024
e891b47
Update test.py
1azyking Oct 15, 2024
bb10634
Update test.py
1azyking Oct 15, 2024
baeb71b
Update train-energy-hessian.md
1azyking Oct 15, 2024
68416d6
Update train-energy-hessian.md
1azyking Oct 15, 2024
f660694
Update train-energy-hessian.md
1azyking Oct 15, 2024
f4fc6ee
Update test.py
1azyking Oct 15, 2024
86b5451
Update test.py
1azyking Oct 15, 2024
1f766c3
Fixed bugs and formatted codes for fitting with hessian
anyangml Oct 17, 2024
83b4cc3
Merge branch 'devel' into enerhess
1azyking Oct 17, 2024
7d3bfda
Fixed bugs and formatted codes for fitting with hessian
anyangml Oct 17, 2024
c080c8a
Merge branch 'devel' into enerhess
1azyking Oct 17, 2024
b084877
Merge branch 'devel' into enerhess
1azyking Oct 18, 2024
551d5f9
Remove eval_model func.
1azyking Oct 21, 2024
faea02b
Merge branch 'devel' into enerhess
1azyking Oct 21, 2024
efc8ab7
Update deep_eval.py
1azyking Oct 23, 2024
c9405ed
Update deep_eval.py
1azyking Oct 23, 2024
fefe066
Update argcheck.py
1azyking Oct 23, 2024
f456699
Update training.py
1azyking Oct 23, 2024
3ee0bcb
Update deep_eval.py
1azyking Oct 23, 2024
10ed979
Update deep_eval.py
1azyking Oct 23, 2024
ca8f8c0
Update deep_eval.py
1azyking Oct 23, 2024
7a5c68b
Update deep_eval.py
1azyking Oct 23, 2024
ec8b2ef
Update test.py
1azyking Oct 23, 2024
f9706f6
Resolving conversations
1azyking Oct 23, 2024
934f5a0
Update system.md
1azyking Oct 23, 2024
87b0bb4
Update train-energy-hessian.md
1azyking Oct 23, 2024
24231ed
Update train-energy-hessian.md
1azyking Oct 23, 2024
8669fa5
Update index.rst
1azyking Oct 23, 2024
2fbcc64
Update overall.md
1azyking Oct 23, 2024
91f2435
Update train-energy-hessian.md
1azyking Oct 23, 2024
0774f79
Update input.json
1azyking Oct 23, 2024
ca7a96d
Update test_examples.py
1azyking Oct 23, 2024
793534d
Update __init__.py
1azyking Oct 23, 2024
d2553cf
Update ener.py
1azyking Oct 23, 2024
5eb7e46
Delete deepmd/pt/loss/ener_hess.py
1azyking Oct 23, 2024
4868f55
Update train-energy-hessian.md
1azyking Oct 23, 2024
3fe1c8d
Merge ener_hess_loss to ener_loss
1azyking Oct 23, 2024
637fef6
Merge branch 'devel' into enerhess
1azyking Oct 23, 2024
e3dbaaf
Update deepmd/calculator.py
1azyking Oct 24, 2024
b65af4c
Update deepmd/driver.py
1azyking Oct 24, 2024
73fc8fe
Update test_models.py
1azyking Oct 24, 2024
e2ad4f0
Update __init__.py
1azyking Oct 24, 2024
267e266
Update deepmd/entrypoints/test.py
1azyking Oct 27, 2024
f0bb34d
Resolving conversations
1azyking Oct 29, 2024
0256c0b
Resolving conversations
1azyking Oct 29, 2024
d698ef9
Merge branch 'devel' into enerhess
1azyking Oct 29, 2024
90d2468
Update env_mat.py
1azyking Oct 29, 2024
002defd
Update env_mat.py
1azyking Oct 29, 2024
737a415
Update env_mat.py
1azyking Oct 29, 2024
d04600d
Update test.py
1azyking Oct 29, 2024
4f11b66
Update env_mat.py
1azyking Oct 29, 2024
d1ff245
Update env_mat.py
1azyking Oct 29, 2024
7c3d126
Update test.py
1azyking Oct 29, 2024
ffce842
Update deep_pot.py
1azyking Oct 29, 2024
652ac0c
Update env_mat.py
1azyking Oct 29, 2024
bd46993
Merge branch 'devel' into enerhess
1azyking Oct 29, 2024
664767a
Update deepmd/infer/deep_pot.py
1azyking Oct 30, 2024
444b20d
Update deepmd/pt/model/descriptor/env_mat.py
1azyking Oct 30, 2024
4e8a9b5
Update argcheck.py
1azyking Oct 30, 2024
4343d16
Update deepmd/infer/deep_pot.py
1azyking Oct 30, 2024
e577b59
Resolving conversations
1azyking Nov 10, 2024
c6bafeb
Update overall.md
1azyking Nov 10, 2024
1e7c558
Resolving conversations
1azyking Nov 10, 2024
d28d50d
Update train-energy-hessian.md
1azyking Nov 10, 2024
3c5f2c0
Merge branch 'devel' into enerhess
1azyking Nov 10, 2024
c61df40
Resolving conversations
1azyking Nov 10, 2024
5014191
Merge branch 'devel' into enerhess
1azyking Nov 10, 2024
63719f0
Update deep_eval.py
1azyking Nov 10, 2024
f5c5e2a
Update deep_eval.py
1azyking Nov 11, 2024
53d5a23
Merge branch 'devel' into enerhess
1azyking Nov 11, 2024
7626d70
Merge branch 'devel' into enerhess
1azyking Nov 11, 2024
c472c9c
Merge branch 'devel' into enerhess
1azyking Nov 12, 2024
23107fa
Merge branch 'devel' into enerhess
1azyking Nov 13, 2024
44ee0b8
Merge branch 'devel' into enerhess
1azyking Nov 15, 2024
32d793c
Merge branch 'devel' into enerhess
1azyking Nov 23, 2024
05f878d
Merge branch 'devel' into enerhess
1azyking Nov 27, 2024
1c93aed
Merge branch 'devel' into enerhess
1azyking Nov 27, 2024
b3886aa
Merge branch 'devel' into enerhess
1azyking Nov 30, 2024
d947f24
Resolving conversations
1azyking Nov 30, 2024
6d351ed
Update training.py
1azyking Dec 2, 2024
dbfdd37
Merge branch 'devel' into enerhess
1azyking Dec 2, 2024
cf4c5d9
Merge branch 'devel' into enerhess
1azyking Dec 3, 2024
9e16fc4
Merge branch 'devel' into enerhess
1azyking Dec 3, 2024
c699e08
Update env_mat.py
1azyking Dec 6, 2024
d1361ca
Update make_hessian_model.py
1azyking Dec 6, 2024
2780984
Merge branch 'devel' into enerhess
1azyking Dec 9, 2024
9c8700b
Resovling conversations
1azyking Dec 14, 2024
19d0cd3
Merge branch 'devel' into enerhess
1azyking Dec 14, 2024
7dead9c
Update ener_model.py
1azyking Dec 15, 2024
6a31b02
Update test_change_bias.py
1azyking Dec 15, 2024
550178e
Merge branch 'devel' into enerhess
1azyking Dec 23, 2024
8c1656a
Merge branch 'devel' into enerhess
1azyking Dec 25, 2024
5f07d2b
Update train-energy-hessian.md
1azyking Dec 25, 2024
e95c140
Merge branch 'devel' into enerhess
1azyking Dec 25, 2024
d9602a1
Update train-energy-hessian.md
1azyking Dec 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def calculate(
cell = None
symbols = self.atoms.get_chemical_symbols()
atype = [self.type_dict[k] for k in symbols]
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)[:3]
self.results["energy"] = e[0][0]
# see https://gitlab.com/ase/ase/-/merge_requests/2485
self.results["free_energy"] = e[0][0]
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,9 @@ def _get_output_shape(self, odef, nframes, natoms):
# Something wrong here?
# return [nframes, *shape, natoms, 1]
return [nframes, natoms, *odef.shape, 1]
elif odef.category == OutputVariableCategory.DERV_R_DERV_R:
# hessian
return [nframes, 3 * natoms, 3 * natoms]
else:
raise RuntimeError("unknown category")

Expand Down
2 changes: 1 addition & 1 deletion deepmd/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def label(self, data: dict) -> dict:
cell = data["cells"].reshape((nframes, 9))
else:
cell = None
e, f, v = self.dp.eval(coord, cell, atype)
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)[:3]
data = data.copy()
data["energies"] = e.reshape((nframes,))
data["forces"] = f.reshape((nframes, natoms, 3))
Expand Down
39 changes: 37 additions & 2 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def test_ener(
if dp.has_spin:
data.add("spin", 3, atomic=True, must=True, high_prec=False)
data.add("force_mag", 3, atomic=True, must=False, high_prec=False)
if dp.has_hessian:
data.add("hessian", 1, atomic=True, must=True, high_prec=False)
1azyking marked this conversation as resolved.
Show resolved Hide resolved
1azyking marked this conversation as resolved.
Show resolved Hide resolved

test_data = data.get_test()
mixed_type = data.mixed_type
Expand Down Expand Up @@ -352,6 +354,9 @@ def test_ener(
energy = energy.reshape([numb_test, 1])
force = force.reshape([numb_test, -1])
virial = virial.reshape([numb_test, 9])
if dp.has_hessian:
hessian = ret[3]
hessian = hessian.reshape([numb_test, -1])
1azyking marked this conversation as resolved.
Show resolved Hide resolved
if has_atom_ener:
ae = ret[3]
av = ret[4]
Expand Down Expand Up @@ -415,6 +420,10 @@ def test_ener(
rmse_ea = rmse_e / natoms
mae_va = mae_v / natoms
rmse_va = rmse_v / natoms
if dp.has_hessian:
diff_h = hessian - test_data["hessian"][:numb_test]
mae_h = mae(diff_h)
rmse_h = rmse(diff_h)
if has_atom_ener:
diff_ae = test_data["atom_ener"][:numb_test].reshape([-1]) - ae.reshape([-1])
mae_ae = mae(diff_ae)
Expand Down Expand Up @@ -447,6 +456,9 @@ def test_ener(
if has_atom_ener:
log.info(f"Atomic ener MAE : {mae_ae:e} eV")
log.info(f"Atomic ener RMSE : {rmse_ae:e} eV")
if dp.has_hessian:
log.info(f"Hessian MAE : {mae_h:e} eV/A^2")
log.info(f"Hessian RMSE : {rmse_h:e} eV/A^2")

if detail_file is not None:
detail_path = Path(detail_file)
Expand Down Expand Up @@ -530,8 +542,24 @@ def test_ener(
"pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz",
append=append_detail,
)
if dp.has_hessian:
data_h = test_data["hessian"][:numb_test].reshape(-1, 1)
pred_h = hessian.reshape(-1, 1)
h = np.concatenate(
(
data_h,
pred_h,
),
axis=1,
)
save_txt_file(
detail_path.with_suffix(".h.out"),
h,
header=f"{system}: data_h pred_h (3Na*3Na matrix in row-major order)",
append=append_detail,
)
if not out_put_spin:
return {
dict_to_return = {
"mae_e": (mae_e, energy.size),
"mae_ea": (mae_ea, energy.size),
"mae_f": (mae_f, force.size),
Expand All @@ -544,7 +572,7 @@ def test_ener(
"rmse_va": (rmse_va, virial.size),
}
else:
return {
dict_to_return = {
"mae_e": (mae_e, energy.size),
"mae_ea": (mae_ea, energy.size),
"mae_fr": (mae_fr, force_r.size),
Expand All @@ -558,6 +586,10 @@ def test_ener(
"rmse_v": (rmse_v, virial.size),
"rmse_va": (rmse_va, virial.size),
}
if dp.has_hessian:
dict_to_return["mae_h"] = (mae_h, hessian.size)
dict_to_return["rmse_h"] = (rmse_h, hessian.size)
return dict_to_return


def print_ener_sys_avg(avg: dict[str, float]) -> None:
Expand All @@ -584,6 +616,9 @@ def print_ener_sys_avg(avg: dict[str, float]) -> None:
log.info(f"Virial RMSE : {avg['rmse_v']:e} eV")
log.info(f"Virial MAE/Natoms : {avg['mae_va']:e} eV")
log.info(f"Virial RMSE/Natoms : {avg['rmse_va']:e} eV")
if "rmse_h" in avg.keys():
1azyking marked this conversation as resolved.
Show resolved Hide resolved
log.info(f"Hessian MAE : {avg['mae_h']:e} eV/A^2")
log.info(f"Hessian RMSE : {avg['rmse_h']:e} eV/A^2")
1azyking marked this conversation as resolved.
Show resolved Hide resolved


def test_dos(
Expand Down
10 changes: 10 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class DeepEvalBackend(ABC):
# old models in v1
"global_polar": "global_polar",
"wfc": "wfc",
"energy_derv_r_derv_r": "hessian",
}

@abstractmethod
Expand Down Expand Up @@ -276,6 +277,10 @@ def get_has_spin(self) -> bool:
"""Check if the model has spin atom types."""
return False

def get_has_hessian(self):
"""Check if the model has hessian."""
return False

@abstractmethod
def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model. Only used in old implement."""
Expand Down Expand Up @@ -541,6 +546,11 @@ def has_spin(self) -> bool:
"""Check if the model has spin."""
return self.deep_eval.get_has_spin()

@property
def has_hessian(self) -> bool:
"""Check if the model has hessian."""
return self.deep_eval.get_has_hessian()

def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model. Only used in old implement."""
return self.deep_eval.get_ntypes_spin()
Expand Down
18 changes: 16 additions & 2 deletions deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def output_def(self) -> ModelOutputDef:
r_differentiable=True,
c_differentiable=True,
atomic=True,
r_hessian=True,
),
]
)
Expand Down Expand Up @@ -99,7 +100,10 @@ def eval(
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
) -> Union[
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
]:
pass

@overload
Expand All @@ -113,7 +117,10 @@ def eval(
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Union[
tuple[np.ndarray, np.ndarray, np.ndarray],
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
]:
pass

@overload
Expand Down Expand Up @@ -179,6 +186,8 @@ def eval(
atomic_virial
The atomic virial of the system, in shape (nframes, natoms, 9). Only returned
when atomic is True.
hessian
1azyking marked this conversation as resolved.
Show resolved Hide resolved
The Hessian matrix of the system, in shape (nframes, 3 * natoms, 3 * natoms). Returned when available.
"""
# This method has been used by:
# documentation python.md
Expand Down Expand Up @@ -239,6 +248,11 @@ def eval(
force_mag = results["energy_derv_r_mag"].reshape(nframes, natoms, 3)
mask_mag = results["mask_mag"].reshape(nframes, natoms, 1)
result = (*list(result), force_mag, mask_mag)
if self.deep_eval.get_has_hessian():
hessian = results["energy_derv_r_derv_r"].reshape(
nframes, 3 * natoms, 3 * natoms
)
result = (*list(result), hessian)
return result


Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ def _get_output_shape(self, odef, nframes, natoms):
elif odef.category == OutputVariableCategory.OUT:
# atom_energy, atom_tensor
return [nframes, natoms, *odef.shape, 1]
elif odef.category == OutputVariableCategory.DERV_R_DERV_R:
# hessian
return [nframes, 3 * natoms, 3 * natoms]
else:
raise RuntimeError("unknown category")

Expand Down
14 changes: 13 additions & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ def __init__(
] = state_dict[item].clone()
state_dict = state_dict_head
model = get_model(self.input_param).to(DEVICE)
model = torch.jit.script(model)
try:
model = torch.jit.script(model)
except RuntimeError:
Fixed Show fixed Hide fixed
pass
self.dp = ModelWrapper(model)
self.dp.load_state_dict(state_dict)
elif str(self.model_path).endswith(".pth"):
Expand Down Expand Up @@ -160,6 +163,7 @@ def __init__(
self._has_spin = getattr(self.dp.model["Default"], "has_spin", False)
if callable(self._has_spin):
self._has_spin = self._has_spin()
self._has_hessian = self.model_def_script.get("hessian_mode", False)

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
Expand Down Expand Up @@ -234,6 +238,10 @@ def get_has_spin(self):
"""Check if the model has spin atom types."""
return self._has_spin

def get_has_hessian(self):
"""Check if the model has hessian."""
return self._has_hessian

def eval(
self,
coords: np.ndarray,
Expand Down Expand Up @@ -339,6 +347,7 @@ def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]:
OutputVariableCategory.REDU,
OutputVariableCategory.DERV_R,
OutputVariableCategory.DERV_C_REDU,
OutputVariableCategory.DERV_R_DERV_R,
)
]

Expand Down Expand Up @@ -568,6 +577,9 @@ def _get_output_shape(self, odef, nframes, natoms):
# Something wrong here?
# return [nframes, *shape, natoms, 1]
return [nframes, natoms, *odef.shape, 1]
elif odef.category == OutputVariableCategory.DERV_R_DERV_R:
return [nframes, 3 * natoms, 3 * natoms]
# return [nframes, *odef.shape, 3 * natoms, 3 * natoms]
else:
raise RuntimeError("unknown category")

Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(
] = state_dict[item].clone()
state_dict = state_dict_head

model_params.pop(
"hessian_mode", None
) # wrapper Hessian to Energy model due to JIT limit
self.model_params = deepcopy(model_params)
self.model = get_model(model_params).to(DEVICE)

Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DOSLoss,
)
from .ener import (
EnergyHessianStdLoss,
EnergyStdLoss,
)
from .ener_spin import (
Expand All @@ -24,6 +25,7 @@
__all__ = [
"DOSLoss",
"DenoiseLoss",
"EnergyHessianStdLoss",
"EnergySpinLoss",
"EnergyStdLoss",
"PropertyLoss",
Expand Down
72 changes: 72 additions & 0 deletions deepmd/pt/loss/ener.py
1azyking marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,75 @@ def label_requirement(self) -> list[DataRequirementItem]:
)
)
return label_requirement


class EnergyHessianStdLoss(EnergyStdLoss):
1azyking marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
start_pref_h=0.0,
limit_pref_h=0.0,
**kwargs,
):
r"""Enable the layer to compute loss on hessian.

Parameters
----------
start_pref_h : float
The prefactor of hessian loss at the start of the training.
limit_pref_h : float
The prefactor of hessian loss at the end of the training.
**kwargs
Other keyword arguments.
"""
super().__init__(**kwargs)
self.has_h = (start_pref_h != 0.0 and limit_pref_h != 0.0) or self.inference

self.start_pref_h = start_pref_h
self.limit_pref_h = limit_pref_h

def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
model_pred, loss, more_loss = super().forward(
input_dict, model, label, natoms, learning_rate, mae=mae
)
coef = learning_rate / self.starter_learning_rate
pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef

if self.has_h and "hessian" in model_pred and "hessian" in label:
find_hessian = label.get("find_hessian", 0.0)
pref_h = pref_h * find_hessian
diff_h = label["hessian"].reshape(
-1,
) - model_pred["hessian"].reshape(
-1,
)
l2_hessian_loss = torch.mean(torch.square(diff_h))
if not self.inference:
more_loss["l2_hessian_loss"] = self.display_if_exist(
l2_hessian_loss.detach(), find_hessian
)
loss += pref_h * l2_hessian_loss
rmse_h = l2_hessian_loss.sqrt()
more_loss["rmse_h"] = self.display_if_exist(rmse_h.detach(), find_hessian)
if mae:
mae_h = torch.mean(torch.abs(diff_h))
more_loss["mae_h"] = self.display_if_exist(mae_h.detach(), find_hessian)

if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return model_pred, loss, more_loss

@property
def label_requirement(self) -> list[DataRequirementItem]:
"""Add hessian label requirement needed for this loss calculation."""
label_requirement = super().label_requirement
if self.has_h:
label_requirement.append(
DataRequirementItem(
"hessian",
ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms
atomic=True,
must=False,
high_prec=False,
)
)
return label_requirement
1azyking marked this conversation as resolved.
Show resolved Hide resolved
9 changes: 6 additions & 3 deletions deepmd/pt/model/descriptor/env_mat.py
1azyking marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@ def _make_env_mat(
nall = coord.shape[1]
mask = nlist >= 0
# nlist = nlist * mask ## this impl will contribute nans in Hessian calculation.
nlist = torch.where(mask, nlist, nall - 1)
nlist = torch.where(mask, nlist, nall)
coord_l = coord[:, :natoms].view(bsz, -1, 1, 3)
index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3)
coord_r = torch.gather(coord, 1, index)
coord_pad = torch.concat([coord, coord[:, -1:, :] + rcut], dim=1)
1azyking marked this conversation as resolved.
Show resolved Hide resolved
1azyking marked this conversation as resolved.
Show resolved Hide resolved
coord_r = torch.gather(coord_pad, 1, index)
1azyking marked this conversation as resolved.
Show resolved Hide resolved
1azyking marked this conversation as resolved.
Show resolved Hide resolved
coord_r = coord_r.view(bsz, natoms, nnei, 3)
diff = coord_r - coord_l
length = torch.linalg.norm(diff, dim=-1, keepdim=True)
# avoid the possibility that coord[:, -1:, :] + rcut is the same as the coordinate of a real atom
diff_ = torch.where(torch.abs(diff) < 1e-30, torch.full_like(diff, 1e-30), diff)
length = torch.linalg.norm(diff_, dim=-1, keepdim=True)
1azyking marked this conversation as resolved.
Show resolved Hide resolved
# for index 0 nloc atom
length = length + ~mask.unsqueeze(-1)
t0 = 1 / (length + protection)
Expand Down
Loading