diff --git a/deepmd/dpmodel/descriptor/__init__.py b/deepmd/dpmodel/descriptor/__init__.py index 8542168d91..765e3069fd 100644 --- a/deepmd/dpmodel/descriptor/__init__.py +++ b/deepmd/dpmodel/descriptor/__init__.py @@ -5,6 +5,9 @@ from .dpa2 import ( DescrptDPA2, ) +from .dpa3 import ( + DescrptDPA3, +) from .hybrid import ( DescrptHybrid, ) @@ -30,6 +33,7 @@ __all__ = [ "DescrptDPA1", "DescrptDPA2", + "DescrptDPA3", "DescrptHybrid", "DescrptSeA", "DescrptSeAttenV2", diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py new file mode 100644 index 0000000000..3150956d4a --- /dev/null +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -0,0 +1,647 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + NoReturn, + Optional, + Union, +) + +import array_api_compat +import numpy as np + +from deepmd.dpmodel import ( + NativeOP, +) +from deepmd.dpmodel.common import ( + cast_precision, + to_numpy_array, +) +from deepmd.dpmodel.utils import ( + EnvMat, +) +from deepmd.dpmodel.utils.network import ( + NativeLayer, +) +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.dpmodel.utils.type_embed import ( + TypeEmbedNet, +) +from deepmd.dpmodel.utils.update_sel import ( + UpdateSel, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_pair_exclude_types, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .descriptor import ( + extend_descrpt_stat, +) +from .repflows import ( + DescrptBlockRepflows, + RepFlowLayer, +) + + +class RepFlowArgs: + def __init__( + self, + n_dim: int = 128, + e_dim: int = 64, + a_dim: int = 64, + nlayers: int = 6, + e_rcut: float = 6.0, + e_rcut_smth: float = 5.0, + e_sel: int = 120, + a_rcut: float = 4.0, + a_rcut_smth: float = 3.5, + a_sel: int = 20, + a_compress_rate: int = 0, + a_compress_e_rate: int = 1, + a_compress_use_split: bool = False, + n_multi_edge_message: int = 1, + axis_neuron: int = 4, + update_angle: bool = True, + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", + fix_stat_std: float = 0.3, + skip_stat: bool = False, + optim_update: bool = True, + ) -> None: + r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor. + + Parameters + ---------- + n_dim : int, optional + The dimension of node representation. + e_dim : int, optional + The dimension of edge representation. + a_dim : int, optional + The dimension of angle representation. + nlayers : int, optional + Number of repflow layers. + e_rcut : float, optional + The edge cut-off radius. + e_rcut_smth : float, optional + Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth. + e_sel : int, optional + Maximally possible number of selected edge neighbors. + a_rcut : float, optional + The angle cut-off radius. + a_rcut_smth : float, optional + Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. + a_sel : int, optional + Maximally possible number of selected angle neighbors. + a_compress_rate : int, optional + The compression rate for angular messages. The default value is 0, indicating no compression. + If a non-zero integer c is provided, the node and edge dimensions will be compressed + to a_dim/c and a_dim/2c, respectively, within the angular message. + a_compress_e_rate : int, optional + The extra compression rate for edge in angular message compression. The default value is 1. + When using angular message compression with a_compress_rate c and a_compress_e_rate c_e, + the edge dimension will be compressed to (c_e * a_dim / 2c) within the angular message. + a_compress_use_split : bool, optional + Whether to split first sub-vectors instead of linear mapping during angular message compression. + The default value is False. + n_multi_edge_message : int, optional + The head number of multiple edge messages to update node feature. + Default is 1, indicating one head edge message. + axis_neuron : int, optional + The number of dimension of submatrix in the symmetrization ops. + update_angle : bool, optional + Where to update the angle rep. If not, only node and edge rep will be used. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + fix_stat_std : float, optional + If non-zero (default is 0.3), use this constant as the normalization standard deviation + instead of computing it from data statistics. + skip_stat : bool, optional + (Deprecated, kept only for compatibility.) This parameter is obsolete and will be removed. + If set to True, it forces fix_stat_std=0.3 for backward compatibility. + Transition to fix_stat_std parameter immediately. + optim_update : bool, optional + Whether to enable the optimized update method. + Uses a more efficient process when enabled. Defaults to True + """ + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.nlayers = nlayers + self.e_rcut = e_rcut + self.e_rcut_smth = e_rcut_smth + self.e_sel = e_sel + self.a_rcut = a_rcut + self.a_rcut_smth = a_rcut_smth + self.a_sel = a_sel + self.a_compress_rate = a_compress_rate + self.n_multi_edge_message = n_multi_edge_message + self.axis_neuron = axis_neuron + self.update_angle = update_angle + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.fix_stat_std = ( + fix_stat_std if not skip_stat else 0.3 + ) # backward compatibility + self.skip_stat = skip_stat + self.a_compress_e_rate = a_compress_e_rate + self.a_compress_use_split = a_compress_use_split + self.optim_update = optim_update + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise KeyError(key) + + def serialize(self) -> dict: + return { + "n_dim": self.n_dim, + "e_dim": self.e_dim, + "a_dim": self.a_dim, + "nlayers": self.nlayers, + "e_rcut": self.e_rcut, + "e_rcut_smth": self.e_rcut_smth, + "e_sel": self.e_sel, + "a_rcut": self.a_rcut, + "a_rcut_smth": self.a_rcut_smth, + "a_sel": self.a_sel, + "a_compress_rate": self.a_compress_rate, + "a_compress_e_rate": self.a_compress_e_rate, + "a_compress_use_split": self.a_compress_use_split, + "n_multi_edge_message": self.n_multi_edge_message, + "axis_neuron": self.axis_neuron, + "update_angle": self.update_angle, + "update_style": self.update_style, + "update_residual": self.update_residual, + "update_residual_init": self.update_residual_init, + "fix_stat_std": self.fix_stat_std, + "optim_update": self.optim_update, + } + + @classmethod + def deserialize(cls, data: dict) -> "RepFlowArgs": + return cls(**data) + + +@BaseDescriptor.register("dpa3") +class DescrptDPA3(NativeOP, BaseDescriptor): + def __init__( + self, + ntypes: int, + # args for repflow + repflow: Union[RepFlowArgs, dict], + # kwargs for descriptor + concat_output_tebd: bool = False, + activation_function: str = "silu", + precision: str = "float64", + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + trainable: bool = True, + seed: Optional[Union[int, list[int]]] = None, + use_econf_tebd: bool = False, + use_tebd_bias: bool = False, + type_map: Optional[list[str]] = None, + ) -> None: + r"""The DPA-3 descriptor. + + Parameters + ---------- + repflow : Union[RepFlowArgs, dict] + The arguments used to initialize the repflow block, see docstr in `RepFlowArgs` for details information. + concat_output_tebd : bool, optional + Whether to concat type embedding at the output of the descriptor. + activation_function : str, optional + The activation function in the embedding net. + precision : str, optional + The precision of the embedding net parameters. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable : bool, optional + If the parameters are trainable. + seed : int, optional + Random seed for parameter initialization. + use_econf_tebd : bool, Optional + Whether to use electronic configuration type embedding. + use_tebd_bias : bool, Optional + Whether to use bias in the type embedding layer. + type_map : list[str], Optional + A list of strings. Give the name to each type of atoms. + + Returns + ------- + descriptor: torch.Tensor + the descriptor of shape nb x nloc x n_dim. + invariant single-atom representation. + g2: torch.Tensor + invariant pair-atom representation. + h2: torch.Tensor + equivariant pair-atom representation. + rot_mat: torch.Tensor + rotation matrix for equivariant fittings + sw: torch.Tensor + The switch function for decaying inverse distance. + + """ + super().__init__() + + def init_subclass_params(sub_data, sub_class): + if isinstance(sub_data, dict): + return sub_class(**sub_data) + elif isinstance(sub_data, sub_class): + return sub_data + else: + raise ValueError( + f"Input args must be a {sub_class.__name__} class or a dict!" + ) + + self.repflow_args = init_subclass_params(repflow, RepFlowArgs) + self.activation_function = activation_function + + self.repflows = DescrptBlockRepflows( + self.repflow_args.e_rcut, + self.repflow_args.e_rcut_smth, + self.repflow_args.e_sel, + self.repflow_args.a_rcut, + self.repflow_args.a_rcut_smth, + self.repflow_args.a_sel, + ntypes, + nlayers=self.repflow_args.nlayers, + n_dim=self.repflow_args.n_dim, + e_dim=self.repflow_args.e_dim, + a_dim=self.repflow_args.a_dim, + a_compress_rate=self.repflow_args.a_compress_rate, + a_compress_e_rate=self.repflow_args.a_compress_e_rate, + a_compress_use_split=self.repflow_args.a_compress_use_split, + n_multi_edge_message=self.repflow_args.n_multi_edge_message, + axis_neuron=self.repflow_args.axis_neuron, + update_angle=self.repflow_args.update_angle, + activation_function=self.activation_function, + update_style=self.repflow_args.update_style, + update_residual=self.repflow_args.update_residual, + update_residual_init=self.repflow_args.update_residual_init, + fix_stat_std=self.repflow_args.fix_stat_std, + optim_update=self.repflow_args.optim_update, + exclude_types=exclude_types, + env_protection=env_protection, + precision=precision, + seed=child_seed(seed, 1), + ) + + self.use_econf_tebd = use_econf_tebd + self.use_tebd_bias = use_tebd_bias + self.type_map = type_map + self.tebd_dim = self.repflow_args.n_dim + self.type_embedding = TypeEmbedNet( + ntypes=ntypes, + neuron=[self.tebd_dim], + padding=True, + activation_function="Linear", + precision=precision, + use_econf_tebd=self.use_econf_tebd, + use_tebd_bias=use_tebd_bias, + type_map=type_map, + seed=child_seed(seed, 2), + ) + self.concat_output_tebd = concat_output_tebd + self.precision = precision + self.exclude_types = exclude_types + self.env_protection = env_protection + self.trainable = trainable + + assert self.repflows.e_rcut >= self.repflows.a_rcut + assert self.repflows.e_sel >= self.repflows.a_sel + + self.rcut = self.repflows.get_rcut() + self.rcut_smth = self.repflows.get_rcut_smth() + self.sel = self.repflows.get_sel() + self.ntypes = ntypes + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + ret = self.repflows.dim_out + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + """Returns the embedding dimension of this descriptor.""" + return self.repflows.dim_emb + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.repflows.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.repflows.get_env_protection() + + def share_params(self, base_class, shared_level, resume=False) -> None: + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some separated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + raise NotImplementedError + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert self.type_map is not None, ( + "'type_map' must be defined when performing type changing!" + ) + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + self.type_map = type_map + self.type_embedding.change_type_map(type_map=type_map) + self.exclude_types = map_pair_exclude_types(self.exclude_types, remap_index) + self.ntypes = len(type_map) + repflow = self.repflows + if has_new_type: + # the avg and std of new types need to be updated + extend_descrpt_stat( + repflow, + type_map, + des_with_stat=model_with_new_type_stat.repflows + if model_with_new_type_stat is not None + else None, + ) + repflow.ntypes = self.ntypes + repflow.reinit_exclude(self.exclude_types) + repflow["davg"] = repflow["davg"][remap_index] + repflow["dstd"] = repflow["dstd"][remap_index] + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def compute_input_stats( + self, merged: list[dict], path: Optional[DPPath] = None + ) -> NoReturn: + """Update mean and stddev for descriptor elements.""" + raise NotImplementedError + + def set_stat_mean_and_stddev( + self, + mean: list[np.ndarray], + stddev: list[np.ndarray], + ) -> None: + """Update mean and stddev for descriptor.""" + descrpt_list = [self.repflows] + for ii, descrpt in enumerate(descrpt_list): + descrpt.mean = mean[ii] + descrpt.stddev = stddev[ii] + + def get_stat_mean_and_stddev(self) -> tuple[list[np.ndarray], list[np.ndarray]]: + """Get mean and stddev for descriptor.""" + mean_list = [self.repflows.mean] + stddev_list = [self.repflows.stddev] + return mean_list, stddev_list + + @cast_precision + def call( + self, + coord_ext: np.ndarray, + atype_ext: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, mapps extended region index to local region. + + Returns + ------- + node_ebd + The output descriptor. shape: nf x nloc x n_dim (or n_dim + tebd_dim) + rot_mat + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x e_dim x 3 + edge_ebd + The edge embedding. + shape: nf x nloc x nnei x e_dim + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) + nframes, nloc, nnei = nlist.shape + nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3 + + type_embedding = self.type_embedding.call() + node_ebd_ext = xp.reshape( + xp.take(type_embedding, xp.reshape(atype_ext, [-1]), axis=0), + (nframes, nall, self.tebd_dim), + ) + node_ebd_inp = node_ebd_ext[:, :nloc, :] + # repflows + node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows( + nlist, + coord_ext, + atype_ext, + node_ebd_ext, + mapping, + ) + if self.concat_output_tebd: + node_ebd = xp.concat([node_ebd, node_ebd_inp], axis=-1) + return node_ebd, rot_mat, edge_ebd, h2, sw + + def serialize(self) -> dict: + repflows = self.repflows + data = { + "@class": "Descriptor", + "type": "dpa3", + "@version": 1, + "ntypes": self.ntypes, + "repflow_args": self.repflow_args.serialize(), + "concat_output_tebd": self.concat_output_tebd, + "activation_function": self.activation_function, + "precision": self.precision, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "trainable": self.trainable, + "use_econf_tebd": self.use_econf_tebd, + "use_tebd_bias": self.use_tebd_bias, + "type_map": self.type_map, + "type_embedding": self.type_embedding.serialize(), + } + repflow_variable = { + "edge_embd": repflows.edge_embd.serialize(), + "angle_embd": repflows.angle_embd.serialize(), + "repflow_layers": [layer.serialize() for layer in repflows.layers], + "env_mat": EnvMat(repflows.rcut, repflows.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(repflows["davg"]), + "dstd": to_numpy_array(repflows["dstd"]), + }, + } + data.update( + { + "repflow_variable": repflow_variable, + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA3": + data = data.copy() + version = data.pop("@version") + check_version_compatibility(version, 1, 1) + data.pop("@class") + data.pop("type") + repflow_variable = data.pop("repflow_variable").copy() + type_embedding = data.pop("type_embedding") + data["repflow"] = RepFlowArgs(**data.pop("repflow_args")) + obj = cls(**data) + obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) + + # deserialize repflow + statistic_repflows = repflow_variable.pop("@variables") + env_mat = repflow_variable.pop("env_mat") + repflow_layers = repflow_variable.pop("repflow_layers") + obj.repflows.edge_embd = NativeLayer.deserialize( + repflow_variable.pop("edge_embd") + ) + obj.repflows.angle_embd = NativeLayer.deserialize( + repflow_variable.pop("angle_embd") + ) + obj.repflows["davg"] = statistic_repflows["davg"] + obj.repflows["dstd"] = statistic_repflows["dstd"] + obj.repflows.layers = [ + RepFlowLayer.deserialize(layer) for layer in repflow_layers + ] + return obj + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + update_sel = UpdateSel() + min_nbor_dist, repflow_e_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repflow"]["e_rcut"], + local_jdata_cpy["repflow"]["e_sel"], + True, + ) + local_jdata_cpy["repflow"]["e_sel"] = repflow_e_sel[0] + + min_nbor_dist, repflow_a_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repflow"]["a_rcut"], + local_jdata_cpy["repflow"]["a_sel"], + True, + ) + local_jdata_cpy["repflow"]["a_sel"] = repflow_a_sel[0] + + return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py new file mode 100644 index 0000000000..5d4a1b3cde --- /dev/null +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -0,0 +1,1342 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + NoReturn, + Optional, + Union, +) + +import array_api_compat +import numpy as np + +from deepmd.dpmodel import ( + PRECISION_DICT, + NativeOP, +) +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.dpmodel.utils import ( + EnvMat, + PairExcludeMask, +) +from deepmd.dpmodel.utils.network import ( + NativeLayer, + get_activation_fn, +) +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .descriptor import ( + DescriptorBlock, +) +from .repformers import ( + _cal_hg, + _make_nei_g1, + get_residual, + symmetrization_op, +) + + +@DescriptorBlock.register("se_repflow") +class DescrptBlockRepflows(NativeOP, DescriptorBlock): + def __init__( + self, + e_rcut, + e_rcut_smth, + e_sel: int, + a_rcut, + a_rcut_smth, + a_sel: int, + ntypes: int, + nlayers: int = 6, + n_dim: int = 128, + e_dim: int = 64, + a_dim: int = 64, + a_compress_rate: int = 0, + a_compress_e_rate: int = 1, + a_compress_use_split: bool = False, + n_multi_edge_message: int = 1, + axis_neuron: int = 4, + update_angle: bool = True, + activation_function: str = "silu", + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", + set_davg_zero: bool = True, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + precision: str = "float64", + fix_stat_std: float = 0.3, + optim_update: bool = True, + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + r""" + The repflow descriptor block. + + Parameters + ---------- + n_dim : int, optional + The dimension of node representation. + e_dim : int, optional + The dimension of edge representation. + a_dim : int, optional + The dimension of angle representation. + nlayers : int, optional + Number of repflow layers. + e_rcut : float, optional + The edge cut-off radius. + e_rcut_smth : float, optional + Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth. + e_sel : int, optional + Maximally possible number of selected edge neighbors. + a_rcut : float, optional + The angle cut-off radius. + a_rcut_smth : float, optional + Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. + a_sel : int, optional + Maximally possible number of selected angle neighbors. + a_compress_rate : int, optional + The compression rate for angular messages. The default value is 0, indicating no compression. + If a non-zero integer c is provided, the node and edge dimensions will be compressed + to a_dim/c and a_dim/2c, respectively, within the angular message. + a_compress_e_rate : int, optional + The extra compression rate for edge in angular message compression. The default value is 1. + When using angular message compression with a_compress_rate c and a_compress_e_rate c_e, + the edge dimension will be compressed to (c_e * a_dim / 2c) within the angular message. + a_compress_use_split : bool, optional + Whether to split first sub-vectors instead of linear mapping during angular message compression. + The default value is False. + n_multi_edge_message : int, optional + The head number of multiple edge messages to update node feature. + Default is 1, indicating one head edge message. + axis_neuron : int, optional + The number of dimension of submatrix in the symmetrization ops. + update_angle : bool, optional + Where to update the angle rep. If not, only node and edge rep will be used. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + fix_stat_std : float, optional + If non-zero (default is 0.3), use this constant as the normalization standard deviation + instead of computing it from data statistics. + optim_update : bool, optional + Whether to enable the optimized update method. + Uses a more efficient process when enabled. Defaults to True + ntypes : int + Number of element types + activation_function : str, optional + The activation function in the embedding net. + set_davg_zero : bool, optional + Set the normalization average to zero. + precision : str, optional + The precision of the embedding net parameters. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + seed : int, optional + Random seed for parameter initialization. + """ + super().__init__() + self.e_rcut = float(e_rcut) + self.e_rcut_smth = float(e_rcut_smth) + self.e_sel = e_sel + self.a_rcut = float(a_rcut) + self.a_rcut_smth = float(a_rcut_smth) + self.a_sel = a_sel + self.ntypes = ntypes + self.nlayers = nlayers + # for other common desciptor method + sel = [e_sel] if isinstance(e_sel, int) else e_sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 # use full descriptor. + assert len(sel) == 1 + self.sel = sel + self.rcut = e_rcut + self.rcut_smth = e_rcut_smth + self.sec = self.sel + self.split_sel = self.sel + self.a_compress_rate = a_compress_rate + self.a_compress_e_rate = a_compress_e_rate + self.n_multi_edge_message = n_multi_edge_message + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + self.fix_stat_std = fix_stat_std + self.set_stddev_constant = fix_stat_std != 0.0 + self.a_compress_use_split = a_compress_use_split + self.optim_update = optim_update + + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.update_angle = update_angle + + self.activation_function = activation_function + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.act = get_activation_fn(self.activation_function) + self.prec = PRECISION_DICT[precision] + + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + self.env_protection = env_protection + self.precision = precision + self.epsilon = 1e-4 + self.seed = seed + + self.edge_embd = NativeLayer( + 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) + ) + self.angle_embd = NativeLayer( + 1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1) + ) + layers = [] + for ii in range(nlayers): + layers.append( + RepFlowLayer( + e_rcut=self.e_rcut, + e_rcut_smth=self.e_rcut_smth, + e_sel=self.sel, + a_rcut=self.a_rcut, + a_rcut_smth=self.a_rcut_smth, + a_sel=self.a_sel, + ntypes=self.ntypes, + n_dim=self.n_dim, + e_dim=self.e_dim, + a_dim=self.a_dim, + a_compress_rate=self.a_compress_rate, + a_compress_use_split=self.a_compress_use_split, + a_compress_e_rate=self.a_compress_e_rate, + n_multi_edge_message=self.n_multi_edge_message, + axis_neuron=self.axis_neuron, + update_angle=self.update_angle, + activation_function=self.activation_function, + update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, + precision=precision, + optim_update=self.optim_update, + seed=child_seed(child_seed(seed, 1), ii), + ) + ) + self.layers = layers + + wanted_shape = (self.ntypes, self.nnei, 4) + self.env_mat_edge = EnvMat( + self.e_rcut, self.e_rcut_smth, protection=self.env_protection + ) + self.env_mat_angle = EnvMat( + self.a_rcut, self.a_rcut_smth, protection=self.env_protection + ) + self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) + self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) + if self.set_stddev_constant: + self.stddev = self.stddev * self.fix_stat_std + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.e_rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.e_rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_emb(self) -> int: + """Returns the embedding dimension e_dim.""" + return self.e_dim + + def __setitem__(self, key, value) -> None: + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.n_dim + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.n_dim + + @property + def dim_emb(self): + """Returns the embedding dimension e_dim.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> NoReturn: + """Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data.""" + raise NotImplementedError + + def get_stats(self) -> NoReturn: + """Get the statistics of the descriptor.""" + raise NotImplementedError + + def reinit_exclude( + self, + exclude_types: list[tuple[int, int]] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def call( + self, + nlist: np.ndarray, + coord_ext: np.ndarray, + atype_ext: np.ndarray, + atype_embd_ext: Optional[np.ndarray] = None, + mapping: Optional[np.ndarray] = None, + ): + xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) + nframes, nloc, nnei = nlist.shape + exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + # nb x nloc x nnei + nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) + # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 + dmatrix, diff, sw = self.env_mat_edge.call( + coord_ext, atype_ext, nlist, self.mean, self.stddev + ) + # nb x nloc x nnei + nlist_mask = nlist != -1 + sw = xp.reshape(sw, (nframes, nloc, nnei)) + # beyond the cutoff sw should be 0.0 + sw = xp.where(nlist_mask, sw, xp.zeros_like(sw)) + + # nb x nloc x tebd_dim + atype_embd = atype_embd_ext[:, :nloc, :] + assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] + + node_ebd = self.act(atype_embd) + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + # edge_input, h2 = xp.split(dmatrix, [1], axis=-1) + edge_input = dmatrix[:, :, :, :1] + h2 = dmatrix[:, :, :, 1:] + # nb x nloc x nnei x e_dim + edge_ebd = self.act(self.edge_embd(edge_input)) + + # get angle nlist (maybe smaller) + a_dist_mask = (xp.linalg.vector_norm(diff, axis=-1) < self.a_rcut)[ + :, :, : self.a_sel + ] + a_nlist = nlist[:, :, : self.a_sel] + a_nlist = xp.where(a_dist_mask, a_nlist, xp.full_like(a_nlist, -1)) + + _, a_diff, a_sw = self.env_mat_angle.call( + coord_ext, + atype_ext, + a_nlist, + self.mean[:, : self.a_sel, :], + self.stddev[:, : self.a_sel, :], + ) + + # nb x nloc x a_nnei + a_nlist_mask = a_nlist != -1 + a_sw = xp.reshape(a_sw, (nframes, nloc, self.a_sel)) + # beyond the cutoff sw should be 0.0 + a_sw = xp.where(a_nlist_mask, a_sw, xp.zeros_like(a_sw)) + a_nlist = xp.where(a_nlist == -1, xp.zeros_like(a_nlist), a_nlist) + + # nf x nloc x a_nnei x 3 + normalized_diff_i = a_diff / ( + xp.linalg.vector_norm(a_diff, axis=-1, keepdims=True) + 1e-6 + ) + # nf x nloc x 3 x a_nnei + normalized_diff_j = xp.matrix_transpose(normalized_diff_i) + # nf x nloc x a_nnei x a_nnei + # 1 - 1e-6 for torch.acos stability + cosine_ij = xp.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6) + # nf x nloc x a_nnei x a_nnei x 1 + cosine_ij = xp.reshape( + cosine_ij, (nframes, nloc, self.a_sel, self.a_sel, 1) + ) / (xp.pi**0.5) + # nf x nloc x a_nnei x a_nnei x a_dim + angle_ebd = xp.reshape( + self.angle_embd(cosine_ij), + (nframes, nloc, self.a_sel, self.a_sel, self.a_dim), + ) + + # set all padding positions to index of 0 + # if a neighbor is real or not is indicated by nlist_mask + nlist = xp.where(nlist == -1, xp.zeros_like(nlist), nlist) + # nb x nall x n_dim + mapping = xp.tile(xp.reshape(mapping, (nframes, -1, 1)), (1, 1, self.n_dim)) + for idx, ll in enumerate(self.layers): + # node_ebd: nb x nloc x n_dim + # node_ebd_ext: nb x nall x n_dim + node_ebd_ext = xp_take_along_axis(node_ebd, mapping, axis=1) + node_ebd, edge_ebd, angle_ebd = ll.call( + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist, + a_nlist_mask, + a_sw, + ) + + # nb x nloc x 3 x e_dim + h2g2 = _cal_hg(edge_ebd, h2, nlist_mask, sw) + # nb x nloc x e_dim x 3 + rot_mat = xp.matrix_transpose(h2g2) + + return ( + node_ebd, + edge_ebd, + h2, + xp.reshape(rot_mat, (nframes, nloc, self.dim_emb, 3)), + sw, + ) + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return True + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return True + + @classmethod + def deserialize(cls, data): + """Deserialize the descriptor block.""" + data = data.copy() + edge_embd = NativeLayer.deserialize(data.pop("edge_embd")) + angle_embd = NativeLayer.deserialize(data.pop("angle_embd")) + layers = [RepFlowLayer.deserialize(dd) for dd in data.pop("repflow_layers")] + env_mat_edge = EnvMat.deserialize(data.pop("env_mat_edge")) + env_mat_angle = EnvMat.deserialize(data.pop("env_mat_angle")) + variables = data.pop("@variables") + davg = variables["davg"] + dstd = variables["dstd"] + obj = cls(**data) + obj.edge_embd = edge_embd + obj.angle_embd = angle_embd + obj.layers = layers + obj.env_mat_edge = env_mat_edge + obj.env_mat_angle = env_mat_angle + obj.mean = davg + obj.stddev = dstd + return obj + + def serialize(self): + """Serialize the descriptor block.""" + return { + "e_rcut": self.e_rcut, + "e_rcut_smth": self.e_rcut_smth, + "e_sel": self.e_sel, + "a_rcut": self.a_rcut, + "a_rcut_smth": self.a_rcut_smth, + "a_sel": self.a_sel, + "ntypes": self.ntypes, + "nlayers": self.nlayers, + "n_dim": self.n_dim, + "e_dim": self.e_dim, + "a_dim": self.a_dim, + "a_compress_rate": self.a_compress_rate, + "a_compress_e_rate": self.a_compress_e_rate, + "a_compress_use_split": self.a_compress_use_split, + "n_multi_edge_message": self.n_multi_edge_message, + "axis_neuron": self.axis_neuron, + "update_angle": self.update_angle, + "activation_function": self.activation_function, + "update_style": self.update_style, + "update_residual": self.update_residual, + "update_residual_init": self.update_residual_init, + "set_davg_zero": self.set_davg_zero, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "precision": self.precision, + "fix_stat_std": self.fix_stat_std, + "optim_update": self.optim_update, + # variables + "edge_embd": self.edge_embd.serialize(), + "angle_embd": self.angle_embd.serialize(), + "repflow_layers": [layer.serialize() for layer in self.layers], + "env_mat_edge": self.env_mat_edge.serialize(), + "env_mat_angle": self.env_mat_angle.serialize(), + "@variables": { + "davg": to_numpy_array(self["davg"]), + "dstd": to_numpy_array(self["dstd"]), + }, + } + + +class RepFlowLayer(NativeOP): + def __init__( + self, + e_rcut: float, + e_rcut_smth: float, + e_sel: int, + a_rcut: float, + a_rcut_smth: float, + a_sel: int, + ntypes: int, + n_dim: int = 128, + e_dim: int = 16, + a_dim: int = 64, + a_compress_rate: int = 0, + a_compress_use_split: bool = False, + a_compress_e_rate: int = 1, + n_multi_edge_message: int = 1, + axis_neuron: int = 4, + update_angle: bool = True, + optim_update: bool = True, + activation_function: str = "silu", + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.e_rcut = float(e_rcut) + self.e_rcut_smth = float(e_rcut_smth) + self.ntypes = ntypes + e_sel = [e_sel] if isinstance(e_sel, int) else e_sel + self.nnei = sum(e_sel) + assert len(e_sel) == 1 + self.e_sel = e_sel + self.sec = self.e_sel + self.a_rcut = a_rcut + self.a_rcut_smth = a_rcut_smth + self.a_sel = a_sel + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.a_compress_rate = a_compress_rate + if a_compress_rate != 0: + assert (a_dim * a_compress_e_rate) % (2 * a_compress_rate) == 0, ( + f"For a_compress_rate of {a_compress_rate}, a_dim*a_compress_e_rate must be divisible by {2 * a_compress_rate}. " + f"Currently, a_dim={a_dim} and a_compress_e_rate={a_compress_e_rate} is not valid." + ) + self.n_multi_edge_message = n_multi_edge_message + assert self.n_multi_edge_message >= 1, "n_multi_edge_message must >= 1!" + self.axis_neuron = axis_neuron + self.update_angle = update_angle + self.activation_function = activation_function + self.act = get_activation_fn(self.activation_function) + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.a_compress_e_rate = a_compress_e_rate + self.a_compress_use_split = a_compress_use_split + self.precision = precision + self.seed = seed + self.prec = PRECISION_DICT[precision] + self.optim_update = optim_update + + assert update_residual_init in [ + "norm", + "const", + ], "'update_residual_init' only support 'norm' or 'const'!" + + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.n_residual = [] + self.e_residual = [] + self.a_residual = [] + self.edge_info_dim = self.n_dim * 2 + self.e_dim + + # node self mlp + self.node_self_mlp = NativeLayer( + n_dim, + n_dim, + precision=precision, + seed=child_seed(seed, 0), + ) + if self.update_style == "res_residual": + self.n_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 1), + ) + ) + + # node sym (grrg + drrd) + self.n_sym_dim = n_dim * self.axis_neuron + e_dim * self.axis_neuron + self.node_sym_linear = NativeLayer( + self.n_sym_dim, + n_dim, + precision=precision, + seed=child_seed(seed, 2), + ) + if self.update_style == "res_residual": + self.n_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 3), + ) + ) + + # node edge message + self.node_edge_linear = NativeLayer( + self.edge_info_dim, + self.n_multi_edge_message * n_dim, + precision=precision, + seed=child_seed(seed, 4), + ) + if self.update_style == "res_residual": + for head_index in range(self.n_multi_edge_message): + self.n_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(child_seed(seed, 5), head_index), + ) + ) + + # edge self message + self.edge_self_linear = NativeLayer( + self.edge_info_dim, + e_dim, + precision=precision, + seed=child_seed(seed, 6), + ) + if self.update_style == "res_residual": + self.e_residual.append( + get_residual( + e_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 7), + ) + ) + + if self.update_angle: + self.angle_dim = self.a_dim + if self.a_compress_rate == 0: + # angle + node + edge * 2 + self.angle_dim += self.n_dim + 2 * self.e_dim + self.a_compress_n_linear = None + self.a_compress_e_linear = None + self.e_a_compress_dim = e_dim + self.n_a_compress_dim = n_dim + else: + # angle + a_dim/c + a_dim/2c * 2 * e_rate + self.angle_dim += (1 + self.a_compress_e_rate) * ( + self.a_dim // self.a_compress_rate + ) + self.e_a_compress_dim = ( + self.a_dim // (2 * self.a_compress_rate) * self.a_compress_e_rate + ) + self.n_a_compress_dim = self.a_dim // self.a_compress_rate + if not self.a_compress_use_split: + self.a_compress_n_linear = NativeLayer( + self.n_dim, + self.n_a_compress_dim, + precision=precision, + bias=False, + seed=child_seed(seed, 8), + ) + self.a_compress_e_linear = NativeLayer( + self.e_dim, + self.e_a_compress_dim, + precision=precision, + bias=False, + seed=child_seed(seed, 9), + ) + else: + self.a_compress_n_linear = None + self.a_compress_e_linear = None + + # edge angle message + self.edge_angle_linear1 = NativeLayer( + self.angle_dim, + self.e_dim, + precision=precision, + seed=child_seed(seed, 10), + ) + self.edge_angle_linear2 = NativeLayer( + self.e_dim, + self.e_dim, + precision=precision, + seed=child_seed(seed, 11), + ) + if self.update_style == "res_residual": + self.e_residual.append( + get_residual( + self.e_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 12), + ) + ) + + # angle self message + self.angle_self_linear = NativeLayer( + self.angle_dim, + self.a_dim, + precision=precision, + seed=child_seed(seed, 13), + ) + if self.update_style == "res_residual": + self.a_residual.append( + get_residual( + self.a_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 14), + ) + ) + else: + self.angle_self_linear = None + self.edge_angle_linear1 = None + self.edge_angle_linear2 = None + self.a_compress_n_linear = None + self.a_compress_e_linear = None + self.angle_dim = 0 + + def optim_angle_update( + self, + angle_ebd: np.ndarray, + node_ebd: np.ndarray, + edge_ebd: np.ndarray, + feat: str = "edge", + ) -> np.ndarray: + xp = array_api_compat.array_namespace(angle_ebd, node_ebd, edge_ebd) + angle_dim = angle_ebd.shape[-1] + node_dim = node_ebd.shape[-1] + edge_dim = edge_ebd.shape[-1] + sub_angle_idx = (0, angle_dim) + sub_node_idx = (angle_dim, angle_dim + node_dim) + sub_edge_idx_ij = (angle_dim + node_dim, angle_dim + node_dim + edge_dim) + sub_edge_idx_ik = ( + angle_dim + node_dim + edge_dim, + angle_dim + node_dim + 2 * edge_dim, + ) + + if feat == "edge": + matrix, bias = self.edge_angle_linear1.w, self.edge_angle_linear1.b + elif feat == "angle": + matrix, bias = self.angle_self_linear.w, self.angle_self_linear.b + else: + raise NotImplementedError + assert angle_dim + node_dim + 2 * edge_dim == matrix.shape[0] + + # nf * nloc * a_sel * a_sel * angle_dim + sub_angle_update = xp.matmul( + angle_ebd, matrix[sub_angle_idx[0] : sub_angle_idx[1], :] + ) + + # nf * nloc * angle_dim + sub_node_update = xp.matmul( + node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1], :] + ) + + # nf * nloc * a_nnei * angle_dim + sub_edge_update_ij = xp.matmul( + edge_ebd, matrix[sub_edge_idx_ij[0] : sub_edge_idx_ij[1], :] + ) + sub_edge_update_ik = xp.matmul( + edge_ebd, matrix[sub_edge_idx_ik[0] : sub_edge_idx_ik[1], :] + ) + + result_update = ( + sub_angle_update + + sub_node_update[:, :, xp.newaxis, xp.newaxis, :] + + sub_edge_update_ij[:, :, xp.newaxis, :, :] + + sub_edge_update_ik[:, :, :, xp.newaxis, :] + ) + bias + return result_update + + def optim_edge_update( + self, + node_ebd: np.ndarray, + node_ebd_ext: np.ndarray, + edge_ebd: np.ndarray, + nlist: np.ndarray, + feat: str = "node", + ) -> np.ndarray: + xp = array_api_compat.array_namespace(node_ebd, node_ebd_ext, edge_ebd, nlist) + node_dim = node_ebd.shape[-1] + edge_dim = edge_ebd.shape[-1] + sub_node_idx = (0, node_dim) + sub_node_ext_idx = (node_dim, 2 * node_dim) + sub_edge_idx = (2 * node_dim, 2 * node_dim + edge_dim) + + if feat == "node": + matrix, bias = self.node_edge_linear.w, self.node_edge_linear.b + elif feat == "edge": + matrix, bias = self.edge_self_linear.w, self.edge_self_linear.b + else: + raise NotImplementedError + assert 2 * node_dim + edge_dim == matrix.shape[0] + + # nf * nloc * node/edge_dim + sub_node_update = xp.matmul( + node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1], :] + ) + + # nf * nall * node/edge_dim + sub_node_ext_update = xp.matmul( + node_ebd_ext, matrix[sub_node_ext_idx[0] : sub_node_ext_idx[1], :] + ) + # nf * nloc * nnei * node/edge_dim + sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist) + + # nf * nloc * nnei * node/edge_dim + sub_edge_update = xp.matmul( + edge_ebd, matrix[sub_edge_idx[0] : sub_edge_idx[1], :] + ) + + result_update = ( + sub_edge_update + sub_node_ext_update + sub_node_update[:, :, xp.newaxis, :] + ) + bias + return result_update + + def call( + self, + node_ebd_ext: np.ndarray, # nf x nall x n_dim + edge_ebd: np.ndarray, # nf x nloc x nnei x e_dim + h2: np.ndarray, # nf x nloc x nnei x 3 + angle_ebd: np.ndarray, # nf x nloc x a_nnei x a_nnei x a_dim + nlist: np.ndarray, # nf x nloc x nnei + nlist_mask: np.ndarray, # nf x nloc x nnei + sw: np.ndarray, # switch func, nf x nloc x nnei + a_nlist: np.ndarray, # nf x nloc x a_nnei + a_nlist_mask: np.ndarray, # nf x nloc x a_nnei + a_sw: np.ndarray, # switch func, nf x nloc x a_nnei + ): + """ + Parameters + ---------- + node_ebd_ext : nf x nall x n_dim + Extended node embedding. + edge_ebd : nf x nloc x nnei x e_dim + Edge embedding. + h2 : nf x nloc x nnei x 3 + Pair-atom channel, equivariant. + angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim + Angle embedding. + nlist : nf x nloc x nnei + Neighbor list. (padded neis are set to 0) + nlist_mask : nf x nloc x nnei + Masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei + Switch function. + a_nlist : nf x nloc x a_nnei + Neighbor list for angle. (padded neis are set to 0) + a_nlist_mask : nf x nloc x a_nnei + Masks of the neighbor list for angle. real nei 1 otherwise 0 + a_sw : nf x nloc x a_nnei + Switch function for angle. + + Returns + ------- + n_updated: nf x nloc x n_dim + Updated node embedding. + e_updated: nf x nloc x nnei x e_dim + Updated edge embedding. + a_updated : nf x nloc x a_nnei x a_nnei x a_dim + Updated angle embedding. + """ + xp = array_api_compat.array_namespace( + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist, + a_nlist_mask, + a_sw, + ) + nb, nloc, nnei, _ = edge_ebd.shape + nall = node_ebd_ext.shape[1] + node_ebd = node_ebd_ext[:, :nloc, :] + assert (nb, nloc) == node_ebd.shape[:2] + assert (nb, nloc, nnei) == h2.shape[:3] + del a_nlist # may be used in the future + + n_update_list: list[np.ndarray] = [node_ebd] + e_update_list: list[np.ndarray] = [edge_ebd] + a_update_list: list[np.ndarray] = [angle_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) + + # node sym (grrg + drrd) + node_sym_list: list[np.ndarray] = [] + node_sym_list.append( + symmetrization_op( + edge_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + ) + node_sym_list.append( + symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + ) + node_sym = self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) + n_update_list.append(node_sym) + + if not self.optim_update: + # nb x nloc x nnei x (n_dim * 2 + e_dim) + edge_info = xp.concat( + [ + xp.tile( + xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), + (1, 1, self.nnei, 1), + ), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) + else: + edge_info = None + + # node edge message + # nb x nloc x nnei x (h * n_dim) + if not self.optim_update: + assert edge_info is not None + node_edge_update = self.act( + self.node_edge_linear(edge_info) + ) * xp.expand_dims(sw, axis=-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "node", + ) + ) * xp.expand_dims(sw, axis=-1) + + node_edge_update = xp.sum(node_edge_update, axis=-2) / self.nnei + if self.n_multi_edge_message > 1: + # nb x nloc x nnei x h x n_dim + node_edge_update_mul_head = xp.reshape( + node_edge_update, (nb, nloc, self.n_multi_edge_message, self.n_dim) + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + else: + n_update_list.append(node_edge_update) + # update node_ebd + n_updated = self.list_update(n_update_list, "node") + + # edge self message + if not self.optim_update: + assert edge_info is not None + edge_self_update = self.act(self.edge_self_linear(edge_info)) + else: + edge_self_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "edge", + ) + ) + e_update_list.append(edge_self_update) + + if self.update_angle: + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + # get angle info + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) + else: + # use the first a_compress_dim dim for node and edge + node_ebd_for_angle = node_ebd[:, :, : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd[:, :, :, : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd + + # nb x nloc x a_nnei x e_dim + edge_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :] + # nb x nloc x a_nnei x e_dim + edge_for_angle = xp.where( + xp.expand_dims(a_nlist_mask, axis=-1), + edge_for_angle, + xp.zeros_like(edge_for_angle), + ) + if not self.optim_update: + # nb x nloc x a_nnei x a_nnei x n_dim + node_for_angle_info = xp.tile( + xp.reshape( + node_ebd_for_angle, (nb, nloc, 1, 1, self.n_a_compress_dim) + ), + (1, 1, self.a_sel, self.a_sel, 1), + ) + # nb x nloc x (a_nnei) x a_nnei x edge_ebd + edge_for_angle_i = xp.tile( + xp.reshape( + edge_for_angle, (nb, nloc, 1, self.a_sel, self.e_a_compress_dim) + ), + (1, 1, self.a_sel, 1, 1), + ) + # nb x nloc x a_nnei x (a_nnei) x e_dim + edge_for_angle_j = xp.tile( + xp.reshape( + edge_for_angle, (nb, nloc, self.a_sel, 1, self.e_a_compress_dim) + ), + (1, 1, 1, self.a_sel, 1), + ) + # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) + edge_for_angle_info = xp.concat( + [edge_for_angle_i, edge_for_angle_j], axis=-1 + ) + angle_info_list = [angle_ebd] + angle_info_list.append(node_for_angle_info) + angle_info_list.append(edge_for_angle_info) + # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) + angle_info = xp.concat(angle_info_list, axis=-1) + else: + angle_info = None + + # edge angle message + # nb x nloc x a_nnei x a_nnei x e_dim + if not self.optim_update: + assert angle_info is not None + edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) + else: + edge_angle_update = self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_for_angle, + "edge", + ) + ) + + # nb x nloc x a_nnei x a_nnei x e_dim + weighted_edge_angle_update = ( + edge_angle_update + * a_sw[:, :, :, xp.newaxis, xp.newaxis] + * a_sw[:, :, xp.newaxis, :, xp.newaxis] + ) + # nb x nloc x a_nnei x e_dim + reduced_edge_angle_update = xp.sum(weighted_edge_angle_update, axis=-2) / ( + self.a_sel**0.5 + ) + # nb x nloc x nnei x e_dim + padding_edge_angle_update = xp.concat( + [ + reduced_edge_angle_update, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel, self.e_dim), + dtype=edge_ebd.dtype, + ), + ], + axis=2, + ) + full_mask = xp.concat( + [ + a_nlist_mask, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel), + dtype=a_nlist_mask.dtype, + ), + ], + axis=-1, + ) + padding_edge_angle_update = xp.where( + xp.expand_dims(full_mask, axis=-1), padding_edge_angle_update, edge_ebd + ) + e_update_list.append( + self.act(self.edge_angle_linear2(padding_edge_angle_update)) + ) + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") + + # angle self message + # nb x nloc x a_nnei x a_nnei x dim_a + if not self.optim_update: + assert angle_info is not None + angle_self_update = self.act(self.angle_self_linear(angle_info)) + else: + angle_self_update = self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_for_angle, + "angle", + ) + ) + a_update_list.append(angle_self_update) + else: + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") + + # update angle_ebd + a_updated = self.list_update(a_update_list, "angle") + return n_updated, e_updated, a_updated + + def list_update_res_avg( + self, + update_list: list[np.ndarray], + ) -> np.ndarray: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + def list_update_res_incr(self, update_list: list[np.ndarray]) -> np.ndarray: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + def list_update_res_residual( + self, update_list: list[np.ndarray], update_name: str = "node" + ) -> np.ndarray: + nitem = len(update_list) + uu = update_list[0] + if update_name == "node": + for ii, vv in enumerate(self.n_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "edge": + for ii, vv in enumerate(self.e_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "angle": + for ii, vv in enumerate(self.a_residual): + uu = uu + vv * update_list[ii + 1] + else: + raise NotImplementedError + return uu + + def list_update( + self, update_list: list[np.ndarray], update_name: str = "node" + ) -> np.ndarray: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + elif self.update_style == "res_residual": + return self.list_update_res_residual(update_list, update_name=update_name) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + data = { + "@class": "RepformerLayer", + "@version": 1, + "e_rcut": self.e_rcut, + "e_rcut_smth": self.e_rcut_smth, + "e_sel": self.e_sel, + "a_rcut": self.a_rcut, + "a_rcut_smth": self.a_rcut_smth, + "a_sel": self.a_sel, + "ntypes": self.ntypes, + "n_dim": self.n_dim, + "e_dim": self.e_dim, + "a_dim": self.a_dim, + "a_compress_rate": self.a_compress_rate, + "a_compress_e_rate": self.a_compress_e_rate, + "a_compress_use_split": self.a_compress_use_split, + "n_multi_edge_message": self.n_multi_edge_message, + "axis_neuron": self.axis_neuron, + "activation_function": self.activation_function, + "update_angle": self.update_angle, + "update_style": self.update_style, + "update_residual": self.update_residual, + "update_residual_init": self.update_residual_init, + "precision": self.precision, + "optim_update": self.optim_update, + "node_self_mlp": self.node_self_mlp.serialize(), + "node_sym_linear": self.node_sym_linear.serialize(), + "node_edge_linear": self.node_edge_linear.serialize(), + "edge_self_linear": self.edge_self_linear.serialize(), + } + if self.update_angle: + data.update( + { + "edge_angle_linear1": self.edge_angle_linear1.serialize(), + "edge_angle_linear2": self.edge_angle_linear2.serialize(), + "angle_self_linear": self.angle_self_linear.serialize(), + } + ) + if self.a_compress_rate != 0 and not self.a_compress_use_split: + data.update( + { + "a_compress_n_linear": self.a_compress_n_linear.serialize(), + "a_compress_e_linear": self.a_compress_e_linear.serialize(), + } + ) + if self.update_style == "res_residual": + data.update( + { + "@variables": { + "n_residual": [to_numpy_array(t) for t in self.n_residual], + "e_residual": [to_numpy_array(t) for t in self.e_residual], + "a_residual": [to_numpy_array(t) for t in self.a_residual], + } + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "RepFlowLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + update_angle = data["update_angle"] + a_compress_rate = data["a_compress_rate"] + a_compress_use_split = data["a_compress_use_split"] + node_self_mlp = data.pop("node_self_mlp") + node_sym_linear = data.pop("node_sym_linear") + node_edge_linear = data.pop("node_edge_linear") + edge_self_linear = data.pop("edge_self_linear") + edge_angle_linear1 = data.pop("edge_angle_linear1", None) + edge_angle_linear2 = data.pop("edge_angle_linear2", None) + angle_self_linear = data.pop("angle_self_linear", None) + a_compress_n_linear = data.pop("a_compress_n_linear", None) + a_compress_e_linear = data.pop("a_compress_e_linear", None) + update_style = data["update_style"] + variables = data.pop("@variables", {}) + n_residual = variables.get("n_residual", data.pop("n_residual", [])) + e_residual = variables.get("e_residual", data.pop("e_residual", [])) + a_residual = variables.get("a_residual", data.pop("a_residual", [])) + + obj = cls(**data) + obj.node_self_mlp = NativeLayer.deserialize(node_self_mlp) + obj.node_sym_linear = NativeLayer.deserialize(node_sym_linear) + obj.node_edge_linear = NativeLayer.deserialize(node_edge_linear) + obj.edge_self_linear = NativeLayer.deserialize(edge_self_linear) + + if update_angle: + assert isinstance(edge_angle_linear1, dict) + assert isinstance(edge_angle_linear2, dict) + assert isinstance(angle_self_linear, dict) + obj.edge_angle_linear1 = NativeLayer.deserialize(edge_angle_linear1) + obj.edge_angle_linear2 = NativeLayer.deserialize(edge_angle_linear2) + obj.angle_self_linear = NativeLayer.deserialize(angle_self_linear) + if a_compress_rate != 0 and not a_compress_use_split: + assert isinstance(a_compress_n_linear, dict) + assert isinstance(a_compress_e_linear, dict) + obj.a_compress_n_linear = NativeLayer.deserialize(a_compress_n_linear) + obj.a_compress_e_linear = NativeLayer.deserialize(a_compress_e_linear) + + if update_style == "res_residual": + obj.n_residual = n_residual + obj.e_residual = e_residual + obj.a_residual = a_residual + return obj diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 703cbdd339..f51308e881 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -317,6 +317,14 @@ def fn(x): # generated by GitHub Copilot return 1 / (1 + xp.exp(-x)) + return fn + elif activation_function == "silu": + + def fn(x): + xp = array_api_compat.array_namespace(x) + # generated by GitHub Copilot + return x / (1 + xp.exp(-x)) + return fn elif activation_function.lower() in ("none", "linear"): diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index de6489e6cf..cda2faf24d 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -5,6 +5,9 @@ from deepmd.jax.descriptor.dpa2 import ( DescrptDPA2, ) +from deepmd.jax.descriptor.dpa3 import ( + DescrptDPA3, +) from deepmd.jax.descriptor.hybrid import ( DescrptHybrid, ) @@ -27,6 +30,7 @@ __all__ = [ "DescrptDPA1", "DescrptDPA2", + "DescrptDPA3", "DescrptHybrid", "DescrptSeA", "DescrptSeAttenV2", diff --git a/deepmd/jax/descriptor/dpa3.py b/deepmd/jax/descriptor/dpa3.py new file mode 100644 index 0000000000..460ccd385e --- /dev/null +++ b/deepmd/jax/descriptor/dpa3.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.descriptor.repflows import ( + DescrptBlockRepflows, +) +from deepmd.jax.utils.type_embed import ( + TypeEmbedNet, +) + + +@BaseDescriptor.register("dpa3") +@flax_module +class DescrptDPA3(DescrptDPA3DP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"repflows"}: + value = DescrptBlockRepflows.deserialize(value.serialize()) + elif name in {"type_embedding"}: + value = TypeEmbedNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/deepmd/jax/descriptor/repflows.py b/deepmd/jax/descriptor/repflows.py new file mode 100644 index 0000000000..ac32462287 --- /dev/null +++ b/deepmd/jax/descriptor/repflows.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.repflows import ( + DescrptBlockRepflows as DescrptBlockRepflowsDP, +) +from deepmd.dpmodel.descriptor.repflows import RepFlowLayer as RepFlowLayerDP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.jax.utils.network import ( + NativeLayer, +) + + +@flax_module +class DescrptBlockRepflows(DescrptBlockRepflowsDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"layers"}: + value = [RepFlowLayer.deserialize(layer.serialize()) for layer in value] + elif name in {"edge_embd", "angle_embd"}: + value = NativeLayer.deserialize(value.serialize()) + elif name in {"env_mat_edge", "env_mat_angle"}: + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +@flax_module +class RepFlowLayer(RepFlowLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in { + "node_self_mlp", + "node_sym_linear", + "node_edge_linear", + "edge_self_linear", + "a_compress_n_linear", + "a_compress_e_linear", + "edge_angle_linear1", + "edge_angle_linear2", + "angle_self_linear", + }: + if value is not None: + value = NativeLayer.deserialize(value.serialize()) + elif name in {"n_residual", "e_residual", "a_residual"}: + value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value] + return super().__setattr__(name, value) diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 4a227918fe..9f3468d1db 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -13,6 +13,9 @@ from .dpa2 import ( DescrptDPA2, ) +from .dpa3 import ( + DescrptDPA3, +) from .env_mat import ( prod_env_mat, ) @@ -49,6 +52,7 @@ "DescrptBlockSeTTebd", "DescrptDPA1", "DescrptDPA2", + "DescrptDPA3", "DescrptHybrid", "DescrptSeA", "DescrptSeAttenV2", diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py new file mode 100644 index 0000000000..9cdbce6f26 --- /dev/null +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -0,0 +1,572 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import torch + +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.model.network.network import ( + TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) +from deepmd.pt.utils.update_sel import ( + UpdateSel, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_pair_exclude_types, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .descriptor import ( + extend_descrpt_stat, +) +from .repflow_layer import ( + RepFlowLayer, +) +from .repflows import ( + DescrptBlockRepflows, +) + + +@BaseDescriptor.register("dpa3") +class DescrptDPA3(BaseDescriptor, torch.nn.Module): + def __init__( + self, + ntypes: int, + # args for repflow + repflow: Union[RepFlowArgs, dict], + # kwargs for descriptor + concat_output_tebd: bool = False, + activation_function: str = "silu", + precision: str = "float64", + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + trainable: bool = True, + seed: Optional[Union[int, list[int]]] = None, + use_econf_tebd: bool = False, + use_tebd_bias: bool = False, + type_map: Optional[list[str]] = None, + ) -> None: + r"""The DPA-3 descriptor. + + Parameters + ---------- + repflow : Union[RepFlowArgs, dict] + The arguments used to initialize the repflow block, see docstr in `RepFlowArgs` for details information. + concat_output_tebd : bool, optional + Whether to concat type embedding at the output of the descriptor. + activation_function : str, optional + The activation function in the embedding net. + precision : str, optional + The precision of the embedding net parameters. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable : bool, optional + If the parameters are trainable. + seed : int, optional + Random seed for parameter initialization. + use_econf_tebd : bool, Optional + Whether to use electronic configuration type embedding. + use_tebd_bias : bool, Optional + Whether to use bias in the type embedding layer. + type_map : list[str], Optional + A list of strings. Give the name to each type of atoms. + + Returns + ------- + descriptor: torch.Tensor + the descriptor of shape nb x nloc x n_dim. + invariant single-atom representation. + g2: torch.Tensor + invariant pair-atom representation. + h2: torch.Tensor + equivariant pair-atom representation. + rot_mat: torch.Tensor + rotation matrix for equivariant fittings + sw: torch.Tensor + The switch function for decaying inverse distance. + + """ + super().__init__() + + def init_subclass_params(sub_data, sub_class): + if isinstance(sub_data, dict): + return sub_class(**sub_data) + elif isinstance(sub_data, sub_class): + return sub_data + else: + raise ValueError( + f"Input args must be a {sub_class.__name__} class or a dict!" + ) + + self.repflow_args = init_subclass_params(repflow, RepFlowArgs) + self.activation_function = activation_function + + self.repflows = DescrptBlockRepflows( + self.repflow_args.e_rcut, + self.repflow_args.e_rcut_smth, + self.repflow_args.e_sel, + self.repflow_args.a_rcut, + self.repflow_args.a_rcut_smth, + self.repflow_args.a_sel, + ntypes, + nlayers=self.repflow_args.nlayers, + n_dim=self.repflow_args.n_dim, + e_dim=self.repflow_args.e_dim, + a_dim=self.repflow_args.a_dim, + a_compress_rate=self.repflow_args.a_compress_rate, + a_compress_e_rate=self.repflow_args.a_compress_e_rate, + a_compress_use_split=self.repflow_args.a_compress_use_split, + n_multi_edge_message=self.repflow_args.n_multi_edge_message, + axis_neuron=self.repflow_args.axis_neuron, + update_angle=self.repflow_args.update_angle, + activation_function=self.activation_function, + update_style=self.repflow_args.update_style, + update_residual=self.repflow_args.update_residual, + update_residual_init=self.repflow_args.update_residual_init, + fix_stat_std=self.repflow_args.fix_stat_std, + optim_update=self.repflow_args.optim_update, + exclude_types=exclude_types, + env_protection=env_protection, + precision=precision, + seed=child_seed(seed, 1), + ) + + self.use_econf_tebd = use_econf_tebd + self.use_tebd_bias = use_tebd_bias + self.type_map = type_map + self.tebd_dim = self.repflow_args.n_dim + self.type_embedding = TypeEmbedNet( + ntypes, + self.tebd_dim, + precision=precision, + seed=child_seed(seed, 2), + use_econf_tebd=self.use_econf_tebd, + use_tebd_bias=use_tebd_bias, + type_map=type_map, + ) + self.concat_output_tebd = concat_output_tebd + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.exclude_types = exclude_types + self.env_protection = env_protection + self.trainable = trainable + + assert self.repflows.e_rcut >= self.repflows.a_rcut + assert self.repflows.e_sel >= self.repflows.a_sel + + self.rcut = self.repflows.get_rcut() + self.rcut_smth = self.repflows.get_rcut_smth() + self.sel = self.repflows.get_sel() + self.ntypes = ntypes + + # set trainable + for param in self.parameters(): + param.requires_grad = trainable + self.compress = False + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + ret = self.repflows.dim_out + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + """Returns the embedding dimension of this descriptor.""" + return self.repflows.dim_emb + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.repflows.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.repflows.get_env_protection() + + def share_params(self, base_class, shared_level, resume=False) -> None: + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some separated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + # For DPA3 descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in type_embedding, repflow + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + self.repflows.share_params(base_class.repflows, 0, resume=resume) + # shared_level: 1 + # share all parameters in type_embedding + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + # Other shared levels + else: + raise NotImplementedError + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert self.type_map is not None, ( + "'type_map' must be defined when performing type changing!" + ) + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + self.type_map = type_map + self.type_embedding.change_type_map(type_map=type_map) + self.exclude_types = map_pair_exclude_types(self.exclude_types, remap_index) + self.ntypes = len(type_map) + repflow = self.repflows + if has_new_type: + # the avg and std of new types need to be updated + extend_descrpt_stat( + repflow, + type_map, + des_with_stat=model_with_new_type_stat.repflows + if model_with_new_type_stat is not None + else None, + ) + repflow.ntypes = self.ntypes + repflow.reinit_exclude(self.exclude_types) + repflow["davg"] = repflow["davg"][remap_index] + repflow["dstd"] = repflow["dstd"][remap_index] + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + descrpt_list = [self.repflows] + for ii, descrpt in enumerate(descrpt_list): + descrpt.compute_input_stats(merged, path) + + def set_stat_mean_and_stddev( + self, + mean: list[torch.Tensor], + stddev: list[torch.Tensor], + ) -> None: + """Update mean and stddev for descriptor.""" + descrpt_list = [self.repflows] + for ii, descrpt in enumerate(descrpt_list): + descrpt.mean = mean[ii] + descrpt.stddev = stddev[ii] + + def get_stat_mean_and_stddev(self) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Get mean and stddev for descriptor.""" + mean_list = [self.repflows.mean] + stddev_list = [self.repflows.stddev] + return mean_list, stddev_list + + def serialize(self) -> dict: + repflows = self.repflows + data = { + "@class": "Descriptor", + "type": "dpa3", + "@version": 1, + "ntypes": self.ntypes, + "repflow_args": self.repflow_args.serialize(), + "concat_output_tebd": self.concat_output_tebd, + "activation_function": self.activation_function, + "precision": self.precision, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "trainable": self.trainable, + "use_econf_tebd": self.use_econf_tebd, + "use_tebd_bias": self.use_tebd_bias, + "type_map": self.type_map, + "type_embedding": self.type_embedding.embedding.serialize(), + } + repflow_variable = { + "edge_embd": repflows.edge_embd.serialize(), + "angle_embd": repflows.angle_embd.serialize(), + "repflow_layers": [layer.serialize() for layer in repflows.layers], + "env_mat": DPEnvMat(repflows.rcut, repflows.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(repflows["davg"]), + "dstd": to_numpy_array(repflows["dstd"]), + }, + } + data.update( + { + "repflow_variable": repflow_variable, + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA3": + data = data.copy() + version = data.pop("@version") + check_version_compatibility(version, 1, 1) + data.pop("@class") + data.pop("type") + repflow_variable = data.pop("repflow_variable").copy() + type_embedding = data.pop("type_embedding") + data["repflow"] = RepFlowArgs(**data.pop("repflow_args")) + obj = cls(**data) + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.repflows.prec, device=env.DEVICE) + + # deserialize repflow + statistic_repflows = repflow_variable.pop("@variables") + env_mat = repflow_variable.pop("env_mat") + repflow_layers = repflow_variable.pop("repflow_layers") + obj.repflows.edge_embd = MLPLayer.deserialize(repflow_variable.pop("edge_embd")) + obj.repflows.angle_embd = MLPLayer.deserialize( + repflow_variable.pop("angle_embd") + ) + obj.repflows["davg"] = t_cvt(statistic_repflows["davg"]) + obj.repflows["dstd"] = t_cvt(statistic_repflows["dstd"]) + obj.repflows.layers = torch.nn.ModuleList( + [RepFlowLayer.deserialize(layer) for layer in repflow_layers] + ) + return obj + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, mapps extended region index to local region. + comm_dict + The data needed for communication for parallel inference. + + Returns + ------- + node_ebd + The output descriptor. shape: nf x nloc x n_dim (or n_dim + tebd_dim) + rot_mat + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x e_dim x 3 + edge_ebd + The edge embedding. + shape: nf x nloc x nnei x e_dim + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + nframes, nloc, nnei = nlist.shape + nall = extended_coord.view(nframes, -1).shape[1] // 3 + + node_ebd_ext = self.type_embedding(extended_atype) + node_ebd_inp = node_ebd_ext[:, :nloc, :] + # repflows + node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows( + nlist, + extended_coord, + extended_atype, + node_ebd_ext, + mapping, + comm_dict=comm_dict, + ) + if self.concat_output_tebd: + node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1) + return ( + node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + update_sel = UpdateSel() + min_nbor_dist, repflow_e_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repflow"]["e_rcut"], + local_jdata_cpy["repflow"]["e_sel"], + True, + ) + local_jdata_cpy["repflow"]["e_sel"] = repflow_e_sel[0] + + min_nbor_dist, repflow_a_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repflow"]["a_rcut"], + local_jdata_cpy["repflow"]["a_sel"], + True, + ) + local_jdata_cpy["repflow"]["a_sel"] = repflow_a_sel[0] + + return local_jdata_cpy, min_nbor_dist + + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + raise NotImplementedError("Compression is unsupported for DPA3.") diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py new file mode 100644 index 0000000000..43cae8c746 --- /dev/null +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -0,0 +1,938 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, + Union, +) + +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.descriptor.repformer_layer import ( + _apply_nlist_mask, + _apply_switch, + _make_nei_g1, + get_residual, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + ActivationFn, + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +class RepFlowLayer(torch.nn.Module): + def __init__( + self, + e_rcut: float, + e_rcut_smth: float, + e_sel: int, + a_rcut: float, + a_rcut_smth: float, + a_sel: int, + ntypes: int, + n_dim: int = 128, + e_dim: int = 16, + a_dim: int = 64, + a_compress_rate: int = 0, + a_compress_use_split: bool = False, + a_compress_e_rate: int = 1, + n_multi_edge_message: int = 1, + axis_neuron: int = 4, + update_angle: bool = True, + optim_update: bool = True, + activation_function: str = "silu", + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.e_rcut = float(e_rcut) + self.e_rcut_smth = float(e_rcut_smth) + self.ntypes = ntypes + e_sel = [e_sel] if isinstance(e_sel, int) else e_sel + self.nnei = sum(e_sel) + assert len(e_sel) == 1 + self.e_sel = e_sel + self.sec = self.e_sel + self.a_rcut = a_rcut + self.a_rcut_smth = a_rcut_smth + self.a_sel = a_sel + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.a_compress_rate = a_compress_rate + if a_compress_rate != 0: + assert (a_dim * a_compress_e_rate) % (2 * a_compress_rate) == 0, ( + f"For a_compress_rate of {a_compress_rate}, a_dim*a_compress_e_rate must be divisible by {2 * a_compress_rate}. " + f"Currently, a_dim={a_dim} and a_compress_e_rate={a_compress_e_rate} is not valid." + ) + self.n_multi_edge_message = n_multi_edge_message + assert self.n_multi_edge_message >= 1, "n_multi_edge_message must >= 1!" + self.axis_neuron = axis_neuron + self.update_angle = update_angle + self.activation_function = activation_function + self.act = ActivationFn(activation_function) + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.a_compress_e_rate = a_compress_e_rate + self.a_compress_use_split = a_compress_use_split + self.precision = precision + self.seed = seed + self.prec = PRECISION_DICT[precision] + self.optim_update = optim_update + + assert update_residual_init in [ + "norm", + "const", + ], "'update_residual_init' only support 'norm' or 'const'!" + + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.n_residual = [] + self.e_residual = [] + self.a_residual = [] + self.edge_info_dim = self.n_dim * 2 + self.e_dim + + # node self mlp + self.node_self_mlp = MLPLayer( + n_dim, + n_dim, + precision=precision, + seed=child_seed(seed, 0), + ) + if self.update_style == "res_residual": + self.n_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 1), + ) + ) + + # node sym (grrg + drrd) + self.n_sym_dim = n_dim * self.axis_neuron + e_dim * self.axis_neuron + self.node_sym_linear = MLPLayer( + self.n_sym_dim, + n_dim, + precision=precision, + seed=child_seed(seed, 2), + ) + if self.update_style == "res_residual": + self.n_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 3), + ) + ) + + # node edge message + self.node_edge_linear = MLPLayer( + self.edge_info_dim, + self.n_multi_edge_message * n_dim, + precision=precision, + seed=child_seed(seed, 4), + ) + if self.update_style == "res_residual": + for head_index in range(self.n_multi_edge_message): + self.n_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(child_seed(seed, 5), head_index), + ) + ) + + # edge self message + self.edge_self_linear = MLPLayer( + self.edge_info_dim, + e_dim, + precision=precision, + seed=child_seed(seed, 6), + ) + if self.update_style == "res_residual": + self.e_residual.append( + get_residual( + e_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 7), + ) + ) + + if self.update_angle: + self.angle_dim = self.a_dim + if self.a_compress_rate == 0: + # angle + node + edge * 2 + self.angle_dim += self.n_dim + 2 * self.e_dim + self.a_compress_n_linear = None + self.a_compress_e_linear = None + self.e_a_compress_dim = e_dim + self.n_a_compress_dim = n_dim + else: + # angle + a_dim/c + a_dim/2c * 2 * e_rate + self.angle_dim += (1 + self.a_compress_e_rate) * ( + self.a_dim // self.a_compress_rate + ) + self.e_a_compress_dim = ( + self.a_dim // (2 * self.a_compress_rate) * self.a_compress_e_rate + ) + self.n_a_compress_dim = self.a_dim // self.a_compress_rate + if not self.a_compress_use_split: + self.a_compress_n_linear = MLPLayer( + self.n_dim, + self.n_a_compress_dim, + precision=precision, + bias=False, + seed=child_seed(seed, 8), + ) + self.a_compress_e_linear = MLPLayer( + self.e_dim, + self.e_a_compress_dim, + precision=precision, + bias=False, + seed=child_seed(seed, 9), + ) + else: + self.a_compress_n_linear = None + self.a_compress_e_linear = None + + # edge angle message + self.edge_angle_linear1 = MLPLayer( + self.angle_dim, + self.e_dim, + precision=precision, + seed=child_seed(seed, 10), + ) + self.edge_angle_linear2 = MLPLayer( + self.e_dim, + self.e_dim, + precision=precision, + seed=child_seed(seed, 11), + ) + if self.update_style == "res_residual": + self.e_residual.append( + get_residual( + self.e_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 12), + ) + ) + + # angle self message + self.angle_self_linear = MLPLayer( + self.angle_dim, + self.a_dim, + precision=precision, + seed=child_seed(seed, 13), + ) + if self.update_style == "res_residual": + self.a_residual.append( + get_residual( + self.a_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 14), + ) + ) + else: + self.angle_self_linear = None + self.edge_angle_linear1 = None + self.edge_angle_linear2 = None + self.a_compress_n_linear = None + self.a_compress_e_linear = None + self.angle_dim = 0 + + self.n_residual = nn.ParameterList(self.n_residual) + self.e_residual = nn.ParameterList(self.e_residual) + self.a_residual = nn.ParameterList(self.a_residual) + + @staticmethod + def _cal_hg( + edge_ebd: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + edge_ebd + Neighbor-wise/Pair-wise edge embeddings, with shape nb x nloc x nnei x e_dim. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape nb x nloc x 3 x e_dim. + """ + # edge_ebd: nb x nloc x nnei x e_dim + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = edge_ebd.shape + e_dim = edge_ebd.shape[-1] + # nb x nloc x nnei x e_dim + edge_ebd = _apply_nlist_mask(edge_ebd, nlist_mask) + edge_ebd = _apply_switch(edge_ebd, sw) + invnnei = torch.rsqrt( + float(nnei) + * torch.ones((nb, nloc, 1, 1), dtype=edge_ebd.dtype, device=edge_ebd.device) + ) + # nb x nloc x 3 x e_dim + h2g2 = torch.matmul(torch.transpose(h2, -1, -2), edge_ebd) * invnnei + return h2g2 + + @staticmethod + def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + h2g2 + The transposed rotation matrix, with shape nb x nloc x 3 x e_dim. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim) + """ + # nb x nloc x 3 x e_dim + nb, nloc, _, e_dim = h2g2.shape + # nb x nloc x 3 x axis + h2g2m = h2g2[..., :axis_neuron] + # nb x nloc x axis x e_dim + g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) + # nb x nloc x (axisxng2) + g1_13 = g1_13.view(nb, nloc, axis_neuron * e_dim) + return g1_13 + + def symmetrization_op( + self, + edge_ebd: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + axis_neuron: int, + ) -> torch.Tensor: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + edge_ebd + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x e_dim. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim) + """ + # edge_ebd: nb x nloc x nnei x e_dim + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = edge_ebd.shape + # nb x nloc x 3 x e_dim + h2g2 = self._cal_hg( + edge_ebd, + h2, + nlist_mask, + sw, + ) + # nb x nloc x (axisxng2) + g1_13 = self._cal_grrg(h2g2, axis_neuron) + return g1_13 + + def optim_angle_update( + self, + angle_ebd: torch.Tensor, + node_ebd: torch.Tensor, + edge_ebd: torch.Tensor, + feat: str = "edge", + ) -> torch.Tensor: + angle_dim = angle_ebd.shape[-1] + node_dim = node_ebd.shape[-1] + edge_dim = edge_ebd.shape[-1] + sub_angle_idx = (0, angle_dim) + sub_node_idx = (angle_dim, angle_dim + node_dim) + sub_edge_idx_ij = (angle_dim + node_dim, angle_dim + node_dim + edge_dim) + sub_edge_idx_ik = ( + angle_dim + node_dim + edge_dim, + angle_dim + node_dim + 2 * edge_dim, + ) + + if feat == "edge": + matrix, bias = self.edge_angle_linear1.matrix, self.edge_angle_linear1.bias + elif feat == "angle": + matrix, bias = self.angle_self_linear.matrix, self.angle_self_linear.bias + else: + raise NotImplementedError + assert angle_dim + node_dim + 2 * edge_dim == matrix.size()[0] + + # nf * nloc * a_sel * a_sel * angle_dim + sub_angle_update = torch.matmul( + angle_ebd, matrix[sub_angle_idx[0] : sub_angle_idx[1]] + ) + + # nf * nloc * angle_dim + sub_node_update = torch.matmul( + node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1]] + ) + + # nf * nloc * a_nnei * angle_dim + sub_edge_update_ij = torch.matmul( + edge_ebd, matrix[sub_edge_idx_ij[0] : sub_edge_idx_ij[1]] + ) + sub_edge_update_ik = torch.matmul( + edge_ebd, matrix[sub_edge_idx_ik[0] : sub_edge_idx_ik[1]] + ) + + result_update = ( + sub_angle_update + + sub_node_update[:, :, None, None, :] + + sub_edge_update_ij[:, :, None, :, :] + + sub_edge_update_ik[:, :, :, None, :] + ) + bias + return result_update + + def optim_edge_update( + self, + node_ebd: torch.Tensor, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + nlist: torch.Tensor, + feat: str = "node", + ) -> torch.Tensor: + node_dim = node_ebd.shape[-1] + edge_dim = edge_ebd.shape[-1] + sub_node_idx = (0, node_dim) + sub_node_ext_idx = (node_dim, 2 * node_dim) + sub_edge_idx = (2 * node_dim, 2 * node_dim + edge_dim) + + if feat == "node": + matrix, bias = self.node_edge_linear.matrix, self.node_edge_linear.bias + elif feat == "edge": + matrix, bias = self.edge_self_linear.matrix, self.edge_self_linear.bias + else: + raise NotImplementedError + assert 2 * node_dim + edge_dim == matrix.size()[0] + + # nf * nloc * node/edge_dim + sub_node_update = torch.matmul( + node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1]] + ) + + # nf * nall * node/edge_dim + sub_node_ext_update = torch.matmul( + node_ebd_ext, matrix[sub_node_ext_idx[0] : sub_node_ext_idx[1]] + ) + # nf * nloc * nnei * node/edge_dim + sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist) + + # nf * nloc * nnei * node/edge_dim + sub_edge_update = torch.matmul( + edge_ebd, matrix[sub_edge_idx[0] : sub_edge_idx[1]] + ) + + result_update = ( + sub_edge_update + sub_node_ext_update + sub_node_update[:, :, None, :] + ) + bias + return result_update + + def forward( + self, + node_ebd_ext: torch.Tensor, # nf x nall x n_dim + edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim + h2: torch.Tensor, # nf x nloc x nnei x 3 + angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim + nlist: torch.Tensor, # nf x nloc x nnei + nlist_mask: torch.Tensor, # nf x nloc x nnei + sw: torch.Tensor, # switch func, nf x nloc x nnei + a_nlist: torch.Tensor, # nf x nloc x a_nnei + a_nlist_mask: torch.Tensor, # nf x nloc x a_nnei + a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei + ): + """ + Parameters + ---------- + node_ebd_ext : nf x nall x n_dim + Extended node embedding. + edge_ebd : nf x nloc x nnei x e_dim + Edge embedding. + h2 : nf x nloc x nnei x 3 + Pair-atom channel, equivariant. + angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim + Angle embedding. + nlist : nf x nloc x nnei + Neighbor list. (padded neis are set to 0) + nlist_mask : nf x nloc x nnei + Masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei + Switch function. + a_nlist : nf x nloc x a_nnei + Neighbor list for angle. (padded neis are set to 0) + a_nlist_mask : nf x nloc x a_nnei + Masks of the neighbor list for angle. real nei 1 otherwise 0 + a_sw : nf x nloc x a_nnei + Switch function for angle. + + Returns + ------- + n_updated: nf x nloc x n_dim + Updated node embedding. + e_updated: nf x nloc x nnei x e_dim + Updated edge embedding. + a_updated : nf x nloc x a_nnei x a_nnei x a_dim + Updated angle embedding. + """ + nb, nloc, nnei, _ = edge_ebd.shape + nall = node_ebd_ext.shape[1] + node_ebd = node_ebd_ext[:, :nloc, :] + assert (nb, nloc) == node_ebd.shape[:2] + assert (nb, nloc, nnei) == h2.shape[:3] + del a_nlist # may be used in the future + + n_update_list: list[torch.Tensor] = [node_ebd] + e_update_list: list[torch.Tensor] = [edge_ebd] + a_update_list: list[torch.Tensor] = [angle_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) + + # node sym (grrg + drrd) + node_sym_list: list[torch.Tensor] = [] + node_sym_list.append( + self.symmetrization_op( + edge_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + ) + node_sym_list.append( + self.symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + ) + node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) + n_update_list.append(node_sym) + + if not self.optim_update: + # nb x nloc x nnei x (n_dim * 2 + e_dim) + edge_info = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + else: + edge_info = None + + # node edge message + # nb x nloc x nnei x (h * n_dim) + if not self.optim_update: + assert edge_info is not None + node_edge_update = self.act( + self.node_edge_linear(edge_info) + ) * sw.unsqueeze(-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "node", + ) + ) * sw.unsqueeze(-1) + + node_edge_update = torch.sum(node_edge_update, dim=-2) / self.nnei + if self.n_multi_edge_message > 1: + # nb x nloc x nnei x h x n_dim + node_edge_update_mul_head = node_edge_update.view( + nb, nloc, self.n_multi_edge_message, self.n_dim + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + else: + n_update_list.append(node_edge_update) + # update node_ebd + n_updated = self.list_update(n_update_list, "node") + + # edge self message + if not self.optim_update: + assert edge_info is not None + edge_self_update = self.act(self.edge_self_linear(edge_info)) + else: + edge_self_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "edge", + ) + ) + e_update_list.append(edge_self_update) + + if self.update_angle: + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + # get angle info + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) + else: + # use the first a_compress_dim dim for node and edge + node_ebd_for_angle = node_ebd[:, :, : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd[:, :, :, : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd + + # nb x nloc x a_nnei x e_dim + edge_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :] + # nb x nloc x a_nnei x e_dim + edge_for_angle = torch.where( + a_nlist_mask.unsqueeze(-1), edge_for_angle, 0.0 + ) + if not self.optim_update: + # nb x nloc x a_nnei x a_nnei x n_dim + node_for_angle_info = torch.tile( + node_ebd_for_angle.unsqueeze(2).unsqueeze(2), + (1, 1, self.a_sel, self.a_sel, 1), + ) + # nb x nloc x (a_nnei) x a_nnei x edge_ebd + edge_for_angle_i = torch.tile( + edge_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) + ) + # nb x nloc x a_nnei x (a_nnei) x e_dim + edge_for_angle_j = torch.tile( + edge_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) + ) + # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) + edge_for_angle_info = torch.cat( + [edge_for_angle_i, edge_for_angle_j], dim=-1 + ) + angle_info_list = [angle_ebd] + angle_info_list.append(node_for_angle_info) + angle_info_list.append(edge_for_angle_info) + # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) + angle_info = torch.cat(angle_info_list, dim=-1) + else: + angle_info = None + + # edge angle message + # nb x nloc x a_nnei x a_nnei x e_dim + if not self.optim_update: + assert angle_info is not None + edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) + else: + edge_angle_update = self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_for_angle, + "edge", + ) + ) + + # nb x nloc x a_nnei x a_nnei x e_dim + weighted_edge_angle_update = ( + edge_angle_update + * a_sw[:, :, :, None, None] + * a_sw[:, :, None, :, None] + ) + # nb x nloc x a_nnei x e_dim + reduced_edge_angle_update = torch.sum( + weighted_edge_angle_update, dim=-2 + ) / (self.a_sel**0.5) + # nb x nloc x nnei x e_dim + padding_edge_angle_update = torch.concat( + [ + reduced_edge_angle_update, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel, self.e_dim], + dtype=edge_ebd.dtype, + device=edge_ebd.device, + ), + ], + dim=2, + ) + full_mask = torch.concat( + [ + a_nlist_mask, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel], + dtype=a_nlist_mask.dtype, + device=a_nlist_mask.device, + ), + ], + dim=-1, + ) + padding_edge_angle_update = torch.where( + full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd + ) + e_update_list.append( + self.act(self.edge_angle_linear2(padding_edge_angle_update)) + ) + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") + + # angle self message + # nb x nloc x a_nnei x a_nnei x dim_a + if not self.optim_update: + assert angle_info is not None + angle_self_update = self.act(self.angle_self_linear(angle_info)) + else: + angle_self_update = self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_for_angle, + "angle", + ) + ) + a_update_list.append(angle_self_update) + else: + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") + + # update angle_ebd + a_updated = self.list_update(a_update_list, "angle") + return n_updated, e_updated, a_updated + + @torch.jit.export + def list_update_res_avg( + self, + update_list: list[torch.Tensor], + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + @torch.jit.export + def list_update_res_incr(self, update_list: list[torch.Tensor]) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + @torch.jit.export + def list_update_res_residual( + self, update_list: list[torch.Tensor], update_name: str = "node" + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + # make jit happy + if update_name == "node": + for ii, vv in enumerate(self.n_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "edge": + for ii, vv in enumerate(self.e_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "angle": + for ii, vv in enumerate(self.a_residual): + uu = uu + vv * update_list[ii + 1] + else: + raise NotImplementedError + return uu + + @torch.jit.export + def list_update( + self, update_list: list[torch.Tensor], update_name: str = "node" + ) -> torch.Tensor: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + elif self.update_style == "res_residual": + return self.list_update_res_residual(update_list, update_name=update_name) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + data = { + "@class": "RepformerLayer", + "@version": 1, + "e_rcut": self.e_rcut, + "e_rcut_smth": self.e_rcut_smth, + "e_sel": self.e_sel, + "a_rcut": self.a_rcut, + "a_rcut_smth": self.a_rcut_smth, + "a_sel": self.a_sel, + "ntypes": self.ntypes, + "n_dim": self.n_dim, + "e_dim": self.e_dim, + "a_dim": self.a_dim, + "a_compress_rate": self.a_compress_rate, + "a_compress_e_rate": self.a_compress_e_rate, + "a_compress_use_split": self.a_compress_use_split, + "n_multi_edge_message": self.n_multi_edge_message, + "axis_neuron": self.axis_neuron, + "activation_function": self.activation_function, + "update_angle": self.update_angle, + "update_style": self.update_style, + "update_residual": self.update_residual, + "update_residual_init": self.update_residual_init, + "precision": self.precision, + "optim_update": self.optim_update, + "node_self_mlp": self.node_self_mlp.serialize(), + "node_sym_linear": self.node_sym_linear.serialize(), + "node_edge_linear": self.node_edge_linear.serialize(), + "edge_self_linear": self.edge_self_linear.serialize(), + } + if self.update_angle: + data.update( + { + "edge_angle_linear1": self.edge_angle_linear1.serialize(), + "edge_angle_linear2": self.edge_angle_linear2.serialize(), + "angle_self_linear": self.angle_self_linear.serialize(), + } + ) + if self.a_compress_rate != 0 and not self.a_compress_use_split: + data.update( + { + "a_compress_n_linear": self.a_compress_n_linear.serialize(), + "a_compress_e_linear": self.a_compress_e_linear.serialize(), + } + ) + if self.update_style == "res_residual": + data.update( + { + "@variables": { + "n_residual": [to_numpy_array(t) for t in self.n_residual], + "e_residual": [to_numpy_array(t) for t in self.e_residual], + "a_residual": [to_numpy_array(t) for t in self.a_residual], + } + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "RepFlowLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + update_angle = data["update_angle"] + a_compress_rate = data["a_compress_rate"] + a_compress_use_split = data["a_compress_use_split"] + node_self_mlp = data.pop("node_self_mlp") + node_sym_linear = data.pop("node_sym_linear") + node_edge_linear = data.pop("node_edge_linear") + edge_self_linear = data.pop("edge_self_linear") + edge_angle_linear1 = data.pop("edge_angle_linear1", None) + edge_angle_linear2 = data.pop("edge_angle_linear2", None) + angle_self_linear = data.pop("angle_self_linear", None) + a_compress_n_linear = data.pop("a_compress_n_linear", None) + a_compress_e_linear = data.pop("a_compress_e_linear", None) + update_style = data["update_style"] + variables = data.pop("@variables", {}) + n_residual = variables.get("n_residual", data.pop("n_residual", [])) + e_residual = variables.get("e_residual", data.pop("e_residual", [])) + a_residual = variables.get("a_residual", data.pop("a_residual", [])) + + obj = cls(**data) + obj.node_self_mlp = MLPLayer.deserialize(node_self_mlp) + obj.node_sym_linear = MLPLayer.deserialize(node_sym_linear) + obj.node_edge_linear = MLPLayer.deserialize(node_edge_linear) + obj.edge_self_linear = MLPLayer.deserialize(edge_self_linear) + + if update_angle: + assert isinstance(edge_angle_linear1, dict) + assert isinstance(edge_angle_linear2, dict) + assert isinstance(angle_self_linear, dict) + obj.edge_angle_linear1 = MLPLayer.deserialize(edge_angle_linear1) + obj.edge_angle_linear2 = MLPLayer.deserialize(edge_angle_linear2) + obj.angle_self_linear = MLPLayer.deserialize(angle_self_linear) + if a_compress_rate != 0 and not a_compress_use_split: + assert isinstance(a_compress_n_linear, dict) + assert isinstance(a_compress_e_linear, dict) + obj.a_compress_n_linear = MLPLayer.deserialize(a_compress_n_linear) + obj.a_compress_e_linear = MLPLayer.deserialize(a_compress_e_linear) + + if update_style == "res_residual": + for ii, t in enumerate(obj.n_residual): + t.data = to_torch_tensor(n_residual[ii]) + for ii, t in enumerate(obj.e_residual): + t.data = to_torch_tensor(e_residual[ii]) + for ii, t in enumerate(obj.a_residual): + t.data = to_torch_tensor(a_residual[ii]) + return obj diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py new file mode 100644 index 0000000000..7018ff32ba --- /dev/null +++ b/deepmd/pt/model/descriptor/repflows.py @@ -0,0 +1,603 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import torch + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.descriptor.descriptor import ( + DescriptorBlock, +) +from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSe, +) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pt.utils.spin import ( + concat_switch_virtual, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) + +from .repflow_layer import ( + RepFlowLayer, +) + +if not hasattr(torch.ops.deepmd, "border_op"): + + def border_op( + argument0, + argument1, + argument2, + argument3, + argument4, + argument5, + argument6, + argument7, + argument8, + ) -> torch.Tensor: + raise NotImplementedError( + "border_op is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for DPA-3 for details." + ) + + # Note: this hack cannot actually save a model that can be run using LAMMPS. + torch.ops.deepmd.border_op = border_op + + +@DescriptorBlock.register("se_repflow") +class DescrptBlockRepflows(DescriptorBlock): + def __init__( + self, + e_rcut, + e_rcut_smth, + e_sel: int, + a_rcut, + a_rcut_smth, + a_sel: int, + ntypes: int, + nlayers: int = 6, + n_dim: int = 128, + e_dim: int = 64, + a_dim: int = 64, + a_compress_rate: int = 0, + a_compress_e_rate: int = 1, + a_compress_use_split: bool = False, + n_multi_edge_message: int = 1, + axis_neuron: int = 4, + update_angle: bool = True, + activation_function: str = "silu", + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", + set_davg_zero: bool = True, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + precision: str = "float64", + fix_stat_std: float = 0.3, + optim_update: bool = True, + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + r""" + The repflow descriptor block. + + Parameters + ---------- + n_dim : int, optional + The dimension of node representation. + e_dim : int, optional + The dimension of edge representation. + a_dim : int, optional + The dimension of angle representation. + nlayers : int, optional + Number of repflow layers. + e_rcut : float, optional + The edge cut-off radius. + e_rcut_smth : float, optional + Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth. + e_sel : int, optional + Maximally possible number of selected edge neighbors. + a_rcut : float, optional + The angle cut-off radius. + a_rcut_smth : float, optional + Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. + a_sel : int, optional + Maximally possible number of selected angle neighbors. + a_compress_rate : int, optional + The compression rate for angular messages. The default value is 0, indicating no compression. + If a non-zero integer c is provided, the node and edge dimensions will be compressed + to a_dim/c and a_dim/2c, respectively, within the angular message. + a_compress_e_rate : int, optional + The extra compression rate for edge in angular message compression. The default value is 1. + When using angular message compression with a_compress_rate c and a_compress_e_rate c_e, + the edge dimension will be compressed to (c_e * a_dim / 2c) within the angular message. + a_compress_use_split : bool, optional + Whether to split first sub-vectors instead of linear mapping during angular message compression. + The default value is False. + n_multi_edge_message : int, optional + The head number of multiple edge messages to update node feature. + Default is 1, indicating one head edge message. + axis_neuron : int, optional + The number of dimension of submatrix in the symmetrization ops. + update_angle : bool, optional + Where to update the angle rep. If not, only node and edge rep will be used. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + fix_stat_std : float, optional + If non-zero (default is 0.3), use this constant as the normalization standard deviation + instead of computing it from data statistics. + optim_update : bool, optional + Whether to enable the optimized update method. + Uses a more efficient process when enabled. Defaults to True + ntypes : int + Number of element types + activation_function : str, optional + The activation function in the embedding net. + set_davg_zero : bool, optional + Set the normalization average to zero. + precision : str, optional + The precision of the embedding net parameters. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + seed : int, optional + Random seed for parameter initialization. + """ + super().__init__() + self.e_rcut = float(e_rcut) + self.e_rcut_smth = float(e_rcut_smth) + self.e_sel = e_sel + self.a_rcut = float(a_rcut) + self.a_rcut_smth = float(a_rcut_smth) + self.a_sel = a_sel + self.ntypes = ntypes + self.nlayers = nlayers + # for other common desciptor method + sel = [e_sel] if isinstance(e_sel, int) else e_sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 # use full descriptor. + assert len(sel) == 1 + self.sel = sel + self.rcut = e_rcut + self.rcut_smth = e_rcut_smth + self.sec = self.sel + self.split_sel = self.sel + self.a_compress_rate = a_compress_rate + self.a_compress_e_rate = a_compress_e_rate + self.n_multi_edge_message = n_multi_edge_message + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + self.fix_stat_std = fix_stat_std + self.set_stddev_constant = fix_stat_std != 0.0 + self.a_compress_use_split = a_compress_use_split + self.optim_update = optim_update + + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.update_angle = update_angle + + self.activation_function = activation_function + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.act = ActivationFn(activation_function) + self.prec = PRECISION_DICT[precision] + + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + self.env_protection = env_protection + self.precision = precision + self.epsilon = 1e-4 + self.seed = seed + + self.edge_embd = MLPLayer( + 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) + ) + self.angle_embd = MLPLayer( + 1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1) + ) + layers = [] + for ii in range(nlayers): + layers.append( + RepFlowLayer( + e_rcut=self.e_rcut, + e_rcut_smth=self.e_rcut_smth, + e_sel=self.sel, + a_rcut=self.a_rcut, + a_rcut_smth=self.a_rcut_smth, + a_sel=self.a_sel, + ntypes=self.ntypes, + n_dim=self.n_dim, + e_dim=self.e_dim, + a_dim=self.a_dim, + a_compress_rate=self.a_compress_rate, + a_compress_use_split=self.a_compress_use_split, + a_compress_e_rate=self.a_compress_e_rate, + n_multi_edge_message=self.n_multi_edge_message, + axis_neuron=self.axis_neuron, + update_angle=self.update_angle, + activation_function=self.activation_function, + update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, + precision=precision, + optim_update=self.optim_update, + seed=child_seed(child_seed(seed, 1), ii), + ) + ) + self.layers = torch.nn.ModuleList(layers) + + wanted_shape = (self.ntypes, self.nnei, 4) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) + if self.set_stddev_constant: + stddev = stddev * self.fix_stat_std + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.stats = None + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.e_rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.e_rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_emb(self) -> int: + """Returns the embedding dimension e_dim.""" + return self.e_dim + + def __setitem__(self, key, value) -> None: + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.n_dim + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.n_dim + + @property + def dim_emb(self): + """Returns the embedding dimension e_dim.""" + return self.get_dim_emb() + + def reinit_exclude( + self, + exclude_types: list[tuple[int, int]] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: Optional[torch.Tensor] = None, + mapping: Optional[torch.Tensor] = None, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ): + if comm_dict is None: + assert mapping is not None + assert extended_atype_embd is not None + nframes, nloc, nnei = nlist.shape + nall = extended_coord.view(nframes, -1).shape[1] // 3 + atype = extended_atype[:, :nloc] + # nb x nloc x nnei + exclude_mask = self.emask(nlist, extended_atype) + nlist = torch.where(exclude_mask != 0, nlist, -1) + # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 + dmatrix, diff, sw = prod_env_mat( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.e_rcut, + self.e_rcut_smth, + protection=self.env_protection, + ) + nlist_mask = nlist != -1 + sw = torch.squeeze(sw, -1) + # beyond the cutoff sw should be 0.0 + sw = sw.masked_fill(~nlist_mask, 0.0) + + # [nframes, nloc, tebd_dim] + if comm_dict is None: + assert isinstance(extended_atype_embd, torch.Tensor) # for jit + atype_embd = extended_atype_embd[:, :nloc, :] + assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] + else: + atype_embd = extended_atype_embd + assert isinstance(atype_embd, torch.Tensor) # for jit + node_ebd = self.act(atype_embd) + n_dim = node_ebd.shape[-1] + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) + # nb x nloc x nnei x e_dim + edge_ebd = self.act(self.edge_embd(edge_input)) + + # get angle nlist (maybe smaller) + a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[ + :, :, : self.a_sel + ] + a_nlist = nlist[:, :, : self.a_sel] + a_nlist = torch.where(a_dist_mask, a_nlist, -1) + _, a_diff, a_sw = prod_env_mat( + extended_coord, + a_nlist, + atype, + self.mean[:, : self.a_sel], + self.stddev[:, : self.a_sel], + self.a_rcut, + self.a_rcut_smth, + protection=self.env_protection, + ) + a_nlist_mask = a_nlist != -1 + a_sw = torch.squeeze(a_sw, -1) + # beyond the cutoff sw should be 0.0 + a_sw = a_sw.masked_fill(~a_nlist_mask, 0.0) + a_nlist[a_nlist == -1] = 0 + + # nf x nloc x a_nnei x 3 + normalized_diff_i = a_diff / ( + torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6 + ) + # nf x nloc x 3 x a_nnei + normalized_diff_j = torch.transpose(normalized_diff_i, 2, 3) + # nf x nloc x a_nnei x a_nnei + # 1 - 1e-6 for torch.acos stability + cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6) + # nf x nloc x a_nnei x a_nnei x 1 + cosine_ij = cosine_ij.unsqueeze(-1) / (torch.pi**0.5) + # nf x nloc x a_nnei x a_nnei x a_dim + angle_ebd = self.angle_embd(cosine_ij).reshape( + nframes, nloc, self.a_sel, self.a_sel, self.a_dim + ) + + # set all padding positions to index of 0 + # if the a neighbor is real or not is indicated by nlist_mask + nlist[nlist == -1] = 0 + # nb x nall x n_dim + if comm_dict is None: + assert mapping is not None + mapping = ( + mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim) + ) + for idx, ll in enumerate(self.layers): + # node_ebd: nb x nloc x n_dim + # node_ebd_ext: nb x nall x n_dim + if comm_dict is None: + assert mapping is not None + node_ebd_ext = torch.gather(node_ebd, 1, mapping) + else: + has_spin = "has_spin" in comm_dict + if not has_spin: + n_padding = nall - nloc + node_ebd = torch.nn.functional.pad( + node_ebd.squeeze(0), (0, 0, 0, n_padding), value=0.0 + ) + real_nloc = nloc + real_nall = nall + else: + # for spin + real_nloc = nloc // 2 + real_nall = nall // 2 + real_n_padding = real_nall - real_nloc + node_ebd_real, node_ebd_virtual = torch.split( + node_ebd, [real_nloc, real_nloc], dim=1 + ) + # mix_node_ebd: nb x real_nloc x (n_dim * 2) + mix_node_ebd = torch.cat([node_ebd_real, node_ebd_virtual], dim=2) + # nb x real_nall x (n_dim * 2) + node_ebd = torch.nn.functional.pad( + mix_node_ebd.squeeze(0), (0, 0, 0, real_n_padding), value=0.0 + ) + + assert "send_list" in comm_dict + assert "send_proc" in comm_dict + assert "recv_proc" in comm_dict + assert "send_num" in comm_dict + assert "recv_num" in comm_dict + assert "communicator" in comm_dict + ret = torch.ops.deepmd.border_op( + comm_dict["send_list"], + comm_dict["send_proc"], + comm_dict["recv_proc"], + comm_dict["send_num"], + comm_dict["recv_num"], + node_ebd, + comm_dict["communicator"], + torch.tensor( + real_nloc, + dtype=torch.int32, + device=env.DEVICE, + ), # should be int of c++ + torch.tensor( + real_nall - real_nloc, + dtype=torch.int32, + device=env.DEVICE, + ), # should be int of c++ + ) + node_ebd_ext = ret[0].unsqueeze(0) + if has_spin: + node_ebd_real_ext, node_ebd_virtual_ext = torch.split( + node_ebd_ext, [n_dim, n_dim], dim=2 + ) + node_ebd_ext = concat_switch_virtual( + node_ebd_real_ext, node_ebd_virtual_ext, real_nloc + ) + node_ebd, edge_ebd, angle_ebd = ll.forward( + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist, + a_nlist_mask, + a_sw, + ) + + # nb x nloc x 3 x e_dim + h2g2 = RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) + # (nb x nloc) x e_dim x 3 + rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) + + return node_ebd, edge_ebd, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + if self.set_stddev_constant and self.set_davg_zero: + return + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + self.mean.copy_( + torch.tensor(mean, device=env.DEVICE, dtype=self.mean.dtype) + ) + if not self.set_stddev_constant: + self.stddev.copy_( + torch.tensor(stddev, device=env.DEVICE, dtype=self.stddev.dtype) + ) + + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return True + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return True diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 6ce4f5d6fc..50d378455b 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -39,6 +39,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.softplus(x) elif self.activation.lower() == "sigmoid": return torch.sigmoid(x) + elif self.activation.lower() == "silu": + return F.silu(x) elif self.activation.lower() == "linear" or self.activation.lower() == "none": return x else: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index a00cfb047a..32c0766265 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1354,6 +1354,239 @@ def dpa2_repformer_args(): ] +@descrpt_args_plugin.register("dpa3", doc=doc_only_pt_supported) +def descrpt_dpa3_args(): + # repflow args + doc_repflow = "The arguments used to initialize the repflow block." + # descriptor args + doc_concat_output_tebd = ( + "Whether to concat type embedding at the output of the descriptor." + ) + doc_activation_function = f"The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." + doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." + doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." + doc_trainable = "If the parameters in the embedding net is trainable." + doc_seed = "Random seed for parameter initialization." + doc_use_econf_tebd = "Whether to use electronic configuration type embedding." + doc_use_tebd_bias = "Whether to use bias in the type embedding layer." + return [ + # doc_repflow args + Argument("repflow", dict, dpa3_repflow_args(), doc=doc_repflow), + # descriptor args + Argument( + "concat_output_tebd", + bool, + optional=True, + default=False, + doc=doc_concat_output_tebd, + ), + Argument( + "activation_function", + str, + optional=True, + default="silu", + doc=doc_activation_function, + ), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument( + "exclude_types", + list[list[int]], + optional=True, + default=[], + doc=doc_exclude_types, + ), + Argument( + "env_protection", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_env_protection, + ), + Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), + Argument("seed", [int, None], optional=True, doc=doc_seed), + Argument( + "use_econf_tebd", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_use_econf_tebd, + ), + Argument( + "use_tebd_bias", + bool, + optional=True, + default=False, + doc=doc_use_tebd_bias, + ), + ] + + +# repflow for dpa3 +def dpa3_repflow_args(): + # repflow args + doc_n_dim = "The dimension of node representation." + doc_e_dim = "The dimension of edge representation." + doc_a_dim = "The dimension of angle representation." + doc_nlayers = "The number of repflow layers." + doc_e_rcut = "The edge cut-off radius." + doc_e_rcut_smth = "Where to start smoothing for edge. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." + doc_e_sel = 'Maximally possible number of selected edge neighbors. It can be:\n\n\ + - `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\ + - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wrapped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".' + doc_a_rcut = "The angle cut-off radius." + doc_a_rcut_smth = "Where to start smoothing for angle. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." + doc_a_sel = 'Maximally possible number of selected angle neighbors. It can be:\n\n\ + - `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\ + - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wrapped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".' + doc_a_compress_rate = ( + "The compression rate for angular messages. The default value is 0, indicating no compression. " + " If a non-zero integer c is provided, the node and edge dimensions will be compressed " + "to a_dim/c and a_dim/2c, respectively, within the angular message." + ) + doc_a_compress_e_rate = ( + "The extra compression rate for edge in angular message compression. The default value is 1." + "When using angular message compression with a_compress_rate c and a_compress_e_rate c_e, " + "the edge dimension will be compressed to (c_e * a_dim / 2c) within the angular message. " + ) + doc_a_compress_use_split = ( + "Whether to split first sub-vectors instead of linear mapping during angular message compression. " + "The default value is False." + ) + doc_n_multi_edge_message = ( + "The head number of multiple edge messages to update node feature. " + "Default is 1, indicating one head edge message." + ) + doc_axis_neuron = "The number of dimension of submatrix in the symmetrization ops." + doc_fix_stat_std = ( + "If non-zero (default is 0.3), use this constant as the normalization standard deviation " + "instead of computing it from data statistics." + ) + doc_skip_stat = ( + "(Deprecated, kept only for compatibility.) This parameter is obsolete and will be removed. " + "If set to True, it forces fix_stat_std=0.3 for backward compatibility. " + "Transition to fix_stat_std parameter immediately." + ) + doc_update_angle = ( + "Where to update the angle rep. If not, only node and edge rep will be used." + ) + doc_update_style = ( + "Style to update a representation. " + "Supported options are: " + "-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) " + "-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)" + "-'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) " + "where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` " + "and `update_residual_init`." + ) + doc_update_residual = ( + "When update using residual mode, the initial std of residual vector weights." + ) + doc_update_residual_init = ( + "When update using residual mode, " + "the initialization mode of residual vector weights." + "Supported modes are: ['norm', 'const']." + ) + doc_optim_update = ( + "Whether to enable the optimized update method. " + "Uses a more efficient process when enabled. Defaults to True" + ) + + return [ + # repflow args + Argument("n_dim", int, optional=True, default=128, doc=doc_n_dim), + Argument("e_dim", int, optional=True, default=64, doc=doc_e_dim), + Argument("a_dim", int, optional=True, default=64, doc=doc_a_dim), + Argument("nlayers", int, optional=True, default=6, doc=doc_nlayers), + Argument("e_rcut", float, doc=doc_e_rcut), + Argument("e_rcut_smth", float, doc=doc_e_rcut_smth), + Argument("e_sel", [int, str], doc=doc_e_sel), + Argument("a_rcut", float, doc=doc_a_rcut), + Argument("a_rcut_smth", float, doc=doc_a_rcut_smth), + Argument("a_sel", [int, str], doc=doc_a_sel), + Argument( + "a_compress_rate", int, optional=True, default=0, doc=doc_a_compress_rate + ), + Argument( + "a_compress_e_rate", + int, + optional=True, + default=1, + doc=doc_a_compress_e_rate, + ), + Argument( + "a_compress_use_split", + bool, + optional=True, + default=False, + doc=doc_a_compress_use_split, + ), + Argument( + "n_multi_edge_message", + int, + optional=True, + default=1, + doc=doc_n_multi_edge_message, + ), + Argument( + "axis_neuron", + int, + optional=True, + default=4, + doc=doc_axis_neuron, + ), + Argument( + "fix_stat_std", + float, + optional=True, + default=0.3, + doc=doc_fix_stat_std, + ), + Argument( + "skip_stat", + bool, + optional=True, + default=False, + doc=doc_skip_stat, + ), + Argument( + "update_angle", + bool, + optional=True, + default=True, + doc=doc_update_angle, + ), + Argument( + "update_style", + str, + optional=True, + default="res_residual", + doc=doc_update_style, + ), + Argument( + "update_residual", + float, + optional=True, + default=0.1, + doc=doc_update_residual, + ), + Argument( + "update_residual_init", + str, + optional=True, + default="const", + doc=doc_update_residual_init, + ), + Argument( + "optim_update", + bool, + optional=True, + default=True, + doc=doc_optim_update, + ), + ] + + @descrpt_args_plugin.register( "se_a_ebd_v2", alias=["se_a_tpe_v2"], doc=doc_only_tf_supported ) diff --git a/source/tests/array_api_strict/descriptor/dpa3.py b/source/tests/array_api_strict/descriptor/dpa3.py new file mode 100644 index 0000000000..b0e959436e --- /dev/null +++ b/source/tests/array_api_strict/descriptor/dpa3.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.type_embed import ( + TypeEmbedNet, +) +from .base_descriptor import ( + BaseDescriptor, +) +from .repflows import ( + DescrptBlockRepflows, +) + + +@BaseDescriptor.register("dpa3") +class DescrptDPA3(DescrptDPA3DP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_array_api_strict_array(value) + elif name in {"repflows"}: + value = DescrptBlockRepflows.deserialize(value.serialize()) + elif name in {"type_embedding"}: + value = TypeEmbedNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/descriptor/repflows.py b/source/tests/array_api_strict/descriptor/repflows.py new file mode 100644 index 0000000000..2a8dbc42f3 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/repflows.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.repflows import ( + DescrptBlockRepflows as DescrptBlockRepflowsDP, +) +from deepmd.dpmodel.descriptor.repflows import RepFlowLayer as RepFlowLayerDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + PairExcludeMask, +) +from ..utils.network import ( + NativeLayer, +) + + +class DescrptBlockRepflows(DescrptBlockRepflowsDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_array_api_strict_array(value) + elif name in {"layers"}: + value = [RepFlowLayer.deserialize(layer.serialize()) for layer in value] + elif name in {"edge_embd", "angle_embd"}: + value = NativeLayer.deserialize(value.serialize()) + elif name in {"env_mat_edge", "env_mat_angle"}: + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +class RepFlowLayer(RepFlowLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in { + "node_self_mlp", + "node_sym_linear", + "node_edge_linear", + "edge_self_linear", + "a_compress_n_linear", + "a_compress_e_linear", + "edge_angle_linear1", + "edge_angle_linear2", + "angle_self_linear", + }: + if value is not None: + value = NativeLayer.deserialize(value.serialize()) + elif name in {"n_residual", "e_residual", "a_residual"}: + value = [to_array_api_strict_array(vv) for vv in value] + return super().__setattr__(name, value) diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py new file mode 100644 index 0000000000..9ebcb15f85 --- /dev/null +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -0,0 +1,351 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +from dargs import ( + Argument, +) + +from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, + INSTALLED_PD, + INSTALLED_PT, + CommonTest, + parameterized, +) +from .common import ( + DescriptorTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3PT +else: + DescrptDPA3PT = None + +if INSTALLED_JAX: + from deepmd.jax.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3JAX +else: + DescrptDPA3JAX = None + +if INSTALLED_PD: + # not supported yet + DescrptDPA3PD = None +else: + DescrptDPA3PD = None + +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3Strict +else: + DescrptDPA3Strict = None + +# not implemented +DescrptDPA3TF = None + +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.utils.argcheck import ( + descrpt_dpa3_args, +) + + +@parameterized( + ("const",), # update_residual_init + ([], [[0, 1]]), # exclude_types + (True, False), # update_angle + (0, 1), # a_compress_rate + (1, 2), # a_compress_e_rate + (True, False), # a_compress_use_split + (True, False), # optim_update + (0.3, 0.0), # fix_stat_std + (1, 2), # n_multi_edge_message + ("float64",), # precision +) +class TestDPA3(CommonTest, DescriptorTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + update_residual_init, + exclude_types, + update_angle, + a_compress_rate, + a_compress_e_rate, + a_compress_use_split, + optim_update, + fix_stat_std, + n_multi_edge_message, + precision, + ) = self.param + return { + "ntypes": self.ntypes, + # kwargs for repinit + "repflow": RepFlowArgs( + **{ + "n_dim": 20, + "e_dim": 10, + "a_dim": 8, + "nlayers": 3, + "e_rcut": 6.0, + "e_rcut_smth": 5.0, + "e_sel": 10, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 8, + "a_compress_rate": a_compress_rate, + "a_compress_e_rate": a_compress_e_rate, + "a_compress_use_split": a_compress_use_split, + "optim_update": optim_update, + "fix_stat_std": fix_stat_std, + "n_multi_edge_message": n_multi_edge_message, + "axis_neuron": 4, + "update_angle": update_angle, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": update_residual_init, + } + ), + # kwargs for descriptor + "activation_function": "silu", + "precision": precision, + "exclude_types": exclude_types, + "env_protection": 0.0, + "trainable": True, + } + + @property + def skip_pt(self) -> bool: + ( + update_residual_init, + exclude_types, + update_angle, + a_compress_rate, + a_compress_e_rate, + a_compress_use_split, + optim_update, + fix_stat_std, + n_multi_edge_message, + precision, + ) = self.param + return CommonTest.skip_pt + + @property + def skip_pd(self) -> bool: + ( + update_residual_init, + exclude_types, + update_angle, + a_compress_rate, + a_compress_e_rate, + a_compress_use_split, + optim_update, + fix_stat_std, + n_multi_edge_message, + precision, + ) = self.param + # return not INSTALLED_PD or precision == "bfloat16" + return True + + @property + def skip_dp(self) -> bool: + ( + update_residual_init, + exclude_types, + update_angle, + a_compress_rate, + a_compress_e_rate, + a_compress_use_split, + optim_update, + fix_stat_std, + n_multi_edge_message, + precision, + ) = self.param + return CommonTest.skip_dp + + @property + def skip_tf(self) -> bool: + ( + update_residual_init, + exclude_types, + update_angle, + a_compress_rate, + a_compress_e_rate, + a_compress_use_split, + optim_update, + fix_stat_std, + n_multi_edge_message, + precision, + ) = self.param + return True + + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + + tf_class = DescrptDPA3TF + dp_class = DescrptDPA3DP + pt_class = DescrptDPA3PT + pd_class = DescrptDPA3PD + jax_class = DescrptDPA3JAX + array_api_strict_class = DescrptDPA3Strict + args = descrpt_dpa3_args().append(Argument("ntypes", int, optional=False)) + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + ( + update_residual_init, + exclude_types, + update_angle, + a_compress_rate, + a_compress_e_rate, + a_compress_use_split, + optim_update, + fix_stat_std, + n_multi_edge_message, + precision, + ) = self.param + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + return self.build_tf_descriptor( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_descriptor( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_descriptor( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_pd(self, pd_obj: Any) -> Any: + return self.eval_pd_descriptor( + pd_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + return (ret[0],) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + update_residual_init, + exclude_types, + update_angle, + a_compress_rate, + a_compress_e_rate, + a_compress_use_split, + optim_update, + fix_stat_std, + n_multi_edge_message, + precision, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + update_residual_init, + exclude_types, + update_angle, + a_compress_rate, + a_compress_e_rate, + a_compress_use_split, + optim_update, + fix_stat_std, + n_multi_edge_message, + precision, + ) = self.param + if precision == "float64": + return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt/model/test_dpa3.py b/source/tests/pt/model/test_dpa3.py new file mode 100644 index 0000000000..240ffad925 --- /dev/null +++ b/source/tests/pt/model/test_dpa3.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DPDescrptDPA3 +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.pt.model.descriptor import ( + DescrptDPA3, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptDPA3(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + ua, + rus, + ruri, + acr, + acer, + acus, + nme, + prec, + ect, + ) in itertools.product( + [True, False], # update_angle + ["res_residual"], # update_style + ["norm", "const"], # update_residual_init + [0, 1], # a_compress_rate + [1, 2], # a_compress_e_rate + [True, False], # a_compress_use_split + [1, 2], # n_multi_edge_message + ["float64"], # precision + [False], # use_econf_tebd + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 # marginal GPU test cases... + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=8, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + a_compress_rate=acr, + a_compress_e_rate=acer, + a_compress_use_split=acus, + n_multi_edge_message=nme, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + ) + + # dpa3 new impl + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA3.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + # dp impl + dd2 = DPDescrptDPA3.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, self.atype_ext, self.nlist, self.mapping + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + ) + + def test_jit( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + ua, + rus, + ruri, + acr, + acer, + acus, + nme, + prec, + ect, + ) in itertools.product( + [True], # update_angle + ["res_residual"], # update_style + ["const"], # update_residual_init + [0, 1], # a_compress_rate + [2], # a_compress_e_rate + [True], # a_compress_use_split + [1, 2], # n_multi_edge_message + ["float64"], # precision + [False], # use_econf_tebd + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=8, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + a_compress_rate=acr, + a_compress_e_rate=acer, + a_compress_use_split=acus, + n_multi_edge_message=nme, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + ) + + # dpa3 new impl + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + model = torch.jit.script(dd0) diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 7911cb9395..422b6c3596 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -7,6 +7,7 @@ from deepmd.dpmodel.descriptor import ( DescrptDPA1, DescrptDPA2, + DescrptDPA3, DescrptHybrid, DescrptSeA, DescrptSeR, @@ -17,6 +18,9 @@ RepformerArgs, RepinitArgs, ) +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) from ....consistent.common import ( parameterize_func, @@ -460,6 +464,90 @@ def DescriptorParamDPA2( DescriptorParamDPA2 = DescriptorParamDPA2List[0] +def DescriptorParamDPA3( + ntypes, + rcut, + rcut_smth, + sel, + type_map, + env_protection=0.0, + exclude_types=[], + update_style="res_residual", + update_residual=0.1, + update_residual_init="const", + update_angle=True, + n_multi_edge_message=1, + a_compress_rate=0, + a_compress_e_rate=1, + a_compress_use_split=False, + optim_update=True, + fix_stat_std=0.3, + precision="float64", +): + input_dict = { + # kwargs for repformer + "repflow": RepFlowArgs( + **{ + "n_dim": 20, + "e_dim": 10, + "a_dim": 8, + "nlayers": 3, + "e_rcut": rcut, + "e_rcut_smth": rcut_smth, + "e_sel": sum(sel), + "a_rcut": rcut / 2, + "a_rcut_smth": rcut_smth / 2, + "a_sel": sum(sel) // 4, + "a_compress_rate": a_compress_rate, + "a_compress_e_rate": a_compress_e_rate, + "a_compress_use_split": a_compress_use_split, + "optim_update": optim_update, + "fix_stat_std": fix_stat_std, + "n_multi_edge_message": n_multi_edge_message, + "axis_neuron": 4, + "update_angle": update_angle, + "update_style": update_style, + "update_residual": update_residual, + "update_residual_init": update_residual_init, + } + ), + "ntypes": ntypes, + "concat_output_tebd": False, + "precision": precision, + "activation_function": "silu", + "exclude_types": exclude_types, + "env_protection": env_protection, + "trainable": True, + "use_econf_tebd": False, + "use_tebd_bias": False, + "type_map": type_map, + "seed": GLOBAL_SEED, + } + return input_dict + + +DescriptorParamDPA3List = parameterize_func( + DescriptorParamDPA3, + OrderedDict( + { + "update_residual_init": ("const",), + "exclude_types": ([], [[0, 1]]), + "update_angle": (True, False), + "a_compress_rate": (0, 1), + "a_compress_e_rate": (2,), + "a_compress_use_split": (True, False), + "optim_update": (True, False), + "fix_stat_std": (0.3,), + "n_multi_edge_message": (1, 2), + "env_protection": (0.0, 1e-8), + "precision": ("float64",), + } + ), +) +# to get name for the default function +DescriptorParamDPA3 = DescriptorParamDPA3List[0] + + def DescriptorParamHybrid(ntypes, rcut, rcut_smth, sel, type_map, **kwargs): ddsub0 = { "type": "se_e2_a", @@ -515,6 +603,7 @@ def DescriptorParamHybridMixedTTebd(ntypes, rcut, rcut_smth, sel, type_map, **kw (DescriptorParamSeTTebd, DescrptSeTTebd), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), (DescriptorParamHybridMixedTTebd, DescrptHybrid), diff --git a/source/tests/universal/pt/descriptor/test_descriptor.py b/source/tests/universal/pt/descriptor/test_descriptor.py index 349eb65588..25c78b43c1 100644 --- a/source/tests/universal/pt/descriptor/test_descriptor.py +++ b/source/tests/universal/pt/descriptor/test_descriptor.py @@ -4,6 +4,7 @@ from deepmd.pt.model.descriptor import ( DescrptDPA1, DescrptDPA2, + DescrptDPA3, DescrptHybrid, DescrptSeA, DescrptSeR, @@ -20,6 +21,7 @@ from ...dpmodel.descriptor.test_descriptor import ( DescriptorParamDPA1, DescriptorParamDPA2, + DescriptorParamDPA3, DescriptorParamHybrid, DescriptorParamHybridMixed, DescriptorParamSeA, @@ -40,6 +42,7 @@ (DescriptorParamSeTTebd, DescrptSeTTebd), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ) # class_param & class diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 3eb1484c45..867fa48b87 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -10,6 +10,7 @@ from deepmd.pt.model.descriptor import ( DescrptDPA1, DescrptDPA2, + DescrptDPA3, DescrptHybrid, DescrptSeA, DescrptSeR, @@ -55,6 +56,8 @@ DescriptorParamDPA1List, DescriptorParamDPA2, DescriptorParamDPA2List, + DescriptorParamDPA3, + DescriptorParamDPA3List, DescriptorParamHybrid, DescriptorParamHybridMixed, DescriptorParamHybridMixedTTebd, @@ -93,6 +96,7 @@ DescriptorParamSeTTebd, DescriptorParamDPA1, DescriptorParamDPA2, + DescriptorParamDPA3, DescriptorParamHybrid, DescriptorParamHybridMixed, ] @@ -117,6 +121,7 @@ ], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), (DescriptorParamHybridMixedTTebd, DescrptHybrid), @@ -131,6 +136,7 @@ (DescriptorParamSeTTebd, DescrptSeTTebd), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], @@ -219,6 +225,7 @@ def setUpClass(cls) -> None: ], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), (DescriptorParamHybridMixedTTebd, DescrptHybrid), @@ -233,6 +240,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeTTebd, DescrptSeTTebd), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, DOSFittingNet) for param_func in FittingParamDosList], @@ -316,6 +324,7 @@ def setUpClass(cls) -> None: *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class @@ -326,6 +335,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeA, DescrptSeA), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, DipoleFittingNet) for param_func in FittingParamDipoleList], @@ -409,6 +419,7 @@ def setUpClass(cls) -> None: *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class @@ -419,6 +430,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeA, DescrptSeA), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, PolarFittingNet) for param_func in FittingParamPolarList], @@ -645,6 +657,7 @@ def setUpClass(cls) -> None: (FittingParam, Fitting) = cls.param[1] cls.epsilon_dict["test_smooth"] = 1e-6 cls.aprec_dict["test_smooth"] = 5e-5 + cls.aprec_dict["test_rot"] = 1e-10 # for test stability # set special precision if Descrpt in [DescrptDPA2, DescrptHybrid]: cls.epsilon_dict["test_smooth"] = 1e-8 @@ -721,6 +734,7 @@ def setUpClass(cls) -> None: *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class @@ -731,6 +745,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeA, DescrptSeA), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[ @@ -812,6 +827,7 @@ def setUpClass(cls) -> None: ( *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybridMixed, DescrptHybrid), (DescriptorParamHybridMixedTTebd, DescrptHybrid), ), # descrpt_class_param & class @@ -821,6 +837,7 @@ def setUpClass(cls) -> None: ( (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList],