diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 2a691e963d..70e52c8e7d 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -206,7 +206,7 @@ def output_def(self): return FittingOutputDef( [ OutputVariableDef( - self.var_name, + "polarizability", [3, 3], reduciable=True, r_differentiable=False, @@ -280,4 +280,4 @@ def call( # (nframes, nloc, 3, 3) bias = np.expand_dims(bias, axis=-1) * eye out = out + bias - return {self.var_name: out} + return {"polarizability": out} diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index cad6e12d2b..1540a5a1b9 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -897,7 +897,7 @@ def test_polar( polar = polar.reshape((polar.shape[0], -1, 9))[:, sel_mask, :].reshape( (polar.shape[0], -1) ) - rmse_f = rmse(polar - test_data["atomic_polarizability"][:numb_test]) + rmse_f = rmse(polar - test_data["atom_polarizability"][:numb_test]) log.info(f"# number of test data : {numb_test:d} ") log.info(f"Polarizability RMSE : {rmse_f:e}") @@ -926,7 +926,7 @@ def test_polar( pe = np.concatenate( ( np.reshape( - test_data["atomic_polarizability"][:numb_test], + test_data["atom_polarizability"][:numb_test], [-1, 9 * sel_natoms], ), np.reshape(polar, [-1, 9 * sel_natoms]), @@ -1037,7 +1037,7 @@ def test_dipole( dipole = dipole.reshape((dipole.shape[0], -1, 3))[:, sel_mask, :].reshape( (dipole.shape[0], -1) ) - rmse_f = rmse(dipole - test_data["atomic_dipole"][:numb_test]) + rmse_f = rmse(dipole - test_data["atom_dipole"][:numb_test]) log.info(f"# number of test data : {numb_test:d}") log.info(f"Dipole RMSE : {rmse_f:e}") @@ -1061,7 +1061,7 @@ def test_dipole( pe = np.concatenate( ( np.reshape( - test_data["atomic_dipole"][:numb_test], [-1, 3 * sel_natoms] + test_data["atom_dipole"][:numb_test], [-1, 3 * sel_natoms] ), np.reshape(dipole, [-1, 3 * sel_natoms]), ), diff --git a/deepmd/pt/loss/tensor.py b/deepmd/pt/loss/tensor.py index 3dd91d203e..3dcf21af1d 100644 --- a/deepmd/pt/loss/tensor.py +++ b/deepmd/pt/loss/tensor.py @@ -93,14 +93,14 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False if ( self.has_local_weight and self.tensor_name in model_pred - and "atomic_" + self.label_name in label + and "atom_" + self.label_name in label ): - find_local = label.get("find_" + "atomic_" + self.label_name, 0.0) + find_local = label.get("find_" + "atom_" + self.label_name, 0.0) local_weight = self.local_weight * find_local local_tensor_pred = model_pred[self.tensor_name].reshape( [-1, natoms, self.tensor_size] ) - local_tensor_label = label["atomic_" + self.label_name].reshape( + local_tensor_label = label["atom_" + self.label_name].reshape( [-1, natoms, self.tensor_size] ) diff = (local_tensor_pred - local_tensor_label).reshape( diff --git a/deepmd/pt/model/atomic_model/polar_atomic_model.py b/deepmd/pt/model/atomic_model/polar_atomic_model.py index 3eb4136b6e..85320210ed 100644 --- a/deepmd/pt/model/atomic_model/polar_atomic_model.py +++ b/deepmd/pt/model/atomic_model/polar_atomic_model.py @@ -24,5 +24,39 @@ def apply_out_stat( ret: Dict[str, torch.Tensor], atype: torch.Tensor, ): - # TODO: migrate bias + """Apply the stat to each atomic output. + + Parameters + ---------- + ret + The returned dict by the forward_atomic method + atype + The atom types. nf x nloc + + """ + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + + if self.fitting_net.shift_diag: + nframes, nloc = atype.shape + device = out_bias[self.bias_keys[0]].device + dtype = out_bias[self.bias_keys[0]].dtype + for kk in self.bias_keys: + ntypes = out_bias[kk].shape[0] + temp = torch.zeros(ntypes, dtype=dtype, device=device) + for i in range(ntypes): + temp[i] = torch.mean(torch.diagonal(out_bias[kk][i].reshape(3, 3))) + modified_bias = temp[atype] + + # (nframes, nloc, 1) + modified_bias = ( + modified_bias.unsqueeze(-1) * self.fitting_net.scale[atype] + ) + + eye = torch.eye(3, dtype=dtype, device=device) + eye = eye.repeat(nframes, nloc, 1, 1) + # (nframes, nloc, 3, 3) + modified_bias = modified_bias.unsqueeze(-1) * eye + + # nf x nloc x odims, out_bias: ntypes x odims + ret[kk] = ret[kk] + modified_bias return ret diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index 867afefc27..aa2a90aa4d 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -54,8 +54,8 @@ def forward( ) if self.get_fitting_net() is not None: model_predict = {} - model_predict["polar"] = model_ret["polar"] - model_predict["global_polar"] = model_ret["polar_redu"] + model_predict["polar"] = model_ret["polarizability"] + model_predict["global_polar"] = model_ret["polarizability_redu"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] else: @@ -85,8 +85,8 @@ def forward_lower( ) if self.get_fitting_net() is not None: model_predict = {} - model_predict["polar"] = model_ret["polar"] - model_predict["global_polar"] = model_ret["polar_redu"] + model_predict["polar"] = model_ret["polarizability"] + model_predict["global_polar"] = model_ret["polarizability_redu"] else: model_predict = model_ret return model_predict diff --git a/deepmd/pt/model/task/dos.py b/deepmd/pt/model/task/dos.py index 196872d17c..c37b05277a 100644 --- a/deepmd/pt/model/task/dos.py +++ b/deepmd/pt/model/task/dos.py @@ -2,13 +2,11 @@ import copy import logging from typing import ( - Callable, List, Optional, Union, ) -import numpy as np import torch from deepmd.dpmodel import ( @@ -30,13 +28,6 @@ from deepmd.pt.utils.utils import ( to_numpy_array, ) -from deepmd.utils.out_stat import ( - compute_stats_from_atomic, - compute_stats_from_redu, -) -from deepmd.utils.path import ( - DPPath, -) from deepmd.utils.version import ( check_version_compatibility, ) @@ -105,63 +96,6 @@ def output_def(self) -> FittingOutputDef: ] ) - def compute_output_stats( - self, - merged: Union[Callable[[], List[dict]], List[dict]], - stat_file_path: Optional[DPPath] = None, - ) -> None: - """ - Compute the output statistics (e.g. dos bias) for the fitting net from packed data. - - Parameters - ---------- - merged : Union[Callable[[], List[dict]], List[dict]] - - List[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], List[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - stat_file_path : Optional[DPPath] - The path to the stat file. - - """ - if stat_file_path is not None: - stat_file_path = stat_file_path / "bias_dos" - if stat_file_path is not None and stat_file_path.is_file(): - bias_dos = stat_file_path.load_numpy() - else: - if callable(merged): - # only get data for once - sampled = merged() - else: - sampled = merged - for sys in range(len(sampled)): - nframs = sampled[sys]["atype"].shape[0] - - if "atom_dos" in sampled[sys]: - bias_dos = compute_stats_from_atomic( - sampled[sys]["atom_dos"].numpy(force=True), - sampled[sys]["atype"].numpy(force=True), - )[0] - else: - sys_type_count = np.zeros( - (nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION - ) - for itype in range(self.ntypes): - type_mask = sampled[sys]["atype"] == itype - sys_type_count[:, itype] = type_mask.sum(dim=1).numpy( - force=True - ) - sys_bias_redu = sampled[sys]["dos"].numpy(force=True) - - bias_dos = compute_stats_from_redu( - sys_bias_redu, sys_type_count, rcond=self.rcond - )[0] - if stat_file_path is not None: - stat_file_path.save_numpy(bias_dos) - self.bias_dos = torch.tensor(bias_dos, device=env.DEVICE) - @classmethod def deserialize(cls, data: dict) -> "DOSFittingNet": data = copy.deepcopy(data) diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index 01e6a8b95d..ea46a552e5 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -2,10 +2,8 @@ import copy import logging from typing import ( - Callable, List, Optional, - Union, ) import torch @@ -24,12 +22,6 @@ from deepmd.pt.utils.env import ( DEFAULT_PRECISION, ) -from deepmd.pt.utils.stat import ( - compute_output_stats, -) -from deepmd.utils.path import ( - DPPath, -) from deepmd.utils.version import ( check_version_compatibility, ) @@ -146,40 +138,6 @@ def deserialize(cls, data: dict) -> "GeneralFitting": check_version_compatibility(data.pop("@version", 1), 1, 1) return super().deserialize(data) - def compute_output_stats( - self, - merged: Union[Callable[[], List[dict]], List[dict]], - stat_file_path: Optional[DPPath] = None, - ): - """ - Compute the output statistics (e.g. energy bias) for the fitting net from packed data. - - Parameters - ---------- - merged : Union[Callable[[], List[dict]], List[dict]] - - List[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], List[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - stat_file_path : Optional[DPPath] - The path to the stat file. - - """ - # [0] to get the mean (bias) - bias_atom_e = compute_output_stats( - merged, - self.ntypes, - keys=[self.var_name], - stat_file_path=stat_file_path, - rcond=self.rcond, - preset_bias={self.var_name: self.atom_ener} - if self.atom_ener is not None - else None, - )[0][self.var_name] - self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out])) - def output_def(self) -> FittingOutputDef: return FittingOutputDef( [ diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index cd944996be..18cc7e69a0 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -2,13 +2,11 @@ import copy import logging from typing import ( - Callable, List, Optional, Union, ) -import numpy as np import torch from deepmd.dpmodel import ( @@ -27,13 +25,6 @@ from deepmd.pt.utils.utils import ( to_numpy_array, ) -from deepmd.utils.out_stat import ( - compute_stats_from_atomic, - compute_stats_from_redu, -) -from deepmd.utils.path import ( - DPPath, -) from deepmd.utils.version import ( check_version_compatibility, ) @@ -185,7 +176,7 @@ def output_def(self) -> FittingOutputDef: return FittingOutputDef( [ OutputVariableDef( - self.var_name, + "polarizability", [3, 3], reduciable=True, r_differentiable=False, @@ -194,82 +185,6 @@ def output_def(self) -> FittingOutputDef: ] ) - def compute_output_stats( - self, - merged: Union[Callable[[], List[dict]], List[dict]], - stat_file_path: Optional[DPPath] = None, - ) -> None: - """ - Compute the output statistics (e.g. energy bias) for the fitting net from packed data. - - Parameters - ---------- - merged : Union[Callable[[], List[dict]], List[dict]] - - List[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], List[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - stat_file_path : Optional[DPPath] - The path to the stat file. - - """ - if self.shift_diag: - if stat_file_path is not None: - stat_file_path = stat_file_path / "constant_matrix" - if stat_file_path is not None and stat_file_path.is_file(): - constant_matrix = stat_file_path.load_numpy() - else: - if callable(merged): - # only get data for once - sampled = merged() - else: - sampled = merged - - sys_constant_matrix = [] - for sys in range(len(sampled)): - nframs = sampled[sys]["atype"].shape[0] - - if sampled[sys]["find_atomic_polarizability"] > 0.0: - sys_atom_polar = compute_stats_from_atomic( - sampled[sys]["atomic_polarizability"].numpy(force=True), - sampled[sys]["atype"].numpy(force=True), - )[0] - else: - if not sampled[sys]["find_polarizability"] > 0.0: - continue - sys_type_count = np.zeros( - (nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION - ) - for itype in range(self.ntypes): - type_mask = sampled[sys]["atype"] == itype - sys_type_count[:, itype] = type_mask.sum(dim=1).numpy( - force=True - ) - - sys_bias_redu = sampled[sys]["polarizability"].numpy(force=True) - - sys_atom_polar = compute_stats_from_redu( - sys_bias_redu, sys_type_count, rcond=self.rcond - )[0] - cur_constant_matrix = np.zeros( - self.ntypes, dtype=env.GLOBAL_NP_FLOAT_PRECISION - ) - - for itype in range(self.ntypes): - cur_constant_matrix[itype] = np.mean( - np.diagonal(sys_atom_polar[itype].reshape(3, 3)) - ) - sys_constant_matrix.append(cur_constant_matrix) - constant_matrix = np.stack(sys_constant_matrix).mean(axis=0) - - # handle nan values. - constant_matrix = np.nan_to_num(constant_matrix) - if stat_file_path is not None: - stat_file_path.save_numpy(constant_matrix) - self.constant_matrix = torch.tensor(constant_matrix, device=env.DEVICE) - def forward( self, descriptor: torch.Tensor, @@ -302,19 +217,7 @@ def forward( "bim,bmj->bij", gr.transpose(1, 2), out ) # (nframes * nloc, 3, 3) out = out.view(nframes, nloc, 3, 3) - if self.shift_diag: - bias = self.constant_matrix[atype] - - # (nframes, nloc, 1) - bias = bias.unsqueeze(-1) * self.scale[atype] - - eye = torch.eye(3, device=env.DEVICE) - eye = eye.repeat(nframes, nloc, 1, 1) - # (nframes, nloc, 3, 3) - bias = bias.unsqueeze(-1) * eye - out = out + bias - - return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} + return {"polarizability": out.to(env.GLOBAL_PT_FLOAT_PRECISION)} # make jit happy with torch 2.0.0 exclude_types: List[int] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index fe9b432fb7..5ba4b6f336 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -303,9 +303,10 @@ def get_loss(loss_params, start_lr, _ntypes, _model): tensor_name ].output_size label_name = tensor_name - if label_name == "polar": - label_name = "polarizability" + if label_name == "polarizability": + label_name = "polar" loss_params["label_name"] = label_name + loss_params["tensor_name"] = label_name return TensorLoss(**loss_params) else: raise NotImplementedError diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 77da1e01f1..b96310a2e6 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -224,6 +224,7 @@ def _fill_stat_with_global( if atomic_stat is None: return global_stat else: + atomic_stat = atomic_stat.reshape(*global_stat.shape) return np.nan_to_num( np.where( np.isnan(atomic_stat) & ~np.isnan(global_stat), global_stat, atomic_stat @@ -535,13 +536,20 @@ def compute_output_stats_atomic( merged_natoms = { kk: to_numpy_array(torch.cat(natoms[kk])) for kk in keys if len(natoms[kk]) > 0 } + # reshape merged data to [nf, nloc, ndim] + merged_output = { + kk: merged_output[kk].reshape((*merged_natoms[kk].shape, -1)) + for kk in merged_output + } if model_pred is None: stats_input = merged_output else: # subtract the model bias and output the delta bias stats_input = { - kk: merged_output[kk] - model_pred[kk] for kk in keys if kk in merged_output + kk: merged_output[kk] - model_pred[kk].reshape(*merged_output[kk].shape) + for kk in keys + if kk in merged_output } bias_atom_e = {} @@ -559,9 +567,8 @@ def compute_output_stats_atomic( nan_padding = np.empty((missing_types, bias_atom_e[kk].shape[1])) nan_padding.fill(np.nan) bias_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0) - std_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0) + std_atom_e[kk] = np.concatenate([std_atom_e[kk], nan_padding], axis=0) else: # this key does not have atomic labels, skip it. continue - bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e) return bias_atom_e, std_atom_e diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index c124bd3ef4..460813f309 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -193,7 +193,7 @@ def compute_output_stats(self, all_stat): atom_has_polar = [ w for w in all_stat["type"][ss][0] if (w in self.sel_type) ] # select atom with polar - if all_stat["find_atomic_polarizability"][ss] > 0.0: + if all_stat["find_atom_polarizability"][ss] > 0.0: for itype in range( len(self.sel_type) ): # Atomic polar mode, should specify the atoms @@ -208,7 +208,7 @@ def compute_output_stats(self, all_stat): polar_bias.append( np.sum( - all_stat["atomic_polarizability"][ss].reshape( + all_stat["atom_polarizability"][ss].reshape( nframes, len(atom_has_polar), -1 )[:, index_lis, :] / nframes, diff --git a/deepmd/tf/loss/tensor.py b/deepmd/tf/loss/tensor.py index 3be01d3871..6a0eb30a44 100644 --- a/deepmd/tf/loss/tensor.py +++ b/deepmd/tf/loss/tensor.py @@ -70,11 +70,11 @@ def __init__(self, jdata, **kwarg): def build(self, learning_rate, natoms, model_dict, label_dict, suffix): polar_hat = label_dict[self.label_name] - atomic_polar_hat = label_dict["atomic_" + self.label_name] + atomic_polar_hat = label_dict["atom_" + self.label_name] polar = tf.reshape(model_dict[self.tensor_name], [-1]) find_global = label_dict["find_" + self.label_name] - find_atomic = label_dict["find_atomic_" + self.label_name] + find_atomic = label_dict["find_atom_" + self.label_name] # YHT: added for global / local dipole combination l2_loss = global_cvt_2_tf_float(0.0) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 3cf73dc093..d01509d6f5 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -577,6 +577,8 @@ def _load_set(self, set_name: DPPath): else: data["type"] = np.tile(self.atom_type[self.idx_map], (nframes, 1)) + # standardize keys + data = {kk.replace("atomic", "atom"): vv for kk, vv in data.items()} return data def _load_data( diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index a6e0e07784..808514ade4 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -125,7 +125,7 @@ def eval_pt(self, pt_obj: Any) -> Any: torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE), torch.from_numpy(self.gr).to(device=PT_DEVICE), None, - )["polar"] + )["polarizability"] .detach() .cpu() .numpy() @@ -142,7 +142,7 @@ def eval_dp(self, dp_obj: Any) -> Any: self.atype.reshape(1, -1), self.gr, None, - )["polar"] + )["polarizability"] def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: diff --git a/source/tests/pt/model/test_atomic_model_atomic_stat.py b/source/tests/pt/model/test_atomic_model_atomic_stat.py index 8f365a09fe..e779eb572c 100644 --- a/source/tests/pt/model/test_atomic_model_atomic_stat.py +++ b/source/tests/pt/model/test_atomic_model_atomic_stat.py @@ -212,6 +212,15 @@ def cvt_ret(x): self.merged_output_stat, stat_file_path=self.stat_file_path ) ret1 = md0.forward_common_atomic(*args) + expected_std = np.ones( + (2, 2, 2), dtype=np.float64 + ) # 2 keys, 2 atypes, 2 max dims. + expected_std[0, :, :1] = np.array([0.0, 0.816496]).reshape( + 2, 1 + ) # updating std for foo based on [5.0, 5.0, 5.0], [5.0, 6.0, 7.0]] + np.testing.assert_almost_equal( + to_numpy_array(md0.out_std), expected_std, decimal=4 + ) ret1 = cvt_ret(ret1) # nt x odim foo_bias = np.array([5.0, 6.0]).reshape(2, 1) @@ -231,6 +240,9 @@ def raise_error(): ret2 = cvt_ret(ret2) for kk in ["foo", "bar"]: np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + np.testing.assert_almost_equal( + to_numpy_array(md0.out_std), expected_std, decimal=4 + ) # 4. test change bias BaseAtomicModel.change_out_bias( @@ -246,7 +258,9 @@ def raise_error(): ] ret3 = md0.forward_common_atomic(*args) ret3 = cvt_ret(ret3) - + expected_std[0, :, :1] = np.array([1.24722, 0.47140]).reshape( + 2, 1 + ) # updating std for foo based on [4.0, 3.0, 2.0], [1.0, 1.0, 1.0]] expected_ret3 = {} # new bias [2.666, 1.333] expected_ret3["foo"] = np.array( @@ -254,6 +268,9 @@ def raise_error(): ).reshape(2, 3, 1) for kk in ["foo"]: np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) + np.testing.assert_almost_equal( + to_numpy_array(md0.out_std), expected_std, decimal=4 + ) class TestAtomicModelStatMergeGlobalAtomic( diff --git a/source/tests/pt/model/test_atomic_model_global_stat.py b/source/tests/pt/model/test_atomic_model_global_stat.py index ca71b604ce..799948b14f 100644 --- a/source/tests/pt/model/test_atomic_model_global_stat.py +++ b/source/tests/pt/model/test_atomic_model_global_stat.py @@ -193,6 +193,7 @@ def cvt_ret(x): # nf x na x odim ret0 = md0.forward_common_atomic(*args) ret0 = cvt_ret(ret0) + expected_ret0 = {} expected_ret0["foo"] = np.array( [ @@ -221,6 +222,7 @@ def cvt_ret(x): ) ret1 = md0.forward_common_atomic(*args) ret1 = cvt_ret(ret1) + expected_std = np.ones((3, 2, 2)) # 3 keys, 2 atypes, 2 max dims. # nt x odim foo_bias = np.array([1.0, 3.0]).reshape(2, 1) bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) @@ -230,6 +232,7 @@ def cvt_ret(x): expected_ret1["bar"] = ret0["bar"] + bar_bias[at] for kk in ["foo", "pix", "bar"]: np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) # 3. test bias load from file def raise_error(): @@ -240,6 +243,7 @@ def raise_error(): ret2 = cvt_ret(ret2) for kk in ["foo", "pix", "bar"]: np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) # 4. test change bias BaseAtomicModel.change_out_bias( @@ -266,6 +270,7 @@ def raise_error(): for kk in ["foo", "pix"]: np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk]) # bar is too complicated to be manually computed. + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) def test_preset_bias(self): nf, nloc, nnei = self.nlist.shape diff --git a/source/tests/pt/model/test_descriptor.py b/source/tests/pt/model/test_descriptor.py index 7d21d1c13d..ff1fd0c959 100644 --- a/source/tests/pt/model/test_descriptor.py +++ b/source/tests/pt/model/test_descriptor.py @@ -38,7 +38,7 @@ op_module, ) -from ..test_stat import ( +from ..test_finetune import ( energy_data_requirement, ) from .test_embedding_net import ( diff --git a/source/tests/pt/model/test_embedding_net.py b/source/tests/pt/model/test_embedding_net.py index 63a3534c74..77d14db2a4 100644 --- a/source/tests/pt/model/test_embedding_net.py +++ b/source/tests/pt/model/test_embedding_net.py @@ -39,7 +39,7 @@ ) from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf -from ..test_stat import ( +from ..test_finetune import ( energy_data_requirement, ) diff --git a/source/tests/pt/model/test_model.py b/source/tests/pt/model/test_model.py index 493d6e2cc3..71ad64d99d 100644 --- a/source/tests/pt/model/test_model.py +++ b/source/tests/pt/model/test_model.py @@ -51,7 +51,7 @@ LearningRateExp, ) -from ..test_stat import ( +from ..test_finetune import ( energy_data_requirement, ) diff --git a/source/tests/pt/model/test_polar_atomic_model_stat.py b/source/tests/pt/model/test_polar_atomic_model_stat.py new file mode 100644 index 0000000000..d9ddfcd3e6 --- /dev/null +++ b/source/tests/pt/model/test_polar_atomic_model_stat.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + Optional, +) + +import h5py +import numpy as np +import torch + +from deepmd.pt.model.atomic_model import ( + BaseAtomicModel, + DPPolarAtomicModel, +) +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.model.task.polarizability import ( + PolarFittingNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.path import ( + DPPath, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class FooFitting(PolarFittingNet): + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["polarizability"] = ( + torch.Tensor( + [ + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], + [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [6.0, 6.0, 6.0]], + ], + [ + [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0], [4.0, 4.0, 4.0]], + [[4.0, 4.0, 4.0], [5.0, 5.0, 5.0], [6.0, 6.0, 6.0]], + [[6.0, 6.0, 6.0], [4.0, 4.0, 4.0], [2.0, 2.0, 2.0]], + ], + ] + ) + .view([nf, nloc, *self.output_def()["polarizability"].shape]) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + .to(env.DEVICE) + ) + + return ret + + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self): + self.tempdir.cleanup() + + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + self.merged_output_stat = [ + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5, 6 + "atom_polarizability": to_torch_tensor( + np.array( + [ + [ + [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0], [5.0, 5.0, 5.0]], + [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0], [5.0, 5.0, 5.0]], + [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0], [5.0, 5.0, 5.0]], + ], + [ + [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0], [5.0, 5.0, 5.0]], + [[6.0, 6.0, 6.0], [6.0, 6.0, 6.0], [6.0, 6.0, 6.0]], + [[7.0, 7.0, 7.0], [7.0, 7.0, 7.0], [7.0, 7.0, 7.0]], + ], + ] + ).reshape(2, 3, 3, 3) + ), + "find_atom_polarizability": np.float32(1.0), + }, + { + "coord": to_torch_tensor(np.zeros([2, 3, 3])), + "atype": to_torch_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_torch_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_torch_tensor(np.zeros([2, 3, 3])), + "natoms": to_torch_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5, 6 from atomic label. + "polarizability": to_torch_tensor( + np.array( + [ + [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0], [5.0, 5.0, 5.0]], + [[7.0, 7.0, 7.0], [7.0, 7.0, 7.0], [7.0, 7.0, 7.0]], + ] + ).reshape(2, 3, 3) + ), + "find_polarizability": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft = FooFitting(self.nt, 1, 1).to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = DPPolarAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["polarizability"] = np.array( + [ + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], + [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0], [6.0, 6.0, 6.0]], + ], + [ + [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0], [4.0, 4.0, 4.0]], + [[4.0, 4.0, 4.0], [5.0, 5.0, 5.0], [6.0, 6.0, 6.0]], + [[6.0, 6.0, 6.0], [4.0, 4.0, 4.0], [2.0, 2.0, 2.0]], + ], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["polarizability"].shape]) + + np.testing.assert_almost_equal( + ret0["polarizability"], expected_ret0["polarizability"] + ) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + expected_std = np.zeros( + (1, 2, 9), dtype=np.float64 + ) # 1 keys, 2 atypes, 9 max dims. + expected_std[:, 1, :] = np.ones(9, dtype=np.float64) * 0.8164966 # updating std + # nt x odim (dia) + diagnoal_bias = np.array( + [ + [[5.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 5.0]], + [[6.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 6.0]], + ] + ).reshape(2, 3, 3) + expected_ret1 = {} + expected_ret1["polarizability"] = ret0["polarizability"] + diagnoal_bias[at] + np.testing.assert_almost_equal( + ret1["polarizability"], expected_ret1["polarizability"] + ) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) + + # 3. test bias load from file + def raise_error(): + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + np.testing.assert_almost_equal(ret1["polarizability"], ret2["polarizability"]) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) + + # 4. test change bias + BaseAtomicModel.change_out_bias( + md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + args = [ + to_torch_tensor(ii) + for ii in [ + self.coord_ext, + to_numpy_array(self.merged_output_stat[0]["atype_ext"]), + self.nlist, + ] + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + + expected_ret3 = {} + expected_std = np.array( + [ + [ + [ + 1.4142136, + 1.4142136, + 1.4142136, + 1.2472191, + 1.2472191, + 1.2472191, + 1.2472191, + 1.2472191, + 1.2472191, + ], + [ + 0.4714045, + 0.4714045, + 0.4714045, + 0.8164966, + 0.8164966, + 0.8164966, + 2.6246693, + 2.6246693, + 2.6246693, + ], + ] + ] + ) + # new bias [[[3.0000, -, -, -, 2.6667, -, -, -, 2.3333], + # [1.6667, -, -, -, 2.0000, -, -, -, 1.3333]]] + # which yields [2.667, 1.667] + expected_ret3["polarizability"] = np.array( + [ + [ + [[3.6667, 1.0, 1.0], [1.0, 3.6667, 1.0], [1.0, 1.0, 3.6667]], + [[3.6667, 1.0, 1.0], [2.0, 4.6667, 2.0], [3.0, 3.0, 5.6667]], + [[4.6667, 3.0, 3.0], [3.0, 4.6667, 3.0], [6.0, 6.0, 7.6667]], + ], + [ + [[6.6667, 4.0, 4.0], [4.0, 6.6667, 4.0], [4.0, 4.0, 6.6667]], + [[5.6667, 4.0, 4.0], [5.0, 6.6667, 5.0], [6.0, 6.0, 7.6667]], + [[7.6667, 6.0, 6.0], [4.0, 5.6667, 4.0], [2.0, 2.0, 3.6667]], + ], + ] + ).reshape(2, 3, 3, 3) + np.testing.assert_almost_equal( + ret3["polarizability"], expected_ret3["polarizability"], decimal=4 + ) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) diff --git a/source/tests/pt/model/test_polar_stat.py b/source/tests/pt/model/test_polar_stat.py deleted file mode 100644 index 3d72c6e8fa..0000000000 --- a/source/tests/pt/model/test_polar_stat.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import unittest - -import numpy as np -import torch - -from deepmd.pt.model.task.polarizability import ( - PolarFittingNet, -) -from deepmd.pt.utils import ( - env, -) -from deepmd.pt.utils.utils import ( - to_numpy_array, -) -from deepmd.tf.fit.polar import ( - PolarFittingSeA, -) - - -class TestConsistency(unittest.TestCase): - def setUp(self) -> None: - types = torch.randint(0, 4, (1, 5), device=env.DEVICE) - types = torch.cat((types, types, types), dim=0) - types[:, -1] = 3 - ntypes = 4 - atomic_polarizability = torch.rand((3, 5, 9), device=env.DEVICE) - polarizability = torch.rand((3, 9), device=env.DEVICE) - find_polarizability = torch.rand(1, device=env.DEVICE) - find_atomic_polarizability = torch.rand(1, device=env.DEVICE) - self.sampled = [ - { - "atype": types, - "find_atomic_polarizability": find_atomic_polarizability, - "atomic_polarizability": atomic_polarizability, - "polarizability": polarizability, - "find_polarizability": find_polarizability, - } - ] - self.all_stat = { - k: [v.numpy(force=True)] for d in self.sampled for k, v in d.items() - } - self.all_stat["type"] = self.all_stat.pop("atype") - self.tfpolar = PolarFittingSeA( - ntypes=ntypes, - dim_descrpt=1, - embedding_width=1, - sel_type=list(range(ntypes)), - ) - self.ptpolar = PolarFittingNet( - ntypes=ntypes, - dim_descrpt=1, - embedding_width=1, - ) - - def test_atomic_consistency(self): - self.tfpolar.compute_output_stats(self.all_stat) - tfbias = self.tfpolar.constant_matrix - self.ptpolar.compute_output_stats(self.sampled) - ptbias = self.ptpolar.constant_matrix - np.testing.assert_allclose(tfbias, to_numpy_array(ptbias)) - - def test_global_consistency(self): - self.sampled[0]["find_atomic_polarizability"] = -1 - self.sampled[0]["polarizability"] = self.sampled[0][ - "atomic_polarizability" - ].sum(dim=1) - self.all_stat["find_atomic_polarizability"] = [-1] - self.all_stat["polarizability"] = [ - self.all_stat["atomic_polarizability"][0].sum(axis=1) - ] - self.tfpolar.compute_output_stats(self.all_stat) - tfbias = self.tfpolar.constant_matrix - self.ptpolar.compute_output_stats(self.sampled) - ptbias = self.ptpolar.constant_matrix - np.testing.assert_allclose(tfbias, to_numpy_array(ptbias), rtol=1e-5, atol=1e-5) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index b1a5e3f730..a061780f45 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -39,7 +39,7 @@ dtype = env.GLOBAL_PT_FLOAT_PRECISION -class TestDipoleFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): +class TestPolarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): TestCaseSingleFrameWithNlist.setUp(self) self.rng = np.random.default_rng() @@ -112,16 +112,16 @@ def test_consistency( aparam=to_numpy_array(iap), ) np.testing.assert_allclose( - to_numpy_array(ret0["polar"]), - ret1["polar"], + to_numpy_array(ret0["polarizability"]), + ret1["polarizability"], ) np.testing.assert_allclose( - to_numpy_array(ret0["polar"]), - to_numpy_array(ret2["polar"]), + to_numpy_array(ret0["polarizability"]), + to_numpy_array(ret2["polarizability"]), ) np.testing.assert_allclose( - to_numpy_array(ret0["polar"]), - ret3["polar"], + to_numpy_array(ret0["polarizability"]), + ret3["polarizability"], ) def test_jit( @@ -217,7 +217,7 @@ def test_rot(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) - res.append(ret0["polar"]) + res.append(ret0["polarizability"]) np.testing.assert_allclose( to_numpy_array(res[1]), to_numpy_array( @@ -260,7 +260,7 @@ def test_permu(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=None, aparam=None) - res.append(ret0["polar"]) + res.append(ret0["polarizability"]) np.testing.assert_allclose( to_numpy_array(res[0][:, idx_perm]), @@ -304,12 +304,12 @@ def test_trans(self): ) ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) - res.append(ret0["polar"]) + res.append(ret0["polarizability"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) -class TestDipoleModel(unittest.TestCase): +class TestPolarModel(unittest.TestCase): def setUp(self): self.natoms = 5 self.rcut = 4.0 diff --git a/source/tests/pt/test_finetune.py b/source/tests/pt/test_finetune.py index 8f299ce542..a874d35497 100644 --- a/source/tests/pt/test_finetune.py +++ b/source/tests/pt/test_finetune.py @@ -23,15 +23,54 @@ from deepmd.pt.utils.utils import ( to_numpy_array, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from .model.test_permutation import ( model_dpa2, model_se_e2_a, model_zbl, ) -from .test_stat import ( - energy_data_requirement, -) + +energy_data_requirement = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ), + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ), +] class FinetuneTest: diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index 17b05dadc6..66460dfef1 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -32,7 +32,7 @@ from .model.test_embedding_net import ( get_single_batch, ) -from .test_stat import ( +from .test_finetune import ( energy_data_requirement, ) diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py deleted file mode 100644 index a8519b8f2f..0000000000 --- a/source/tests/pt/test_stat.py +++ /dev/null @@ -1,467 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import json -import os -import tempfile -import unittest -from abc import ( - ABC, - abstractmethod, -) -from pathlib import ( - Path, -) - -import dpdata -import h5py -import numpy as np -import torch - -from deepmd.pt.model.descriptor import ( - DescrptSeA, -) -from deepmd.pt.model.descriptor.dpa1 import ( - DescrptDPA1, -) -from deepmd.pt.model.task.ener import ( - EnergyFittingNet, -) -from deepmd.pt.utils import ( - env, -) -from deepmd.pt.utils.dataloader import ( - DpLoaderSet, -) -from deepmd.pt.utils.stat import ( - compute_output_stats, -) -from deepmd.pt.utils.stat import make_stat_input -from deepmd.pt.utils.stat import make_stat_input as my_make -from deepmd.pt.utils.utils import ( - to_numpy_array, -) -from deepmd.tf.common import ( - expand_sys_str, -) -from deepmd.tf.descriptor.se_a import DescrptSeA as DescrptSeA_tf -from deepmd.tf.descriptor.se_atten import DescrptSeAtten as DescrptSeAtten_tf -from deepmd.tf.fit.ener import ( - EnerFitting, -) -from deepmd.tf.model.model_stat import make_stat_input as dp_make -from deepmd.tf.model.model_stat import merge_sys_stat as dp_merge -from deepmd.tf.utils import random as tf_random -from deepmd.tf.utils.data_system import ( - DeepmdDataSystem, -) -from deepmd.utils.data import ( - DataRequirementItem, -) -from deepmd.utils.path import ( - DPPath, -) - -CUR_DIR = os.path.dirname(__file__) - -energy_data_requirement = [ - DataRequirementItem( - "energy", - ndof=1, - atomic=False, - must=False, - high_prec=True, - ), - DataRequirementItem( - "force", - ndof=3, - atomic=True, - must=False, - high_prec=False, - ), - DataRequirementItem( - "virial", - ndof=9, - atomic=False, - must=False, - high_prec=False, - ), - DataRequirementItem( - "atom_ener", - ndof=1, - atomic=True, - must=False, - high_prec=False, - ), - DataRequirementItem( - "atom_pref", - ndof=1, - atomic=True, - must=False, - high_prec=False, - repeat=3, - ), -] - - -def compare(ut, base, given): - if isinstance(base, list): - ut.assertEqual(len(base), len(given)) - for idx in range(len(base)): - compare(ut, base[idx], given[idx]) - elif isinstance(base, np.ndarray): - ut.assertTrue(np.allclose(base.reshape(-1), given.reshape(-1))) - else: - ut.assertEqual(base, given) - - -class DatasetTest(ABC): - @abstractmethod - def setup_data(self): - pass - - @abstractmethod - def setup_tf(self): - pass - - @abstractmethod - def setup_pt(self): - pass - - @abstractmethod - def tf_compute_input_stats(self): - pass - - def setUp(self): - with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: - content = fin.read() - config = json.loads(content) - data_file = [self.setup_data()] - - config["training"]["training_data"]["systems"] = data_file - config["training"]["validation_data"]["systems"] = data_file - model_config = config["model"] - self.rcut = model_config["descriptor"]["rcut"] - self.rcut_smth = model_config["descriptor"]["rcut_smth"] - self.sel = model_config["descriptor"]["sel"] - self.batch_size = config["training"]["training_data"]["batch_size"] - self.systems = config["training"]["validation_data"]["systems"] - if isinstance(self.systems, str): - self.systems = expand_sys_str(self.systems) - self.my_dataset = DpLoaderSet( - self.systems, - self.batch_size, - model_config["type_map"], - seed=10, - ) - self.filter_neuron = model_config["descriptor"]["neuron"] - self.axis_neuron = model_config["descriptor"]["axis_neuron"] - self.data_stat_nbatch = 2 - self.filter_neuron = model_config["descriptor"]["neuron"] - self.axis_neuron = model_config["descriptor"]["axis_neuron"] - self.n_neuron = model_config["fitting_net"]["neuron"] - self.my_dataset.add_data_requirement(energy_data_requirement) - - self.my_sampled = my_make( - self.my_dataset.systems, self.my_dataset.dataloaders, self.data_stat_nbatch - ) - - tf_random.seed(10) - dp_dataset = DeepmdDataSystem(self.systems, self.batch_size, 1, self.rcut) - dp_dataset.add("energy", 1, atomic=False, must=False, high_prec=True) - dp_dataset.add("force", 3, atomic=True, must=False, high_prec=False) - self.dp_sampled = dp_make(dp_dataset, self.data_stat_nbatch, False) - self.dp_merged = dp_merge(self.dp_sampled) - self.dp_mesh = self.dp_merged.pop("default_mesh") - self.dp_d = self.setup_tf() - - def test_stat_output(self): - def my_merge(energy, natoms): - energy_lst = [] - natoms_lst = [] - for i in range(len(energy)): - for j in range(len(energy[i])): - energy_lst.append(torch.tensor(energy[i][j], device="cpu")) - natoms_lst.append( - torch.tensor(natoms[i][j], device="cpu") - .unsqueeze(0) - .expand(energy[i][j].shape[0], -1) - ) - energy_merge = torch.cat(energy_lst) - natoms_merge = torch.cat(natoms_lst) - return energy_merge, natoms_merge - - energy = self.dp_sampled["energy"] - natoms = self.dp_sampled["natoms_vec"] - energy, natoms = my_merge(energy, natoms) - dp_fn = EnerFitting( - self.dp_d.get_ntypes(), self.dp_d.get_dim_out(), self.n_neuron - ) - dp_fn.compute_output_stats(self.dp_sampled, mixed_type=self.mixed_type) - pt_fn = EnergyFittingNet( - self.dp_d.get_ntypes(), self.dp_d.get_dim_out(), self.n_neuron - ) - pt_fn.compute_output_stats(self.my_sampled) - np.testing.assert_allclose( - dp_fn.bias_atom_e, pt_fn.bias_atom_e.detach().cpu().numpy().ravel() - ) - - # temporarily delete this function for performance of seeds in tf and pytorch may be different - """ - def test_stat_input(self): - my_sampled = self.my_sampled - # list of dicts, each dict contains samples from a system - dp_keys = set(self.dp_merged.keys()) # dict of list of batches - self.dp_merged['natoms'] = self.dp_merged['natoms_vec'] - for key in dp_keys: - if not key in my_sampled[0] or key in 'coord': - # coord is pre-normalized - continue - lst = [] - for item in my_sampled: - bsz = item['energy'].shape[0]//self.data_stat_nbatch - for j in range(self.data_stat_nbatch): - lst.append(item[key][j*bsz:(j+1)*bsz].cpu().numpy()) - compare(self, self.dp_merged[key], lst) - """ - - def test_descriptor(self): - self.tf_compute_input_stats() - - my_en = self.setup_pt() - sampled = self.my_sampled - for sys in sampled: - for key in [ - "coord", - "atype", - "natoms", - "box", - ]: - if key in sys.keys(): - sys[key] = sys[key].to(env.DEVICE) - stat_dict = my_en.compute_input_stats(sampled) - my_en.mean = my_en.mean - my_en.stddev = my_en.stddev - np.testing.assert_allclose( - self.dp_d.davg.reshape([-1]), - my_en.mean.cpu().reshape([-1]), - rtol=1e-14, - atol=1e-14, - ) - np.testing.assert_allclose( - self.dp_d.dstd.reshape([-1]), - my_en.stddev.cpu().reshape([-1]), - rtol=1e-14, - atol=1e-14, - ) - - -class TestDatasetNoMixed(DatasetTest, unittest.TestCase): - def setup_data(self): - original_data = str(Path(__file__).parent / "water/data/data_0") - picked_data = str(Path(__file__).parent / "picked_data_for_test_stat") - dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy( - picked_data - ) - self.mixed_type = False - return picked_data - - def setup_tf(self): - return DescrptSeA_tf( - rcut=self.rcut, - rcut_smth=self.rcut_smth, - sel=self.sel, - neuron=self.filter_neuron, - axis_neuron=self.axis_neuron, - ) - - def setup_pt(self): - return DescrptSeA( - self.rcut, self.rcut_smth, self.sel, self.filter_neuron, self.axis_neuron - ).sea # get the block who has stat as private vars - - def tf_compute_input_stats(self): - coord = self.dp_merged["coord"] - atype = self.dp_merged["type"] - natoms = self.dp_merged["natoms_vec"] - box = self.dp_merged["box"] - self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {}) - - -class TestDatasetMixed(DatasetTest, unittest.TestCase): - def setup_data(self): - original_data = str(Path(__file__).parent / "water/data/data_0") - picked_data = str(Path(__file__).parent / "picked_data_for_test_stat") - dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy_mixed( - picked_data - ) - self.mixed_type = True - return picked_data - - def setup_tf(self): - return DescrptSeAtten_tf( - ntypes=2, - rcut=self.rcut, - rcut_smth=self.rcut_smth, - sel=sum(self.sel), - neuron=self.filter_neuron, - axis_neuron=self.axis_neuron, - set_davg_zero=False, - ) - - def setup_pt(self): - return DescrptDPA1( - self.rcut, - self.rcut_smth, - sum(self.sel), - 2, - self.filter_neuron, - self.axis_neuron, - set_davg_zero=False, - ).se_atten - - def tf_compute_input_stats(self): - coord = self.dp_merged["coord"] - atype = self.dp_merged["type"] - natoms = self.dp_merged["natoms_vec"] - box = self.dp_merged["box"] - real_natoms_vec = self.dp_merged["real_natoms_vec"] - - self.dp_d.compute_input_stats( - coord, - box, - atype, - natoms, - self.dp_mesh, - {}, - mixed_type=True, - real_natoms_vec=real_natoms_vec, - ) - - -class TestExcludeTypes(DatasetTest, unittest.TestCase): - def setup_data(self): - original_data = str(Path(__file__).parent / "water/data/data_0") - picked_data = str(Path(__file__).parent / "picked_data_for_test_stat") - dpdata.LabeledSystem(original_data, fmt="deepmd/npy")[:2].to_deepmd_npy( - picked_data - ) - self.mixed_type = False - return picked_data - - def setup_tf(self): - return DescrptSeA_tf( - rcut=self.rcut, - rcut_smth=self.rcut_smth, - sel=self.sel, - neuron=self.filter_neuron, - axis_neuron=self.axis_neuron, - exclude_types=[[0, 0], [1, 1]], - ) - - def setup_pt(self): - return DescrptSeA( - self.rcut, - self.rcut_smth, - self.sel, - self.filter_neuron, - self.axis_neuron, - exclude_types=[[0, 0], [1, 1]], - ).sea # get the block who has stat as private vars - - def tf_compute_input_stats(self): - coord = self.dp_merged["coord"] - atype = self.dp_merged["type"] - natoms = self.dp_merged["natoms_vec"] - box = self.dp_merged["box"] - self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {}) - - -class TestOutputStat(unittest.TestCase): - def setUp(self): - self.data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.type_map = ["O", "H"] # by dataset - self.data = DpLoaderSet( - self.data_file, - batch_size=1, - type_map=self.type_map, - ) - self.data.add_data_requirement(energy_data_requirement) - self.sampled = make_stat_input( - self.data.systems, - self.data.dataloaders, - nbatches=1, - ) - self.tempdir = tempfile.TemporaryDirectory() - h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) - with h5py.File(h5file, "w") as f: - pass - self.stat_file_path = DPPath(h5file, "a") - - def tearDown(self): - self.tempdir.cleanup() - - def test_calc_and_load(self): - stat_file_path = self.stat_file_path - type_map = self.type_map - - # compute from sample - ret0, _ = compute_output_stats( - self.sampled, - len(type_map), - keys=["energy"], - stat_file_path=stat_file_path, - preset_bias=None, - model_forward=None, - ) - # ground truth - ntest = 1 - atom_nums = np.tile( - np.bincount(to_numpy_array(self.sampled[0]["atype"][0])), - (ntest, 1), - ) - energy_diff = to_numpy_array(self.sampled[0]["energy"][:ntest]) - ground_truth_shift = np.linalg.lstsq(atom_nums, energy_diff, rcond=None)[0] - - # check values - np.testing.assert_almost_equal( - to_numpy_array(ret0["energy"]), ground_truth_shift, decimal=10 - ) - # self.assertTrue(stat_file_path.is_dir()) - - def raise_error(): - raise RuntimeError - - # hack!!! - # suppose to load stat from file, if from sample, an error will raise. - ret1, _ = compute_output_stats( - raise_error, - len(type_map), - keys=["energy"], - stat_file_path=stat_file_path, - preset_bias=None, - model_forward=None, - ) - np.testing.assert_almost_equal( - to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]), decimal=10 - ) - - def test_assigned(self): - atom_ener = {"energy": np.array([3.0, 5.0]).reshape(2, 1)} - stat_file_path = self.stat_file_path - type_map = self.type_map - - # from assigned atom_ener - ret2, _ = compute_output_stats( - self.sampled, - len(type_map), - keys=["energy"], - stat_file_path=stat_file_path, - preset_bias=atom_ener, - model_forward=None, - ) - np.testing.assert_almost_equal( - to_numpy_array(ret2["energy"]), atom_ener["energy"], decimal=10 - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/source/tests/tf/test_deepmd_data.py b/source/tests/tf/test_deepmd_data.py index 94e1f4c571..31dc8235d3 100644 --- a/source/tests/tf/test_deepmd_data.py +++ b/source/tests/tf/test_deepmd_data.py @@ -190,8 +190,8 @@ def test_load_set(self): self.assertEqual(data["type"][ii][1], 1) self.assertEqual(data["find_coord"], 1) self._comp_np_mat2(data["coord"], self.coord) - self.assertEqual(data["find_test_atomic"], 1) - self._comp_np_mat2(data["test_atomic"], self.test_atomic) + self.assertEqual(data["find_test_atom"], 1) + self._comp_np_mat2(data["test_atom"], self.test_atomic) self.assertEqual(data["find_test_frame"], 1) self._comp_np_mat2(data["test_frame"], self.test_frame) self.assertEqual(data["find_test_null"], 0) @@ -207,7 +207,7 @@ def test_shuffle(self): data_bk = copy.deepcopy(data) data, idx = dd._shuffle_data(data) self._comp_np_mat2(data_bk["coord"][idx, :], data["coord"]) - self._comp_np_mat2(data_bk["test_atomic"][idx, :], data["test_atomic"]) + self._comp_np_mat2(data_bk["test_atom"][idx, :], data["test_atom"]) self._comp_np_mat2(data_bk["test_frame"][idx, :], data["test_frame"]) def test_shuffle_with_numb_copy(self): @@ -224,15 +224,15 @@ def test_shuffle_with_numb_copy(self): data, idx = dd._shuffle_data(data) assert idx.size == np.sum(prob) self._comp_np_mat2(data_bk["coord"][idx, :], data["coord"]) - self._comp_np_mat2(data_bk["test_atomic"][idx, :], data["test_atomic"]) + self._comp_np_mat2(data_bk["test_atom"][idx, :], data["test_atom"]) self._comp_np_mat2(data_bk["test_frame"][idx, :], data["test_frame"]) def test_reduce(self): dd = DeepmdData(self.data_name).add("test_atomic", 7, atomic=True, must=True) dd.reduce("redu", "test_atomic") data = dd._load_set(os.path.join(self.data_name, "set.foo")) - self.assertEqual(data["find_test_atomic"], 1) - self._comp_np_mat2(data["test_atomic"], self.test_atomic) + self.assertEqual(data["find_test_atom"], 1) + self._comp_np_mat2(data["test_atom"], self.test_atomic) self.assertEqual(data["find_redu"], 1) self._comp_np_mat2(data["redu"], self.redu_atomic) @@ -240,9 +240,9 @@ def test_reduce_null(self): dd = DeepmdData(self.data_name).add("test_atomic_1", 7, atomic=True, must=False) dd.reduce("redu", "test_atomic_1") data = dd._load_set(os.path.join(self.data_name, "set.foo")) - self.assertEqual(data["find_test_atomic_1"], 0) + self.assertEqual(data["find_test_atom_1"], 0) self._comp_np_mat2( - data["test_atomic_1"], np.zeros([self.nframes, self.natoms * 7]) + data["test_atom_1"], np.zeros([self.nframes, self.natoms * 7]) ) self.assertEqual(data["find_redu"], 0) self._comp_np_mat2(data["redu"], np.zeros([self.nframes, 7]))