diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index ad5167c572..77e9f1d936 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -31,10 +31,14 @@ build_multiple_neighbor_list, get_multiple_nlist_key, ) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) from deepmd.pt.utils.update_sel import ( UpdateSel, ) from deepmd.pt.utils.utils import ( + ActivationFn, to_numpy_array, ) from deepmd.utils.data_system import ( @@ -306,6 +310,7 @@ def init_subclass_params(sub_data, sub_class): # set trainable for param in self.parameters(): param.requires_grad = trainable + self.compress = False def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -859,3 +864,85 @@ def update_sel( ) local_jdata_cpy["repformer"]["nsel"] = repformer_sel[0] return local_jdata_cpy, min_nbor_dist + + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + # do some checks before the mocel compression process + if self.compress: + raise ValueError("Compression is already enabled.") + assert ( + not self.repinit.resnet_dt + ), "Model compression error: repinit resnet_dt must be false!" + for tt in self.repinit.exclude_types: + if (tt[0] not in range(self.repinit.ntypes)) or ( + tt[1] not in range(self.repinit.ntypes) + ): + raise RuntimeError( + "Repinit exclude types" + + str(tt) + + " must within the number of atomic types " + + str(self.repinit.ntypes) + + "!" + ) + if ( + self.repinit.ntypes * self.repinit.ntypes - len(self.repinit.exclude_types) + == 0 + ): + raise RuntimeError( + "Repinit empty embedding-nets are not supported in model compression!" + ) + + if self.repinit.attn_layer != 0: + raise RuntimeError( + "Cannot compress model when repinit attention layer is not 0." + ) + + if self.repinit.tebd_input_mode != "strip": + raise RuntimeError( + "Cannot compress model when repinit tebd_input_mode == 'concat'" + ) + + # repinit doesn't have a serialize method + data = self.serialize() + self.table = DPTabulate( + self, + data["repinit_args"]["neuron"], + data["repinit_args"]["type_one_side"], + data["exclude_types"], + ActivationFn(data["repinit_args"]["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + + self.repinit.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + self.compress = True diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py index 796f7dcd52..e21d2ec9a6 100644 --- a/deepmd/pt/utils/tabulate.py +++ b/deepmd/pt/utils/tabulate.py @@ -95,11 +95,14 @@ def __init__( raise RuntimeError("Unknown activation function type!") self.activation_fn = activation_fn - self.davg = self.descrpt.serialize()["@variables"]["davg"] - self.dstd = self.descrpt.serialize()["@variables"]["dstd"] - self.ntypes = self.descrpt.get_ntypes() + serialized = self.descrpt.serialize() + if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA2): + serialized = serialized["repinit_variable"] + self.davg = serialized["@variables"]["davg"] + self.dstd = serialized["@variables"]["dstd"] + self.embedding_net_nodes = serialized["embeddings"]["networks"] - self.embedding_net_nodes = self.descrpt.serialize()["embeddings"]["networks"] + self.ntypes = self.descrpt.get_ntypes() self.layer_size = self._get_layer_size() self.table_size = self._get_table_size() @@ -291,7 +294,13 @@ def _layer_1(self, x, w, b): return t, self.activation_fn(torch.matmul(x, w) + b) + t def _get_descrpt_type(self): - if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA1): + if isinstance( + self.descrpt, + ( + deepmd.pt.model.descriptor.DescrptDPA1, + deepmd.pt.model.descriptor.DescrptDPA2, + ), + ): return "Atten" elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeA): return "A" diff --git a/source/tests/pt/model/test_compressed_descriptor_dpa2.py b/source/tests/pt/model/test_compressed_descriptor_dpa2.py new file mode 100644 index 0000000000..05b1143eb1 --- /dev/null +++ b/source/tests/pt/model/test_compressed_descriptor_dpa2.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor.dpa2 import ( + RepformerArgs, + RepinitArgs, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.pt.model.descriptor.dpa2 import ( + DescrptDPA2, +) +from deepmd.pt.utils.env import DEVICE as PT_DEVICE +from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt +from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, +) + +from ...consistent.common import ( + parameterized, +) + + +def eval_pt_descriptor( + pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False +) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping) + return result + + +@parameterized(("float32", "float64"), (True, False)) +class TestDescriptorDPA2(unittest.TestCase): + def setUp(self): + (self.dtype, self.type_one_side) = self.param + if self.dtype == "float32": + self.skipTest("FP32 has bugs:") + # ../../../../deepmd/pt/model/descriptor/repformer_layer.py:521: in forward + # torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) + # E RuntimeError: expected scalar type Float but found Double + if self.dtype == "float32": + self.atol = 1e-5 + elif self.dtype == "float64": + self.atol = 1e-10 + self.seed = 21 + self.sel = [10] + self.rcut_smth = 5.80 + self.rcut = 6.00 + self.neuron = [6, 12, 24] + self.axis_neuron = 3 + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=10, + tebd_input_mode="strip", + type_one_side=self.type_one_side, + ) + repformer = RepformerArgs( + rcut=self.rcut - 1, + rcut_smth=self.rcut_smth - 1, + nsel=9, + ) + + self.descriptor = DescrptDPA2( + ntypes=self.ntypes, + repinit=repinit, + repformer=repformer, + precision=self.dtype, + ) + + def test_compressed_forward(self): + result_pt = eval_pt_descriptor( + self.descriptor, + self.natoms, + self.coords, + self.atype, + self.box, + ) + self.descriptor.enable_compression(0.5) + result_pt_compressed = eval_pt_descriptor( + self.descriptor, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.assertEqual(result_pt.shape, result_pt_compressed.shape) + torch.testing.assert_close( + result_pt, + result_pt_compressed, + atol=self.atol, + rtol=self.atol, + ) + + +if __name__ == "__main__": + unittest.main()