Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: migrate and refactor polar and dos bias #3662

Merged
merged 28 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
774f2ce
chore: try rename to atom_
anyangml Apr 10, 2024
2aff780
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
d6a6571
fix: UTs
anyangml Apr 10, 2024
bd75e48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
14a002c
Merge branch 'devel' into chore/migrate-bias
anyangml Apr 11, 2024
f0baf2e
fix: data shape
anyangml Apr 12, 2024
5300d98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
90ede06
fix: var name
anyangml Apr 12, 2024
8f9dc5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
d8d3f16
fix: var_name
anyangml Apr 12, 2024
ed5c543
fix: loss name
anyangml Apr 15, 2024
88da7ce
fix: dp var name
anyangml Apr 15, 2024
7176a39
fix: dp var name
anyangml Apr 15, 2024
3136c10
Merge branch 'devel' into chore/migrate-bias
anyangml Apr 16, 2024
88e41e5
chore: remove bias in fitting
anyangml Apr 16, 2024
c94608a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
23c7fdf
chore: remove UTs
anyangml Apr 16, 2024
ead2a38
fix: UT import
anyangml Apr 16, 2024
ec89624
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
5698bfc
chore: move polar bias
anyangml Apr 17, 2024
3f11f7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
d7036b8
feat: add UT on out_std
anyangml Apr 17, 2024
09d775d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
cd2e78e
Merge branch 'devel' into chore/migrate-bias
anyangml Apr 17, 2024
0eacfe9
fix: UTs
anyangml Apr 18, 2024
c0c08ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
ae709c4
fix: UTs
anyangml Apr 18, 2024
e57dd7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
self.var_name,
"polarizability",
[3, 3],
reduciable=True,
r_differentiable=False,
Expand Down Expand Up @@ -280,4 +280,4 @@ def call(
# (nframes, nloc, 3, 3)
bias = np.expand_dims(bias, axis=-1) * eye
out = out + bias
return {self.var_name: out}
return {"polarizability": out}
8 changes: 4 additions & 4 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def test_polar(
polar = polar.reshape((polar.shape[0], -1, 9))[:, sel_mask, :].reshape(
(polar.shape[0], -1)
)
rmse_f = rmse(polar - test_data["atomic_polarizability"][:numb_test])
rmse_f = rmse(polar - test_data["atom_polarizability"][:numb_test])

log.info(f"# number of test data : {numb_test:d} ")
log.info(f"Polarizability RMSE : {rmse_f:e}")
Expand Down Expand Up @@ -926,7 +926,7 @@ def test_polar(
pe = np.concatenate(
(
np.reshape(
test_data["atomic_polarizability"][:numb_test],
test_data["atom_polarizability"][:numb_test],
[-1, 9 * sel_natoms],
),
np.reshape(polar, [-1, 9 * sel_natoms]),
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def test_dipole(
dipole = dipole.reshape((dipole.shape[0], -1, 3))[:, sel_mask, :].reshape(
(dipole.shape[0], -1)
)
rmse_f = rmse(dipole - test_data["atomic_dipole"][:numb_test])
rmse_f = rmse(dipole - test_data["atom_dipole"][:numb_test])

log.info(f"# number of test data : {numb_test:d}")
log.info(f"Dipole RMSE : {rmse_f:e}")
Expand All @@ -1061,7 +1061,7 @@ def test_dipole(
pe = np.concatenate(
(
np.reshape(
test_data["atomic_dipole"][:numb_test], [-1, 3 * sel_natoms]
test_data["atom_dipole"][:numb_test], [-1, 3 * sel_natoms]
),
np.reshape(dipole, [-1, 3 * sel_natoms]),
),
Expand Down
6 changes: 3 additions & 3 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
if (
self.has_local_weight
and self.tensor_name in model_pred
and "atomic_" + self.label_name in label
and "atom_" + self.label_name in label
):
find_local = label.get("find_" + "atomic_" + self.label_name, 0.0)
find_local = label.get("find_" + "atom_" + self.label_name, 0.0)
local_weight = self.local_weight * find_local
local_tensor_pred = model_pred[self.tensor_name].reshape(
[-1, natoms, self.tensor_size]
)
local_tensor_label = label["atomic_" + self.label_name].reshape(
local_tensor_label = label["atom_" + self.label_name].reshape(
[-1, natoms, self.tensor_size]
)
diff = (local_tensor_pred - local_tensor_label).reshape(
Expand Down
36 changes: 35 additions & 1 deletion deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,39 @@ def apply_out_stat(
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
# TODO: migrate bias
"""Apply the stat to each atomic output.

Parameters
----------
ret
The returned dict by the forward_atomic method
atype
The atom types. nf x nloc

"""
out_bias, out_std = self._fetch_out_stat(self.bias_keys)

if self.fitting_net.shift_diag:
nframes, nloc = atype.shape
device = out_bias[self.bias_keys[0]].device
dtype = out_bias[self.bias_keys[0]].dtype
for kk in self.bias_keys:
ntypes = out_bias[kk].shape[0]
temp = torch.zeros(ntypes, dtype=dtype, device=device)
for i in range(ntypes):
temp[i] = torch.mean(torch.diagonal(out_bias[kk][i].reshape(3, 3)))
modified_bias = temp[atype]

# (nframes, nloc, 1)
modified_bias = (
modified_bias.unsqueeze(-1) * self.fitting_net.scale[atype]
)

eye = torch.eye(3, dtype=dtype, device=device)
eye = eye.repeat(nframes, nloc, 1, 1)
# (nframes, nloc, 3, 3)
modified_bias = modified_bias.unsqueeze(-1) * eye

# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + modified_bias
return ret
8 changes: 4 additions & 4 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def forward(
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["polar"] = model_ret["polar"]
model_predict["global_polar"] = model_ret["polar_redu"]
model_predict["polar"] = model_ret["polarizability"]
model_predict["global_polar"] = model_ret["polarizability_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
Expand Down Expand Up @@ -85,8 +85,8 @@ def forward_lower(
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["polar"] = model_ret["polar"]
model_predict["global_polar"] = model_ret["polar_redu"]
model_predict["polar"] = model_ret["polarizability"]
model_predict["global_polar"] = model_ret["polarizability_redu"]
else:
model_predict = model_ret
return model_predict
66 changes: 0 additions & 66 deletions deepmd/pt/model/task/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import copy
import logging
from typing import (
Callable,
List,
Optional,
Union,
)

import numpy as np
import torch

from deepmd.dpmodel import (
Expand All @@ -30,13 +28,6 @@
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.out_stat import (
compute_stats_from_atomic,
compute_stats_from_redu,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -105,63 +96,6 @@ def output_def(self) -> FittingOutputDef:
]
)

def compute_output_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
) -> None:
"""
Compute the output statistics (e.g. dos bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_dos"
if stat_file_path is not None and stat_file_path.is_file():
bias_dos = stat_file_path.load_numpy()
else:
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged
for sys in range(len(sampled)):
nframs = sampled[sys]["atype"].shape[0]

if "atom_dos" in sampled[sys]:
bias_dos = compute_stats_from_atomic(
sampled[sys]["atom_dos"].numpy(force=True),
sampled[sys]["atype"].numpy(force=True),
)[0]
else:
sys_type_count = np.zeros(
(nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION
)
for itype in range(self.ntypes):
type_mask = sampled[sys]["atype"] == itype
sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(
force=True
)
sys_bias_redu = sampled[sys]["dos"].numpy(force=True)

bias_dos = compute_stats_from_redu(
sys_bias_redu, sys_type_count, rcond=self.rcond
)[0]
if stat_file_path is not None:
stat_file_path.save_numpy(bias_dos)
self.bias_dos = torch.tensor(bias_dos, device=env.DEVICE)

@classmethod
def deserialize(cls, data: dict) -> "DOSFittingNet":
data = copy.deepcopy(data)
Expand Down
42 changes: 0 additions & 42 deletions deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import copy
import logging
from typing import (
Callable,
List,
Optional,
Union,
)

import torch
Expand All @@ -24,12 +22,6 @@
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -146,40 +138,6 @@ def deserialize(cls, data: dict) -> "GeneralFitting":
check_version_compatibility(data.pop("@version", 1), 1, 1)
return super().deserialize(data)

def compute_output_stats(
self,
merged: Union[Callable[[], List[dict]], List[dict]],
anyangml marked this conversation as resolved.
Show resolved Hide resolved
stat_file_path: Optional[DPPath] = None,
):
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
# [0] to get the mean (bias)
bias_atom_e = compute_output_stats(
merged,
self.ntypes,
keys=[self.var_name],
stat_file_path=stat_file_path,
rcond=self.rcond,
preset_bias={self.var_name: self.atom_ener}
if self.atom_ener is not None
else None,
)[0][self.var_name]
self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out]))

def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
Expand Down
Loading