diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 049fb95e3a..f164758304 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -37,7 +37,7 @@ jobs: run: python -m pip config --user set global.index-url https://mirrors.aliyun.com/pypi/simple/ - run: python -m pip install -U "pip>=21.3.1,!=23.0.0" - run: python -m pip install "tensorflow>=2.15.0rc0" - - run: python -m pip install -v -e .[gpu,test,lmp,cu12] "ase @ https://gitlab.com/ase/ase/-/archive/8c5aa5fd6448c5cfb517a014dccf2b214a9dfa8f/ase-8c5aa5fd6448c5cfb517a014dccf2b214a9dfa8f.tar.gz" + - run: python -m pip install -v -e .[gpu,test,lmp,cu12,torch] "ase @ https://gitlab.com/ase/ase/-/archive/8c5aa5fd6448c5cfb517a014dccf2b214a9dfa8f/ase-8c5aa5fd6448c5cfb517a014dccf2b214a9dfa8f.tar.gz" env: DP_BUILD_TESTING: 1 DP_VARIANT: cuda diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 55ef041532..091a2a61f8 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -9,12 +9,12 @@ jobs: strategy: matrix: include: - - python: 3.7 - tf: 1.14 - python: 3.8 tf: + torch: - python: "3.11" tf: + torch: steps: - uses: actions/checkout@v4 @@ -23,22 +23,25 @@ jobs: python-version: ${{ matrix.python }} cache: 'pip' - uses: mpi4py/setup-mpi@v1 - if: ${{ matrix.tf == '' }} with: mpi: openmpi # https://github.com/pypa/pip/issues/11770 - run: python -m pip install -U "pip>=21.3.1,!=23.0.0" - - run: pip install -e .[cpu,test] + - run: python -m pip install -U "torch==${{ matrix.torch }}" "numpy<1.20" + if: matrix.torch != '' + - run: pip install -e .[cpu,test,torch] env: TENSORFLOW_VERSION: ${{ matrix.tf }} DP_BUILD_TESTING: 1 - run: pip install horovod mpi4py - if: ${{ matrix.tf == '' }} env: HOROVOD_WITH_TENSORFLOW: 1 + HOROVOD_WITHOUT_PYTORCH: 1 HOROVOD_WITHOUT_GLOO: 1 - run: dp --version - run: pytest --cov=deepmd source/tests --durations=0 + env: + NUM_WORKERS: 0 - uses: codecov/codecov-action@v3 with: gcov: true diff --git a/.gitignore b/.gitignore index 82d3e4a7da..5e30cf3167 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ *.bz2 *.pyc *.pb +*.DS_Store tmp* CMakeCache.txt CMakeFiles diff --git a/README.md b/README.md index e61c18dbcb..2076e11f1b 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,6 @@ The code is organized as follows: See [DeePMD-kit Contributing Guide](CONTRIBUTING.md) to become a contributor! 🤓 - [1]: https://arxiv.org/abs/1707.01478 [2]: https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.143001 [3]: https://arxiv.org/abs/1805.09003 diff --git a/backend/dynamic_metadata.py b/backend/dynamic_metadata.py index 72dfcaef45..e30c97bd98 100644 --- a/backend/dynamic_metadata.py +++ b/backend/dynamic_metadata.py @@ -88,4 +88,8 @@ def dynamic_metadata( "nvidia-cudnn-cu12", "nvidia-cuda-nvcc-cu12", ], + "torch": [ + "torch>=2a", + "tqdm", + ], } diff --git a/deepmd/pt/__init__.py b/deepmd/pt/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt/entrypoints/__init__.py b/deepmd/pt/entrypoints/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt/entrypoints/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py new file mode 100644 index 0000000000..f1cd7ae210 --- /dev/null +++ b/deepmd/pt/entrypoints/main.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import argparse +import json +import logging +import os + +import torch +import torch.distributed as dist +from torch.distributed.elastic.multiprocessing.errors import ( + record, +) + +from deepmd import ( + __version__, +) +from deepmd.pt.infer import ( + inference, +) +from deepmd.pt.model.descriptor import ( + Descriptor, +) +from deepmd.pt.train import ( + training, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.finetune import ( + change_finetune_model_params, +) +from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) + + +def get_trainer( + config, + init_model=None, + restart_model=None, + finetune_model=None, + model_branch="", + force_load=False, +): + # Initialize DDP + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None: + local_rank = int(local_rank) + assert dist.is_nccl_available() + dist.init_process_group(backend="nccl") + + multi_task = "model_dict" in config["model"] + ckpt = init_model if init_model is not None else restart_model + config["model"] = change_finetune_model_params( + ckpt, + finetune_model, + config["model"], + multi_task=multi_task, + model_branch=model_branch, + ) + config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None) + shared_links = None + if multi_task: + config["model"], shared_links = preprocess_shared_params(config["model"]) + + def prepare_trainer_input_single( + model_params_single, data_dict_single, loss_dict_single, suffix="" + ): + training_dataset_params = data_dict_single["training_data"] + type_split = False + if model_params_single["descriptor"]["type"] in ["se_e2_a"]: + type_split = True + validation_dataset_params = data_dict_single["validation_data"] + training_systems = training_dataset_params["systems"] + validation_systems = validation_dataset_params["systems"] + + # noise params + noise_settings = None + if loss_dict_single.get("type", "ener") == "denoise": + noise_settings = { + "noise_type": loss_dict_single.pop("noise_type", "uniform"), + "noise": loss_dict_single.pop("noise", 1.0), + "noise_mode": loss_dict_single.pop("noise_mode", "fix_num"), + "mask_num": loss_dict_single.pop("mask_num", 8), + "mask_prob": loss_dict_single.pop("mask_prob", 0.15), + "same_mask": loss_dict_single.pop("same_mask", False), + "mask_coord": loss_dict_single.pop("mask_coord", False), + "mask_type": loss_dict_single.pop("mask_type", False), + "max_fail_num": loss_dict_single.pop("max_fail_num", 10), + "mask_type_idx": len(model_params_single["type_map"]) - 1, + } + # noise_settings = None + + # stat files + hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid" + has_stat_file_path = True + if not hybrid_descrpt: + ### this design requires "rcut", "rcut_smth" and "sel" in the descriptor + ### VERY BAD DESIGN!!!! + ### not all descriptors provides these parameter in their constructor + default_stat_file_name = Descriptor.get_stat_name( + model_params_single["descriptor"] + ) + model_params_single["stat_file_dir"] = data_dict_single.get( + "stat_file_dir", f"stat_files{suffix}" + ) + model_params_single["stat_file"] = data_dict_single.get( + "stat_file", default_stat_file_name + ) + model_params_single["stat_file_path"] = os.path.join( + model_params_single["stat_file_dir"], model_params_single["stat_file"] + ) + if not os.path.exists(model_params_single["stat_file_path"]): + has_stat_file_path = False + else: ### need to remove this + default_stat_file_name = [] + for descrpt in model_params_single["descriptor"]["list"]: + default_stat_file_name.append( + f'stat_file_rcut{descrpt["rcut"]:.2f}_' + f'smth{descrpt["rcut_smth"]:.2f}_' + f'sel{descrpt["sel"]}_{descrpt["type"]}.npz' + ) + model_params_single["stat_file_dir"] = data_dict_single.get( + "stat_file_dir", f"stat_files{suffix}" + ) + model_params_single["stat_file"] = data_dict_single.get( + "stat_file", default_stat_file_name + ) + assert isinstance( + model_params_single["stat_file"], list + ), "Stat file of hybrid descriptor must be a list!" + stat_file_path = [] + for stat_file_path_item in model_params_single["stat_file"]: + single_file_path = os.path.join( + model_params_single["stat_file_dir"], stat_file_path_item + ) + stat_file_path.append(single_file_path) + if not os.path.exists(single_file_path): + has_stat_file_path = False + model_params_single["stat_file_path"] = stat_file_path + + # validation and training data + validation_data_single = DpLoaderSet( + validation_systems, + validation_dataset_params["batch_size"], + model_params_single, + type_split=type_split, + noise_settings=noise_settings, + ) + if ckpt or finetune_model or has_stat_file_path: + train_data_single = DpLoaderSet( + training_systems, + training_dataset_params["batch_size"], + model_params_single, + type_split=type_split, + noise_settings=noise_settings, + ) + sampled_single = None + else: + train_data_single = DpLoaderSet( + training_systems, + training_dataset_params["batch_size"], + model_params_single, + type_split=type_split, + ) + data_stat_nbatch = model_params_single.get("data_stat_nbatch", 10) + sampled_single = make_stat_input( + train_data_single.systems, + train_data_single.dataloaders, + data_stat_nbatch, + ) + if noise_settings is not None: + train_data_single = DpLoaderSet( + training_systems, + training_dataset_params["batch_size"], + model_params_single, + type_split=type_split, + noise_settings=noise_settings, + ) + return train_data_single, validation_data_single, sampled_single + + if not multi_task: + train_data, validation_data, sampled = prepare_trainer_input_single( + config["model"], config["training"], config["loss"] + ) + else: + train_data, validation_data, sampled = {}, {}, {} + for model_key in config["model"]["model_dict"]: + ( + train_data[model_key], + validation_data[model_key], + sampled[model_key], + ) = prepare_trainer_input_single( + config["model"]["model_dict"][model_key], + config["training"]["data_dict"][model_key], + config["loss_dict"][model_key], + suffix=f"_{model_key}", + ) + + trainer = training.Trainer( + config, + train_data, + sampled, + validation_data=validation_data, + init_model=init_model, + restart_model=restart_model, + finetune_model=finetune_model, + force_load=force_load, + shared_links=shared_links, + ) + return trainer + + +def train(FLAGS): + logging.info("Configuration path: %s", FLAGS.INPUT) + with open(FLAGS.INPUT) as fin: + config = json.load(fin) + trainer = get_trainer( + config, + FLAGS.init_model, + FLAGS.restart, + FLAGS.finetune, + FLAGS.model_branch, + FLAGS.force_load, + ) + trainer.run() + + +def test(FLAGS): + trainer = inference.Tester( + FLAGS.model, + input_script=FLAGS.input_script, + system=FLAGS.system, + datafile=FLAGS.datafile, + numb_test=FLAGS.numb_test, + detail_file=FLAGS.detail_file, + shuffle_test=FLAGS.shuffle_test, + head=FLAGS.head, + ) + trainer.run() + + +def freeze(FLAGS): + model = torch.jit.script( + inference.Tester(FLAGS.model, numb_test=1, head=FLAGS.head).model + ) + torch.jit.save( + model, + FLAGS.output, + { + # TODO: _extra_files + }, + ) + + +# avoid logger conflicts of tf version +def clean_loggers(): + logger = logging.getLogger() + while logger.hasHandlers(): + logger.removeHandler(logger.handlers[0]) + + +@record +def main(args=None): + clean_loggers() + logging.basicConfig( + level=logging.WARNING if env.LOCAL_RANK else logging.INFO, + format=f"%(asctime)-15s {os.environ.get('RANK') or ''} [%(filename)s:%(lineno)d] %(levelname)s %(message)s", + ) + logging.info("DeepMD version: %s", __version__) + parser = argparse.ArgumentParser( + description="A tool to manager deep models of potential energy surface." + ) + subparsers = parser.add_subparsers(dest="command") + train_parser = subparsers.add_parser("train", help="Train a model.") + train_parser.add_argument("INPUT", help="A Json-format configuration file.") + parser_train_subgroup = train_parser.add_mutually_exclusive_group() + parser_train_subgroup.add_argument( + "-i", + "--init-model", + type=str, + default=None, + help="Initialize the model by the provided checkpoint.", + ) + parser_train_subgroup.add_argument( + "-r", + "--restart", + type=str, + default=None, + help="Restart the training from the provided checkpoint.", + ) + parser_train_subgroup.add_argument( + "-t", + "--finetune", + type=str, + default=None, + help="Finetune the frozen pretrained model.", + ) + train_parser.add_argument( + "-m", + "--model-branch", + type=str, + default="", + help="Model branch chosen for fine-tuning if multi-task. If not specified, it will re-init the fitting net.", + ) + train_parser.add_argument( + "--force-load", + action="store_true", + help="Force load from ckpt, other missing tensors will init from scratch", + ) + + test_parser = subparsers.add_parser("test", help="Test a model.") + test_parser_subgroup = test_parser.add_mutually_exclusive_group() + test_parser_subgroup.add_argument( + "-s", + "--system", + default=None, + type=str, + help="The system dir. Recursively detect systems in this directory", + ) + test_parser_subgroup.add_argument( + "-f", + "--datafile", + default=None, + type=str, + help="The path to file of test list.", + ) + test_parser_subgroup.add_argument( + "-i", + "--input-script", + default=None, + type=str, + help="The path to the input script, the validation systems will be tested.", + ) + test_parser.add_argument( + "-m", + "--model", + default="model.pt", + type=str, + help="Model checkpoint to import", + ) + test_parser.add_argument( + "--head", + default=None, + type=str, + help="Task head to test if in multi-task mode.", + ) + test_parser.add_argument( + "-n", "--numb-test", default=100, type=int, help="The number of data for test" + ) + test_parser.add_argument( + "-d", + "--detail-file", + type=str, + default=None, + help="The prefix to files where details of energy, force and virial accuracy/accuracy per atom will be written", + ) + test_parser.add_argument( + "--shuffle-test", action="store_true", default=False, help="Shuffle test data" + ) + + freeze_parser = subparsers.add_parser("freeze", help="Freeze a model.") + freeze_parser.add_argument("model", help="Resumes from checkpoint.") + freeze_parser.add_argument( + "-o", + "--output", + type=str, + default="frozen_model.pth", + help="The frozen model path", + ) + freeze_parser.add_argument( + "--head", + default=None, + type=str, + help="Task head to freeze if in multi-task mode.", + ) + + FLAGS = parser.parse_args(args) + if FLAGS.command == "train": + train(FLAGS) + elif FLAGS.command == "test": + test(FLAGS) + elif FLAGS.command == "freeze": + freeze(FLAGS) + else: + logging.error("Invalid command!") + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/deepmd/pt/infer/__init__.py b/deepmd/pt/infer/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt/infer/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py new file mode 100644 index 0000000000..79772b47ae --- /dev/null +++ b/deepmd/pt/infer/deep_eval.py @@ -0,0 +1,412 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from pathlib import ( + Path, +) +from typing import ( + Callable, + List, + Optional, + Tuple, + Union, +) + +import numpy as np +import torch + +from deepmd.infer.deep_pot import DeepPot as DeepPotBase +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.auto_batch_size import ( + AutoBatchSize, +) +from deepmd.pt.utils.env import ( + DEVICE, + GLOBAL_PT_FLOAT_PRECISION, +) + + +class DeepEval: + def __init__( + self, + model_file: "Path", + auto_batch_size: Union[bool, int, AutoBatchSize] = True, + ): + self.model_path = model_file + state_dict = torch.load(model_file, map_location=env.DEVICE) + if "model" in state_dict: + state_dict = state_dict["model"] + self.input_param = state_dict["_extra_state"]["model_params"] + self.input_param["resuming"] = True + self.multi_task = "model_dict" in self.input_param + assert not self.multi_task, "multitask mode currently not supported!" + self.type_split = self.input_param["descriptor"]["type"] in ["se_e2_a"] + self.type_map = self.input_param["type_map"] + self.dp = ModelWrapper(get_model(self.input_param, None).to(DEVICE)) + self.dp.load_state_dict(state_dict) + self.rcut = self.dp.model["Default"].descriptor.get_rcut() + self.sec = np.cumsum(self.dp.model["Default"].descriptor.get_sel()) + if isinstance(auto_batch_size, bool): + if auto_batch_size: + self.auto_batch_size = AutoBatchSize() + else: + self.auto_batch_size = None + elif isinstance(auto_batch_size, int): + self.auto_batch_size = AutoBatchSize(auto_batch_size) + elif isinstance(auto_batch_size, AutoBatchSize): + self.auto_batch_size = auto_batch_size + else: + raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize") + + def eval( + self, + coords: Union[np.ndarray, torch.Tensor], + cells: Optional[Union[np.ndarray, torch.Tensor]], + atom_types: Union[np.ndarray, torch.Tensor, List[int]], + atomic: bool = False, + ): + raise NotImplementedError + + +class DeepPot(DeepEval, DeepPotBase): + def __init__( + self, + model_file: "Path", + auto_batch_size: Union[bool, int, AutoBatchSize] = True, + neighbor_list=None, + ): + if neighbor_list is not None: + raise NotImplementedError + super().__init__( + model_file, + auto_batch_size=auto_batch_size, + ) + + def eval( + self, + coords: np.ndarray, + cells: np.ndarray, + atom_types: List[int], + atomic: bool = False, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + efield: Optional[np.ndarray] = None, + mixed_type: bool = False, + ): + if fparam is not None or aparam is not None or efield is not None: + raise NotImplementedError + # convert all of the input to numpy array + atom_types = np.array(atom_types, dtype=np.int32) + coords = np.array(coords) + if cells is not None: + cells = np.array(cells) + natoms, numb_test = self._get_natoms_and_nframes( + coords, atom_types, len(atom_types.shape) > 1 + ) + return self._eval_func(self._eval_model, numb_test, natoms)( + coords, cells, atom_types, atomic + ) + + def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable: + """Wrapper method with auto batch size. + + Parameters + ---------- + inner_func : Callable + the method to be wrapped + numb_test : int + number of tests + natoms : int + number of atoms + + Returns + ------- + Callable + the wrapper + """ + if self.auto_batch_size is not None: + + def eval_func(*args, **kwargs): + return self.auto_batch_size.execute_all( + inner_func, numb_test, natoms, *args, **kwargs + ) + + else: + eval_func = inner_func + return eval_func + + def _get_natoms_and_nframes( + self, + coords: np.ndarray, + atom_types: Union[List[int], np.ndarray], + mixed_type: bool = False, + ) -> Tuple[int, int]: + if mixed_type: + natoms = len(atom_types[0]) + else: + natoms = len(atom_types) + if natoms == 0: + assert coords.size == 0 + else: + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + nframes = coords.shape[0] + return natoms, nframes + + def _eval_model( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: np.ndarray, + atomic: bool = False, + ): + model = self.dp.to(DEVICE) + energy_out = None + atomic_energy_out = None + force_out = None + virial_out = None + atomic_virial_out = None + + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + coord_input = torch.tensor( + coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION + ).to(DEVICE) + type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE) + if cells is not None: + box_input = torch.tensor( + cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION + ).to(DEVICE) + else: + box_input = None + + batch_output = model( + coord_input, type_input, box=box_input, do_atomic_virial=atomic + ) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + energy_out = batch_output["energy"].detach().cpu().numpy() + if "atom_energy" in batch_output: + atomic_energy_out = batch_output["atom_energy"].detach().cpu().numpy() + force_out = batch_output["force"].detach().cpu().numpy() + virial_out = batch_output["virial"].detach().cpu().numpy() + if "atomic_virial" in batch_output: + atomic_virial_out = batch_output["atomic_virial"].detach().cpu().numpy() + + if not atomic: + return energy_out, force_out, virial_out + else: + return ( + energy_out, + force_out, + virial_out, + atomic_energy_out, + atomic_virial_out, + ) + + def get_ntypes(self) -> int: + """Get the number of atom types of this model.""" + return len(self.type_map) + + def get_type_map(self) -> List[str]: + """Get the type map (element name of the atom types) of this model.""" + return self.type_map + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this DP.""" + return 0 + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this DP.""" + return 0 + + +# For tests only +def eval_model( + model, + coords: Union[np.ndarray, torch.Tensor], + cells: Optional[Union[np.ndarray, torch.Tensor]], + atom_types: Union[np.ndarray, torch.Tensor, List[int]], + atomic: bool = False, + infer_batch_size: int = 2, + denoise: bool = False, +): + model = model.to(DEVICE) + energy_out = [] + atomic_energy_out = [] + force_out = [] + virial_out = [] + atomic_virial_out = [] + updated_coord_out = [] + logits_out = [] + err_msg = ( + f"All inputs should be the same format, " + f"but found {type(coords)}, {type(cells)}, {type(atom_types)} instead! " + ) + return_tensor = True + if isinstance(coords, torch.Tensor): + if cells is not None: + assert isinstance(cells, torch.Tensor), err_msg + assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) + atom_types = torch.tensor(atom_types, dtype=torch.long).to(DEVICE) + elif isinstance(coords, np.ndarray): + if cells is not None: + assert isinstance(cells, np.ndarray), err_msg + assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list) + atom_types = np.array(atom_types, dtype=np.int32) + return_tensor = False + + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + if isinstance(atom_types, torch.Tensor): + atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape( + nframes, -1 + ) + else: + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + coord_input = torch.tensor( + coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION + ).to(DEVICE) + type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE) + box_input = None + if cells is None: + pbc = False + else: + pbc = True + box_input = torch.tensor( + cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION + ).to(DEVICE) + num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size) + + for ii in range(num_iter): + batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + batch_box = None + if pbc: + batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size] + batch_output = model(batch_coord, batch_atype, box=batch_box) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + if not return_tensor: + if "energy" in batch_output: + energy_out.append(batch_output["energy"].detach().cpu().numpy()) + if "atom_energy" in batch_output: + atomic_energy_out.append( + batch_output["atom_energy"].detach().cpu().numpy() + ) + if "force" in batch_output: + force_out.append(batch_output["force"].detach().cpu().numpy()) + if "virial" in batch_output: + virial_out.append(batch_output["virial"].detach().cpu().numpy()) + if "atomic_virial" in batch_output: + atomic_virial_out.append( + batch_output["atomic_virial"].detach().cpu().numpy() + ) + if "updated_coord" in batch_output: + updated_coord_out.append( + batch_output["updated_coord"].detach().cpu().numpy() + ) + if "logits" in batch_output: + logits_out.append(batch_output["logits"].detach().cpu().numpy()) + else: + if "energy" in batch_output: + energy_out.append(batch_output["energy"]) + if "atom_energy" in batch_output: + atomic_energy_out.append(batch_output["atom_energy"]) + if "force" in batch_output: + force_out.append(batch_output["force"]) + if "virial" in batch_output: + virial_out.append(batch_output["virial"]) + if "atomic_virial" in batch_output: + atomic_virial_out.append(batch_output["atomic_virial"]) + if "updated_coord" in batch_output: + updated_coord_out.append(batch_output["updated_coord"]) + if "logits" in batch_output: + logits_out.append(batch_output["logits"]) + if not return_tensor: + energy_out = ( + np.concatenate(energy_out) if energy_out else np.zeros([nframes, 1]) + ) + atomic_energy_out = ( + np.concatenate(atomic_energy_out) + if atomic_energy_out + else np.zeros([nframes, natoms, 1]) + ) + force_out = ( + np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3]) + ) + virial_out = ( + np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3]) + ) + atomic_virial_out = ( + np.concatenate(atomic_virial_out) + if atomic_virial_out + else np.zeros([nframes, natoms, 3, 3]) + ) + updated_coord_out = ( + np.concatenate(updated_coord_out) if updated_coord_out else None + ) + logits_out = np.concatenate(logits_out) if logits_out else None + else: + energy_out = ( + torch.cat(energy_out) + if energy_out + else torch.zeros([nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE) + ) + atomic_energy_out = ( + torch.cat(atomic_energy_out) + if atomic_energy_out + else torch.zeros([nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to( + DEVICE + ) + ) + force_out = ( + torch.cat(force_out) + if force_out + else torch.zeros([nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to( + DEVICE + ) + ) + virial_out = ( + torch.cat(virial_out) + if virial_out + else torch.zeros([nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to( + DEVICE + ) + ) + atomic_virial_out = ( + torch.cat(atomic_virial_out) + if atomic_virial_out + else torch.zeros( + [nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION + ).to(DEVICE) + ) + updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None + logits_out = torch.cat(logits_out) if logits_out else None + if denoise: + return updated_coord_out, logits_out + else: + if not atomic: + return energy_out, force_out, virial_out + else: + return ( + energy_out, + force_out, + virial_out, + atomic_energy_out, + atomic_virial_out, + ) diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py new file mode 100644 index 0000000000..4906bb7a46 --- /dev/null +++ b/deepmd/pt/infer/inference.py @@ -0,0 +1,417 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import logging +import math +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import numpy as np +import torch +from torch.utils.data import ( + DataLoader, + RandomSampler, +) + +from deepmd.common import ( + expand_sys_str, +) +from deepmd.pt.loss import ( + DenoiseLoss, + EnergyStdLoss, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.env import ( + DEVICE, + JIT, + NUM_WORKERS, +) + +if torch.__version__.startswith("2"): + import torch._dynamo + + +class Tester: + def __init__( + self, + model_ckpt, + input_script=None, + system=None, + datafile=None, + numb_test=100, + detail_file=None, + shuffle_test=False, + head=None, + ): + """Construct a DeePMD tester. + + Args: + - config: The Dict-like configuration with training options. + """ + self.numb_test = numb_test + self.detail_file = detail_file + self.shuffle_test = shuffle_test + # Model + state_dict = torch.load(model_ckpt, map_location=DEVICE) + if "model" in state_dict: + state_dict = state_dict["model"] + model_params = state_dict["_extra_state"]["model_params"] + self.multi_task = "model_dict" in model_params + if self.multi_task: + assert head is not None, "Head must be specified in multitask mode!" + self.head = head + assert head in model_params["model_dict"], ( + f"Specified head {head} not found in model {model_ckpt}! " + f"Available ones are {list(model_params['model_dict'].keys())}." + ) + model_params = model_params["model_dict"][head] + state_dict_head = {"_extra_state": state_dict["_extra_state"]} + for item in state_dict: + if f"model.{head}." in item: + state_dict_head[ + item.replace(f"model.{head}.", "model.Default.") + ] = state_dict[item].clone() + state_dict = state_dict_head + + # Data + if input_script is not None: + with open(input_script) as fin: + self.input_script = json.load(fin) + training_params = self.input_script["training"] + if not self.multi_task: + assert ( + "validation_data" in training_params + ), f"Validation systems not found in {input_script}!" + self.systems = training_params["validation_data"]["systems"] + self.batchsize = training_params["validation_data"]["batch_size"] + logging.info( + f"Testing validation systems in input script: {input_script}" + ) + else: + assert ( + "data_dict" in training_params + ), f"Input script {input_script} is not in multi-task mode!" + assert head in training_params["data_dict"], ( + f"Specified head {head} not found in input script {input_script}! " + f"Available ones are {list(training_params['data_dict'].keys())}." + ) + assert ( + "validation_data" in training_params["data_dict"][head] + ), f"Validation systems not found in head {head} of {input_script}!" + self.systems = training_params["data_dict"][head]["validation_data"][ + "systems" + ] + self.batchsize = training_params["data_dict"][head]["validation_data"][ + "batch_size" + ] + logging.info( + f"Testing validation systems in head {head} of input script: {input_script}" + ) + elif system is not None: + self.systems = expand_sys_str(system) + self.batchsize = "auto" + logging.info("Testing systems in path: %s", system) + elif datafile is not None: + with open(datafile) as fin: + self.systems = fin.read().splitlines() + self.batchsize = "auto" + logging.info("Testing systems in file: %s", datafile) + else: + self.systems = None + self.batchsize = None + + self.type_split = False + if model_params["descriptor"]["type"] in ["se_e2_a"]: + self.type_split = True + self.model_params = deepcopy(model_params) + model_params["resuming"] = True + self.model = get_model(model_params).to(DEVICE) + + # Model Wrapper + self.wrapper = ModelWrapper(self.model) # inference only + if JIT: + self.wrapper = torch.jit.script(self.wrapper) + self.wrapper.load_state_dict(state_dict) + + # Loss + if "fitting_net" not in model_params: + assert ( + input_script is not None + ), "Denoise model must use --input-script mode!" + loss_params = self.input_script["loss"] + loss_type = loss_params.pop("type", "ener") + assert ( + loss_type == "denoise" + ), "Models without fitting_net only support denoise test!" + self.noise_settings = { + "noise_type": loss_params.pop("noise_type", "uniform"), + "noise": loss_params.pop("noise", 1.0), + "noise_mode": loss_params.pop("noise_mode", "fix_num"), + "mask_num": loss_params.pop("mask_num", 8), + "same_mask": loss_params.pop("same_mask", False), + "mask_coord": loss_params.pop("mask_coord", False), + "mask_type": loss_params.pop("mask_type", False), + "mask_type_idx": len(model_params["type_map"]) - 1, + } + loss_params["ntypes"] = len(model_params["type_map"]) + self.loss = DenoiseLoss(**loss_params) + else: + self.noise_settings = None + self.loss = EnergyStdLoss(inference=True) + + @staticmethod + def get_data(data): + batch_data = next(iter(data)) + for key in batch_data.keys(): + if key == "sid" or key == "fid": + continue + elif not isinstance(batch_data[key], list): + if batch_data[key] is not None: + batch_data[key] = batch_data[key].to(DEVICE) + else: + batch_data[key] = [item.to(DEVICE) for item in batch_data[key]] + input_dict = {} + for item in [ + "coord", + "atype", + "box", + ]: + if item in batch_data: + input_dict[item] = batch_data[item] + else: + input_dict[item] = None + label_dict = {} + for item in [ + "energy", + "force", + "virial", + "clean_coord", + "clean_type", + "coord_mask", + "type_mask", + ]: + if item in batch_data: + label_dict[item] = batch_data[item] + return input_dict, label_dict + + def run(self): + systems = self.systems + system_results = {} + global_sum_natoms = 0 + for cc, system in enumerate(systems): + logging.info("# ---------------output of dp test--------------- ") + logging.info(f"# testing system : {system}") + system_pred = [] + system_label = [] + dataset = DpLoaderSet( + [system], + self.batchsize, + self.model_params, + type_split=self.type_split, + noise_settings=self.noise_settings, + shuffle=self.shuffle_test, + ) + sampler = RandomSampler( + dataset, replacement=True, num_samples=dataset.total_batch + ) + if sampler is None: + logging.warning( + "Sampler not specified!" + ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. + dataloader = DataLoader( + dataset, + sampler=sampler, + batch_size=None, + num_workers=min( + NUM_WORKERS, 1 + ), # setting to 0 diverges the behavior of its iterator; should be >=1 + drop_last=False, + ) + data = iter(dataloader) + + single_results = {} + sum_natoms = 0 + sys_natoms = None + for ii in range(self.numb_test): + try: + input_dict, label_dict = self.get_data(data) + except StopIteration: + if ( + ii < dataset.total_batch + ): # Unexpected stop iteration.(test step < total batch) + raise StopIteration + else: + break + model_pred, _, _ = self.wrapper(**input_dict) + system_pred.append( + { + item: model_pred[item].detach().cpu().numpy() + for item in model_pred + } + ) + system_label.append( + { + item: label_dict[item].detach().cpu().numpy() + for item in label_dict + } + ) + natoms = int(input_dict["atype"].shape[-1]) + _, more_loss = self.loss( + model_pred, label_dict, natoms, 1.0, mae=True + ) # TODO: lr here is useless + if sys_natoms is None: + sys_natoms = natoms + else: + assert ( + sys_natoms == natoms + ), "Frames in one system must be the same!" + sum_natoms += natoms + for k, v in more_loss.items(): + if "mae" in k: + single_results[k] = single_results.get(k, 0.0) + v * natoms + else: + single_results[k] = single_results.get(k, 0.0) + v**2 * natoms + if self.detail_file is not None: + save_detail_file( + Path(self.detail_file), + system_pred, + system_label, + sys_natoms, + system_name=system, + append=(cc != 0), + ) + results = { + k: v / sum_natoms if "mae" in k else math.sqrt(v / sum_natoms) + for k, v in single_results.items() + } + for item in sorted(results.keys()): + logging.info(f"{item}: {results[item]:.4f}") + logging.info("# ----------------------------------------------- ") + for k, v in single_results.items(): + system_results[k] = system_results.get(k, 0.0) + v + global_sum_natoms += sum_natoms + + global_results = { + k: v / global_sum_natoms if "mae" in k else math.sqrt(v / global_sum_natoms) + for k, v in system_results.items() + } + logging.info("# ----------weighted average of errors----------- ") + if not self.multi_task: + logging.info(f"# number of systems : {len(systems)}") + else: + logging.info(f"# number of systems for {self.head}: {len(systems)}") + for item in sorted(global_results.keys()): + logging.info(f"{item}: {global_results[item]:.4f}") + logging.info("# ----------------------------------------------- ") + return global_results + + +def save_txt_file( + fname: Path, data: np.ndarray, header: str = "", append: bool = False +): + """Save numpy array to test file. + + Parameters + ---------- + fname : str + filename + data : np.ndarray + data to save to disk + header : str, optional + header string to use in file, by default "" + append : bool, optional + if true file will be appended insted of overwriting, by default False + """ + flags = "ab" if append else "w" + with fname.open(flags) as fp: + np.savetxt(fp, data, header=header) + + +def save_detail_file( + detail_path, system_pred, system_label, natoms, system_name, append=False +): + ntest = len(system_pred) + data_e = np.concatenate([item["energy"] for item in system_label]).reshape([-1, 1]) + pred_e = np.concatenate([item["energy"] for item in system_pred]).reshape([-1, 1]) + pe = np.concatenate( + ( + data_e, + pred_e, + ), + axis=1, + ) + save_txt_file( + detail_path.with_suffix(".e.out"), + pe, + header="%s: data_e pred_e" % system_name, + append=append, + ) + pe_atom = pe / natoms + save_txt_file( + detail_path.with_suffix(".e_peratom.out"), + pe_atom, + header="%s: data_e pred_e" % system_name, + append=append, + ) + if "force" in system_pred[0]: + data_f = np.concatenate([item["force"] for item in system_label]).reshape( + [-1, 3] + ) + pred_f = np.concatenate([item["force"] for item in system_pred]).reshape( + [-1, 3] + ) + pf = np.concatenate( + ( + data_f, + pred_f, + ), + axis=1, + ) + save_txt_file( + detail_path.with_suffix(".f.out"), + pf, + header="%s: data_fx data_fy data_fz pred_fx pred_fy pred_fz" % system_name, + append=append, + ) + if "virial" in system_pred[0]: + data_v = np.concatenate([item["virial"] for item in system_label]).reshape( + [-1, 9] + ) + pred_v = np.concatenate([item["virial"] for item in system_pred]).reshape( + [-1, 9] + ) + pv = np.concatenate( + ( + data_v, + pred_v, + ), + axis=1, + ) + save_txt_file( + detail_path.with_suffix(".v.out"), + pv, + header=f"{system_name}: data_vxx data_vxy data_vxz data_vyx data_vyy " + "data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx " + "pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz", + append=append, + ) + pv_atom = pv / natoms + save_txt_file( + detail_path.with_suffix(".v_peratom.out"), + pv_atom, + header=f"{system_name}: data_vxx data_vxy data_vxz data_vyx data_vyy " + "data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx " + "pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz", + append=append, + ) diff --git a/deepmd/pt/loss/__init__.py b/deepmd/pt/loss/__init__.py new file mode 100644 index 0000000000..d3a095ce13 --- /dev/null +++ b/deepmd/pt/loss/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .denoise import ( + DenoiseLoss, +) +from .ener import ( + EnergyStdLoss, +) +from .loss import ( + TaskLoss, +) + +__all__ = [ + "DenoiseLoss", + "EnergyStdLoss", + "TaskLoss", +] diff --git a/deepmd/pt/loss/denoise.py b/deepmd/pt/loss/denoise.py new file mode 100644 index 0000000000..cd12e70bb1 --- /dev/null +++ b/deepmd/pt/loss/denoise.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch +import torch.nn.functional as F + +from deepmd.pt.loss.loss import ( + TaskLoss, +) +from deepmd.pt.utils import ( + env, +) + + +class DenoiseLoss(TaskLoss): + def __init__( + self, + ntypes, + masked_token_loss=1.0, + masked_coord_loss=1.0, + norm_loss=0.01, + use_l1=True, + beta=1.00, + mask_loss_coord=True, + mask_loss_token=True, + **kwargs, + ): + """Construct a layer to compute loss on coord, and type reconstruction.""" + super().__init__() + self.ntypes = ntypes + self.masked_token_loss = masked_token_loss + self.masked_coord_loss = masked_coord_loss + self.norm_loss = norm_loss + self.has_coord = self.masked_coord_loss > 0.0 + self.has_token = self.masked_token_loss > 0.0 + self.has_norm = self.norm_loss > 0.0 + self.use_l1 = use_l1 + self.beta = beta + self.frac_beta = 1.00 / self.beta + self.mask_loss_coord = mask_loss_coord + self.mask_loss_token = mask_loss_token + + def forward(self, model_pred, label, natoms, learning_rate, mae=False): + """Return loss on coord and type denoise. + + Returns + ------- + - loss: Loss to minimize. + """ + updated_coord = model_pred["updated_coord"] + logits = model_pred["logits"] + clean_coord = label["clean_coord"] + clean_type = label["clean_type"] + coord_mask = label["coord_mask"] + type_mask = label["type_mask"] + + loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE) + more_loss = {} + if self.has_coord: + if self.mask_loss_coord: + masked_updated_coord = updated_coord[coord_mask] + masked_clean_coord = clean_coord[coord_mask] + if masked_updated_coord.size(0) > 0: + coord_loss = F.smooth_l1_loss( + masked_updated_coord.view(-1, 3), + masked_clean_coord.view(-1, 3), + reduction="mean", + beta=self.beta, + ) + else: + coord_loss = torch.tensor( + 0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + else: + coord_loss = F.smooth_l1_loss( + updated_coord.view(-1, 3), + clean_coord.view(-1, 3), + reduction="mean", + beta=self.beta, + ) + loss += self.masked_coord_loss * coord_loss + more_loss["coord_l1_error"] = coord_loss.detach() + if self.has_token: + if self.mask_loss_token: + masked_logits = logits[type_mask] + masked_target = clean_type[type_mask] + if masked_logits.size(0) > 0: + token_loss = F.nll_loss( + F.log_softmax(masked_logits, dim=-1), + masked_target, + reduction="mean", + ) + else: + token_loss = torch.tensor( + 0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + else: + token_loss = F.nll_loss( + F.log_softmax(logits.view(-1, self.ntypes - 1), dim=-1), + clean_type.view(-1), + reduction="mean", + ) + loss += self.masked_token_loss * token_loss + more_loss["token_error"] = token_loss.detach() + if self.has_norm: + norm_x = model_pred["norm_x"] + norm_delta_pair_rep = model_pred["norm_delta_pair_rep"] + loss += self.norm_loss * (norm_x + norm_delta_pair_rep) + more_loss["norm_loss"] = norm_x.detach() + norm_delta_pair_rep.detach() + + return loss, more_loss diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py new file mode 100644 index 0000000000..4ed765cf69 --- /dev/null +++ b/deepmd/pt/loss/ener.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch +import torch.nn.functional as F + +from deepmd.pt.loss.loss import ( + TaskLoss, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + GLOBAL_PT_FLOAT_PRECISION, +) + + +class EnergyStdLoss(TaskLoss): + def __init__( + self, + starter_learning_rate=1.0, + start_pref_e=0.0, + limit_pref_e=0.0, + start_pref_f=0.0, + limit_pref_f=0.0, + start_pref_v=0.0, + limit_pref_v=0.0, + use_l1_all: bool = False, + inference=False, + **kwargs, + ): + """Construct a layer to compute loss on energy, force and virial.""" + super().__init__() + self.starter_learning_rate = starter_learning_rate + self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference + self.has_f = (start_pref_f != 0.0 and limit_pref_f != 0.0) or inference + self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference + self.start_pref_e = start_pref_e + self.limit_pref_e = limit_pref_e + self.start_pref_f = start_pref_f + self.limit_pref_f = limit_pref_f + self.start_pref_v = start_pref_v + self.limit_pref_v = limit_pref_v + self.use_l1_all = use_l1_all + self.inference = inference + + def forward(self, model_pred, label, natoms, learning_rate, mae=False): + """Return loss on loss and force. + + Args: + - natoms: Tell atom count. + - p_energy: Predicted energy of all atoms. + - p_force: Predicted force per atom. + - l_energy: Actual energy of all atoms. + - l_force: Actual force per atom. + + Returns + ------- + - loss: Loss to minimize. + """ + coef = learning_rate / self.starter_learning_rate + pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef + pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef + pref_v = self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * coef + loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE) + more_loss = {} + # more_loss['log_keys'] = [] # showed when validation on the fly + # more_loss['test_keys'] = [] # showed when doing dp test + atom_norm = 1.0 / natoms + if self.has_e and "energy" in model_pred and "energy" in label: + if not self.use_l1_all: + l2_ener_loss = torch.mean( + torch.square(model_pred["energy"] - label["energy"]) + ) + if not self.inference: + more_loss["l2_ener_loss"] = l2_ener_loss.detach() + loss += atom_norm * (pref_e * l2_ener_loss) + rmse_e = l2_ener_loss.sqrt() * atom_norm + more_loss["rmse_e"] = rmse_e.detach() + # more_loss['log_keys'].append('rmse_e') + else: # use l1 and for all atoms + l1_ener_loss = F.l1_loss( + model_pred["energy"].reshape(-1), + label["energy"].reshape(-1), + reduction="sum", + ) + loss += pref_e * l1_ener_loss + more_loss["mae_e"] = F.l1_loss( + model_pred["energy"].reshape(-1), + label["energy"].reshape(-1), + reduction="mean", + ).detach() + # more_loss['log_keys'].append('rmse_e') + if mae: + mae_e = ( + torch.mean(torch.abs(model_pred["energy"] - label["energy"])) + * atom_norm + ) + more_loss["mae_e"] = mae_e.detach() + mae_e_all = torch.mean( + torch.abs(model_pred["energy"] - label["energy"]) + ) + more_loss["mae_e_all"] = mae_e_all.detach() + + if self.has_f and "force" in model_pred and "force" in label: + if "force_target_mask" in model_pred: + force_target_mask = model_pred["force_target_mask"] + else: + force_target_mask = None + if not self.use_l1_all: + if force_target_mask is not None: + diff_f = (label["force"] - model_pred["force"]) * force_target_mask + force_cnt = force_target_mask.squeeze(-1).sum(-1) + l2_force_loss = torch.mean( + torch.square(diff_f).mean(-1).sum(-1) / force_cnt + ) + else: + diff_f = label["force"] - model_pred["force"] + l2_force_loss = torch.mean(torch.square(diff_f)) + if not self.inference: + more_loss["l2_force_loss"] = l2_force_loss.detach() + loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) + rmse_f = l2_force_loss.sqrt() + more_loss["rmse_f"] = rmse_f.detach() + else: + l1_force_loss = F.l1_loss( + label["force"], model_pred["force"], reduction="none" + ) + if force_target_mask is not None: + l1_force_loss *= force_target_mask + force_cnt = force_target_mask.squeeze(-1).sum(-1) + more_loss["mae_f"] = ( + l1_force_loss.mean(-1).sum(-1) / force_cnt + ).mean() + l1_force_loss = (l1_force_loss.sum(-1).sum(-1) / force_cnt).sum() + else: + more_loss["mae_f"] = l1_force_loss.mean().detach() + l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum() + loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) + if mae: + mae_f = torch.mean(torch.abs(diff_f)) + more_loss["mae_f"] = mae_f.detach() + + if self.has_v and "virial" in model_pred and "virial" in label: + diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9) + l2_virial_loss = torch.mean(torch.square(diff_v)) + if not self.inference: + more_loss["l2_virial_loss"] = l2_virial_loss.detach() + loss += atom_norm * (pref_v * l2_virial_loss) + rmse_v = l2_virial_loss.sqrt() * atom_norm + more_loss["rmse_v"] = rmse_v.detach() + if mae: + mae_v = torch.mean(torch.abs(diff_v)) * atom_norm + more_loss["mae_v"] = mae_v.detach() + if not self.inference: + more_loss["rmse"] = torch.sqrt(loss.detach()) + return loss, more_loss diff --git a/deepmd/pt/loss/loss.py b/deepmd/pt/loss/loss.py new file mode 100644 index 0000000000..9f2c3a7ed7 --- /dev/null +++ b/deepmd/pt/loss/loss.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + + +class TaskLoss(torch.nn.Module): + def __init__(self, **kwargs): + """Construct loss.""" + super().__init__() + + def forward(self, model_pred, label, natoms, learning_rate): + """Return loss .""" + raise NotImplementedError diff --git a/deepmd/pt/model/__init__.py b/deepmd/pt/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt/model/backbone/__init__.py b/deepmd/pt/model/backbone/__init__.py new file mode 100644 index 0000000000..a76bdb2a2d --- /dev/null +++ b/deepmd/pt/model/backbone/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .backbone import ( + BackBone, +) +from .evoformer2b import ( + Evoformer2bBackBone, +) + +__all__ = [ + "BackBone", + "Evoformer2bBackBone", +] diff --git a/deepmd/pt/model/backbone/backbone.py b/deepmd/pt/model/backbone/backbone.py new file mode 100644 index 0000000000..ddeedfeff5 --- /dev/null +++ b/deepmd/pt/model/backbone/backbone.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + + +class BackBone(torch.nn.Module): + def __init__(self, **kwargs): + """BackBone base method.""" + super().__init__() + + def forward(self, **kwargs): + """Calculate backBone.""" + raise NotImplementedError diff --git a/deepmd/pt/model/backbone/evoformer2b.py b/deepmd/pt/model/backbone/evoformer2b.py new file mode 100644 index 0000000000..1146b3a298 --- /dev/null +++ b/deepmd/pt/model/backbone/evoformer2b.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.pt.model.backbone import ( + BackBone, +) +from deepmd.pt.model.network.network import ( + Evoformer2bEncoder, +) + + +class Evoformer2bBackBone(BackBone): + def __init__( + self, + nnei, + layer_num=6, + attn_head=8, + atomic_dim=1024, + pair_dim=100, + feature_dim=1024, + ffn_dim=2048, + post_ln=False, + final_layer_norm=True, + final_head_layer_norm=False, + emb_layer_norm=False, + atomic_residual=False, + evo_residual=False, + residual_factor=1.0, + activation_function="gelu", + **kwargs, + ): + """Construct an evoformer backBone.""" + super().__init__() + self.nnei = nnei + self.layer_num = layer_num + self.attn_head = attn_head + self.atomic_dim = atomic_dim + self.pair_dim = pair_dim + self.feature_dim = feature_dim + self.head_dim = feature_dim // attn_head + assert ( + feature_dim % attn_head == 0 + ), f"feature_dim {feature_dim} must be divided by attn_head {attn_head}!" + self.ffn_dim = ffn_dim + self.post_ln = post_ln + self.final_layer_norm = final_layer_norm + self.final_head_layer_norm = final_head_layer_norm + self.emb_layer_norm = emb_layer_norm + self.activation_function = activation_function + self.atomic_residual = atomic_residual + self.evo_residual = evo_residual + self.residual_factor = float(residual_factor) + self.encoder = Evoformer2bEncoder( + nnei=self.nnei, + layer_num=self.layer_num, + attn_head=self.attn_head, + atomic_dim=self.atomic_dim, + pair_dim=self.pair_dim, + feature_dim=self.feature_dim, + ffn_dim=self.ffn_dim, + post_ln=self.post_ln, + final_layer_norm=self.final_layer_norm, + final_head_layer_norm=self.final_head_layer_norm, + emb_layer_norm=self.emb_layer_norm, + atomic_residual=self.atomic_residual, + evo_residual=self.evo_residual, + residual_factor=self.residual_factor, + activation_function=self.activation_function, + ) + + def forward(self, atomic_rep, pair_rep, nlist, nlist_type, nlist_mask): + """Encoder the atomic and pair representations. + + Args: + - atomic_rep: Atomic representation with shape [nframes, nloc, atomic_dim]. + - pair_rep: Pair representation with shape [nframes, nloc, nnei, pair_dim]. + - nlist: Neighbor list with shape [nframes, nloc, nnei]. + - nlist_type: Neighbor types with shape [nframes, nloc, nnei]. + - nlist_mask: Neighbor mask with shape [nframes, nloc, nnei], `False` if blank. + + Returns + ------- + - atomic_rep: Atomic representation after encoder with shape [nframes, nloc, feature_dim]. + - transformed_atomic_rep: Transformed atomic representation after encoder with shape [nframes, nloc, atomic_dim]. + - pair_rep: Pair representation after encoder with shape [nframes, nloc, nnei, attn_head]. + - delta_pair_rep: Delta pair representation after encoder with shape [nframes, nloc, nnei, attn_head]. + - norm_x: Normalization loss of atomic_rep. + - norm_delta_pair_rep: Normalization loss of delta_pair_rep. + """ + ( + atomic_rep, + transformed_atomic_rep, + pair_rep, + delta_pair_rep, + norm_x, + norm_delta_pair_rep, + ) = self.encoder(atomic_rep, pair_rep, nlist, nlist_type, nlist_mask) + return ( + atomic_rep, + transformed_atomic_rep, + pair_rep, + delta_pair_rep, + norm_x, + norm_delta_pair_rep, + ) diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py new file mode 100644 index 0000000000..4252e34905 --- /dev/null +++ b/deepmd/pt/model/descriptor/__init__.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .descriptor import ( + Descriptor, + DescriptorBlock, + compute_std, + make_default_type_embedding, +) +from .dpa1 import ( + DescrptBlockSeAtten, + DescrptDPA1, +) +from .dpa2 import ( + DescrptDPA2, +) +from .env_mat import ( + prod_env_mat_se_a, +) +from .gaussian_lcc import ( + DescrptGaussianLcc, +) +from .hybrid import ( + DescrptBlockHybrid, +) +from .repformers import ( + DescrptBlockRepformers, +) +from .se_a import ( + DescrptBlockSeA, + DescrptSeA, +) + +__all__ = [ + "Descriptor", + "DescriptorBlock", + "compute_std", + "make_default_type_embedding", + "DescrptBlockSeA", + "DescrptBlockSeAtten", + "DescrptSeA", + "DescrptDPA1", + "DescrptDPA2", + "prod_env_mat_se_a", + "DescrptGaussianLcc", + "DescrptBlockHybrid", + "DescrptBlockRepformers", +] diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py new file mode 100644 index 0000000000..bb98e8dc15 --- /dev/null +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -0,0 +1,272 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + Callable, + List, + Optional, +) + +import numpy as np +import torch + +from deepmd.pt.model.network.network import ( + TypeEmbedNet, +) +from deepmd.pt.utils.plugin import ( + Plugin, +) + + +class Descriptor(torch.nn.Module, ABC): + """The descriptor. + Given the atomic coordinates, atomic types and neighbor list, + calculate the descriptor. + """ + + __plugins = Plugin() + local_cluster = False + + @abstractmethod + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + raise NotImplementedError + + @abstractmethod + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + raise NotImplementedError + + @abstractmethod + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + raise NotImplementedError + + @abstractmethod + def get_ntype(self) -> int: + """Returns the number of element types.""" + raise NotImplementedError + + @abstractmethod + def get_dim_out(self) -> int: + """Returns the output dimension.""" + raise NotImplementedError + + @abstractmethod + def compute_input_stats(self, merged): + """Update mean and stddev for descriptor elements.""" + raise NotImplementedError + + @abstractmethod + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + """Initialize the model bias by the statistics.""" + raise NotImplementedError + + @abstractmethod + def forward( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + ): + """Calculate descriptor.""" + raise NotImplementedError + + @staticmethod + def register(key: str) -> Callable: + """Register a descriptor plugin. + + Parameters + ---------- + key : str + the key of a descriptor + + Returns + ------- + Descriptor + the registered descriptor + + Examples + -------- + >>> @Descriptor.register("some_descrpt") + class SomeDescript(Descriptor): + pass + """ + return Descriptor.__plugins.register(key) + + @classmethod + def get_stat_name(cls, config): + descrpt_type = config["type"] + return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config) + + @classmethod + def get_data_process_key(cls, config): + descrpt_type = config["type"] + return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config) + + def __new__(cls, *args, **kwargs): + if cls is Descriptor: + try: + descrpt_type = kwargs["type"] + except KeyError: + raise KeyError("the type of descriptor should be set by `type`") + if descrpt_type in Descriptor.__plugins.plugins: + cls = Descriptor.__plugins.plugins[descrpt_type] + else: + raise RuntimeError("Unknown descriptor type: " + descrpt_type) + return super().__new__(cls) + + +class DescriptorBlock(torch.nn.Module, ABC): + """The building block of descriptor. + Given the input descriptor, provide with the atomic coordinates, + atomic types and neighbor list, calculate the new descriptor. + """ + + __plugins = Plugin() + local_cluster = False + + @staticmethod + def register(key: str) -> Callable: + """Register a DescriptorBlock plugin. + + Parameters + ---------- + key : str + the key of a DescriptorBlock + + Returns + ------- + DescriptorBlock + the registered DescriptorBlock + + Examples + -------- + >>> @DescriptorBlock.register("some_descrpt") + class SomeDescript(DescriptorBlock): + pass + """ + return DescriptorBlock.__plugins.register(key) + + def __new__(cls, *args, **kwargs): + if cls is DescriptorBlock: + try: + descrpt_type = kwargs["type"] + except KeyError: + raise KeyError("the type of DescriptorBlock should be set by `type`") + if descrpt_type in DescriptorBlock.__plugins.plugins: + cls = DescriptorBlock.__plugins.plugins[descrpt_type] + else: + raise RuntimeError("Unknown DescriptorBlock type: " + descrpt_type) + return super().__new__(cls) + + @abstractmethod + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + raise NotImplementedError + + @abstractmethod + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + raise NotImplementedError + + @abstractmethod + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + raise NotImplementedError + + @abstractmethod + def get_ntype(self) -> int: + """Returns the number of element types.""" + raise NotImplementedError + + @abstractmethod + def get_dim_out(self) -> int: + """Returns the output dimension.""" + raise NotImplementedError + + @abstractmethod + def get_dim_in(self) -> int: + """Returns the output dimension.""" + raise NotImplementedError + + @abstractmethod + def compute_input_stats(self, merged): + """Update mean and stddev for DescriptorBlock elements.""" + raise NotImplementedError + + @abstractmethod + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + """Initialize the model bias by the statistics.""" + raise NotImplementedError + + def share_params(self, base_class, shared_level, resume=False): + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + if shared_level == 0: + # link buffers + if hasattr(self, "mean") and not resume: + # in case of change params during resume + sumr_base, suma_base, sumn_base, sumr2_base, suma2_base = ( + base_class.sumr, + base_class.suma, + base_class.sumn, + base_class.sumr2, + base_class.suma2, + ) + sumr, suma, sumn, sumr2, suma2 = ( + self.sumr, + self.suma, + self.sumn, + self.sumr2, + self.suma2, + ) + base_class.init_desc_stat( + sumr_base + sumr, + suma_base + suma, + sumn_base + sumn, + sumr2_base + sumr2, + suma2_base + suma2, + ) + self.mean = base_class.mean + self.stddev = base_class.stddev + # self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model + # the following will successfully link all the params except buffers + for item in self._modules: + self._modules[item] = base_class._modules[item] + else: + raise NotImplementedError + + @abstractmethod + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: Optional[torch.Tensor] = None, + mapping: Optional[torch.Tensor] = None, + ): + """Calculate DescriptorBlock.""" + raise NotImplementedError + + +def compute_std(sumv2, sumv, sumn, rcut_r): + """Compute standard deviation.""" + if sumn == 0: + return 1.0 / rcut_r + val = np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn)) + if np.abs(val) < 1e-2: + val = 1e-2 + return val + + +def make_default_type_embedding( + ntypes, +): + aux = {} + aux["tebd_dim"] = 8 + return TypeEmbedNet(ntypes, aux["tebd_dim"]), aux diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py new file mode 100644 index 0000000000..dd34b815c9 --- /dev/null +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Optional, +) + +import torch + +from deepmd.pt.model.descriptor import ( + Descriptor, +) +from deepmd.pt.model.network.network import ( + TypeEmbedNet, +) + +from .se_atten import ( + DescrptBlockSeAtten, +) + + +@Descriptor.register("dpa1") +@Descriptor.register("se_atten") +class DescrptDPA1(Descriptor): + def __init__( + self, + rcut, + rcut_smth, + sel, + ntypes: int, + neuron: list = [25, 50, 100], + axis_neuron: int = 16, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + # set_davg_zero: bool = False, + set_davg_zero: bool = True, # TODO + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + post_ln=True, + ffn=False, + ffn_embed_dim=1024, + activation="tanh", + scaling_factor=1.0, + head_num=1, + normalize=True, + temperature=None, + return_rot=False, + concat_output_tebd: bool = True, + type: Optional[str] = None, + ): + super().__init__() + del type + self.se_atten = DescrptBlockSeAtten( + rcut, + rcut_smth, + sel, + ntypes, + neuron=neuron, + axis_neuron=axis_neuron, + tebd_dim=tebd_dim, + tebd_input_mode=tebd_input_mode, + set_davg_zero=set_davg_zero, + attn=attn, + attn_layer=attn_layer, + attn_dotr=attn_dotr, + attn_mask=attn_mask, + post_ln=post_ln, + ffn=ffn, + ffn_embed_dim=ffn_embed_dim, + activation=activation, + scaling_factor=scaling_factor, + head_num=head_num, + normalize=normalize, + temperature=temperature, + return_rot=return_rot, + ) + self.type_embedding = TypeEmbedNet(ntypes, tebd_dim) + self.tebd_dim = tebd_dim + self.concat_output_tebd = concat_output_tebd + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.se_atten.get_rcut() + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return self.se_atten.get_nsel() + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.se_atten.get_sel() + + def get_ntype(self) -> int: + """Returns the number of element types.""" + return self.se_atten.get_ntype() + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + ret = self.se_atten.get_dim_out() + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + return self.se_atten.dim_emb + + def compute_input_stats(self, merged): + return self.se_atten.compute_input_stats(merged) + + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + self.se_atten.init_desc_stat(sumr, suma, sumn, sumr2, suma2) + + @classmethod + def get_stat_name(cls, config): + descrpt_type = config["type"] + assert descrpt_type in ["dpa1", "se_atten"] + return f'stat_file_dpa1_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz' + + @classmethod + def get_data_process_key(cls, config): + descrpt_type = config["type"] + assert descrpt_type in ["dpa1", "se_atten"] + return {"sel": config["sel"], "rcut": config["rcut"]} + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + ): + del mapping + nframes, nloc, nnei = nlist.shape + nall = extended_coord.view(nframes, -1).shape[1] // 3 + g1_ext = self.type_embedding(extended_atype) + g1_inp = g1_ext[:, :nloc, :] + g1, env_mat, diff, rot_mat, sw = self.se_atten( + nlist, + extended_coord, + extended_atype, + g1_ext, + mapping=None, + ) + if self.concat_output_tebd: + g1 = torch.cat([g1, g1_inp], dim=-1) + return g1, env_mat, diff, rot_mat, sw diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py new file mode 100644 index 0000000000..fbdbc91dd9 --- /dev/null +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Optional, +) + +import torch + +from deepmd.pt.model.descriptor import ( + Descriptor, +) +from deepmd.pt.model.network.network import ( + Identity, + Linear, + TypeEmbedNet, +) +from deepmd.pt.utils.nlist import ( + build_multiple_neighbor_list, + get_multiple_nlist_key, +) + +from .repformers import ( + DescrptBlockRepformers, +) +from .se_atten import ( + DescrptBlockSeAtten, +) + + +@Descriptor.register("dpa2") +class DescrptDPA2(Descriptor): + def __init__( + self, + ntypes: int, + repinit_rcut: float, + repinit_rcut_smth: float, + repinit_nsel: int, + repformer_rcut: float, + repformer_rcut_smth: float, + repformer_nsel: int, + # kwargs + tebd_dim: int = 8, + concat_output_tebd: bool = True, + repinit_neuron: List[int] = [25, 50, 100], + repinit_axis_neuron: int = 16, + repinit_set_davg_zero: bool = True, # TODO + repinit_activation="tanh", + # repinit still unclear: + # ffn, ffn_embed_dim, scaling_factor, normalize, + repformer_nlayers: int = 3, + repformer_g1_dim: int = 128, + repformer_g2_dim: int = 16, + repformer_axis_dim: int = 4, + repformer_do_bn_mode: str = "no", + repformer_bn_momentum: float = 0.1, + repformer_update_g1_has_conv: bool = True, + repformer_update_g1_has_drrd: bool = True, + repformer_update_g1_has_grrg: bool = True, + repformer_update_g1_has_attn: bool = True, + repformer_update_g2_has_g1g1: bool = True, + repformer_update_g2_has_attn: bool = True, + repformer_update_h2: bool = False, + repformer_attn1_hidden: int = 64, + repformer_attn1_nhead: int = 4, + repformer_attn2_hidden: int = 16, + repformer_attn2_nhead: int = 4, + repformer_attn2_has_gate: bool = False, + repformer_activation: str = "tanh", + repformer_update_style: str = "res_avg", + repformer_set_davg_zero: bool = True, # TODO + repformer_add_type_ebd_to_seq: bool = False, + type: Optional[ + str + ] = None, # work around the bad design in get_trainer and DpLoaderSet! + rcut: Optional[ + float + ] = None, # work around the bad design in get_trainer and DpLoaderSet! + rcut_smth: Optional[ + float + ] = None, # work around the bad design in get_trainer and DpLoaderSet! + sel: Optional[ + int + ] = None, # work around the bad design in get_trainer and DpLoaderSet! + ): + r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492. + + Parameters + ---------- + ntypes : int + Number of atom types + repinit_rcut : float + The cut-off radius of the repinit block + repinit_rcut_smth : float + From this position the inverse distance smoothly decays + to 0 at the cut-off. Use in the repinit block. + repinit_nsel : int + Maximally possible number of neighbors for repinit block. + repformer_rcut : float + The cut-off radius of the repformer block + repformer_rcut_smth : float + From this position the inverse distance smoothly decays + to 0 at the cut-off. Use in the repformer block. + repformer_nsel : int + Maximally possible number of neighbors for repformer block. + tebd_dim : int + The dimension of atom type embedding + concat_output_tebd : bool + Whether to concat type embedding at the output of the descriptor. + repinit_neuron : List[int] + repinit block: the number of neurons in the embedding net. + repinit_axis_neuron : int + repinit block: the number of dimension of split in the + symmetrization op. + repinit_activation : str + repinit block: the activation function in the embedding net + repformer_nlayers : int + repformers block: the number of repformer layers + repformer_g1_dim : int + repformers block: the dimension of single-atom rep + repformer_g2_dim : int + repformers block: the dimension of invariant pair-atom rep + repformer_axis_dim : int + repformers block: the number of dimension of split in the + symmetrization ops. + repformer_do_bn_mode : bool + repformers block: do batch norm in the repformer layers + repformer_bn_momentum : float + repformers block: moment in the batch normalization + repformer_update_g1_has_conv : bool + repformers block: update the g1 rep with convolution term + repformer_update_g1_has_drrd : bool + repformers block: update the g1 rep with the drrd term + repformer_update_g1_has_grrg : bool + repformers block: update the g1 rep with the grrg term + repformer_update_g1_has_attn : bool + repformers block: update the g1 rep with the localized + self-attention + repformer_update_g2_has_g1g1 : bool + repformers block: update the g2 rep with the g1xg1 term + repformer_update_g2_has_attn : bool + repformers block: update the g2 rep with the gated self-attention + repformer_update_h2 : bool + repformers block: update the h2 rep + repformer_attn1_hidden : int + repformers block: the hidden dimension of localized self-attention + repformer_attn1_nhead : int + repformers block: the number of heads in localized self-attention + repformer_attn2_hidden : int + repformers block: the hidden dimension of gated self-attention + repformer_attn2_nhead : int + repformers block: the number of heads in gated self-attention + repformer_attn2_has_gate : bool + repformers block: has gate in the gated self-attention + repformer_activation : str + repformers block: the activation function in the MLPs. + repformer_update_style : str + repformers block: style of update a rep. + can be res_avg or res_incr. + res_avg updates a rep `u` with: + u = 1/\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + res_incr updates a rep `u` with: + u = u + 1/\sqrt{n} (u_1 + u_2 + ... + u_n) + repformer_set_davg_zero : bool + repformers block: set the avg to zero in statistics + repformer_add_type_ebd_to_seq : bool + repformers block: concatenate the type embedding at the output. + + Returns + ------- + descriptor: torch.Tensor + the descriptor of shape nb x nloc x g1_dim. + invariant single-atom representation. + g2: torch.Tensor + invariant pair-atom representation. + h2: torch.Tensor + equivariant pair-atom representation. + rot_mat: torch.Tensor + rotation matrix for equivariant fittings + sw: torch.Tensor + The switch function for decaying inverse distance. + + """ + super().__init__() + del type, rcut, rcut_smth, sel + self.repinit = DescrptBlockSeAtten( + repinit_rcut, + repinit_rcut_smth, + repinit_nsel, + ntypes, + attn_layer=0, + neuron=repinit_neuron, + axis_neuron=repinit_axis_neuron, + tebd_dim=tebd_dim, + tebd_input_mode="concat", + # tebd_input_mode='dot_residual_s', + set_davg_zero=repinit_set_davg_zero, + activation=repinit_activation, + ) + self.repformers = DescrptBlockRepformers( + repformer_rcut, + repformer_rcut_smth, + repformer_nsel, + ntypes, + nlayers=repformer_nlayers, + g1_dim=repformer_g1_dim, + g2_dim=repformer_g2_dim, + axis_dim=repformer_axis_dim, + direct_dist=False, + do_bn_mode=repformer_do_bn_mode, + bn_momentum=repformer_bn_momentum, + update_g1_has_conv=repformer_update_g1_has_conv, + update_g1_has_drrd=repformer_update_g1_has_drrd, + update_g1_has_grrg=repformer_update_g1_has_grrg, + update_g1_has_attn=repformer_update_g1_has_attn, + update_g2_has_g1g1=repformer_update_g2_has_g1g1, + update_g2_has_attn=repformer_update_g2_has_attn, + update_h2=repformer_update_h2, + attn1_hidden=repformer_attn1_hidden, + attn1_nhead=repformer_attn1_nhead, + attn2_hidden=repformer_attn2_hidden, + attn2_nhead=repformer_attn2_nhead, + attn2_has_gate=repformer_attn2_has_gate, + activation=repformer_activation, + update_style=repformer_update_style, + set_davg_zero=repformer_set_davg_zero, + smooth=True, + add_type_ebd_to_seq=repformer_add_type_ebd_to_seq, + ) + self.type_embedding = TypeEmbedNet(ntypes, tebd_dim) + if self.repinit.dim_out == self.repformers.dim_in: + self.g1_shape_tranform = Identity() + else: + self.g1_shape_tranform = Linear( + self.repinit.dim_out, + self.repformers.dim_in, + bias=False, + init="glorot", + ) + assert self.repinit.rcut > self.repformers.rcut + assert self.repinit.sel[0] > self.repformers.sel[0] + self.concat_output_tebd = concat_output_tebd + self.tebd_dim = tebd_dim + self.rcut = self.repinit.get_rcut() + self.ntypes = ntypes + self.sel = self.repinit.sel + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntype(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + ret = self.repformers.dim_out + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.repformers.dim_emb + + def compute_input_stats(self, merged): + sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] + for ii, descrpt in enumerate([self.repinit, self.repformers]): + merged_tmp = [ + { + key: item[key] if not isinstance(item[key], list) else item[key][ii] + for key in item + } + for item in merged + ] + ( + sumr_tmp, + suma_tmp, + sumn_tmp, + sumr2_tmp, + suma2_tmp, + ) = descrpt.compute_input_stats(merged_tmp) + sumr.append(sumr_tmp) + suma.append(suma_tmp) + sumn.append(sumn_tmp) + sumr2.append(sumr2_tmp) + suma2.append(suma2_tmp) + return sumr, suma, sumn, sumr2, suma2 + + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + for ii, descrpt in enumerate([self.repinit, self.repformers]): + descrpt.init_desc_stat(sumr[ii], suma[ii], sumn[ii], sumr2[ii], suma2[ii]) + + @classmethod + def get_stat_name(cls, config): + descrpt_type = config["type"] + assert descrpt_type in ["dpa2"] + return ( + f'stat_file_dpa2_repinit_rcut{config["repinit_rcut"]:.2f}_smth{config["repinit_rcut_smth"]:.2f}_sel{config["repinit_nsel"]}' + f'_repformer_rcut{config["repformer_rcut"]:.2f}_smth{config["repformer_rcut_smth"]:.2f}_sel{config["repformer_nsel"]}.npz' + ) + + @classmethod + def get_data_process_key(cls, config): + descrpt_type = config["type"] + assert descrpt_type in ["dpa2"] + return { + "sel": [config["repinit_nsel"], config["repformer_nsel"]], + "rcut": [config["repinit_rcut"], config["repformer_rcut"]], + } + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + ): + nframes, nloc, nnei = nlist.shape + nall = extended_coord.view(nframes, -1).shape[1] // 3 + # nlists + nlist_dict = build_multiple_neighbor_list( + extended_coord, + nlist, + [self.repformers.get_rcut(), self.repinit.get_rcut()], + [self.repformers.get_nsel(), self.repinit.get_nsel()], + ) + # repinit + g1_ext = self.type_embedding(extended_atype) + g1_inp = g1_ext[:, :nloc, :] + g1, _, _, _, _ = self.repinit( + nlist_dict[ + get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel()) + ], + extended_coord, + extended_atype, + g1_ext, + mapping, + ) + # linear to change shape + g1 = self.g1_shape_tranform(g1) + # mapping g1 + assert mapping is not None + mapping_ext = ( + mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1]) + ) + g1_ext = torch.gather(g1, 1, mapping_ext) + # repformer + g1, g2, h2, rot_mat, sw = self.repformers( + nlist_dict[ + get_multiple_nlist_key( + self.repformers.get_rcut(), self.repformers.get_nsel() + ) + ], + extended_coord, + extended_atype, + g1_ext, + mapping, + ) + if self.concat_output_tebd: + g1 = torch.cat([g1, g1_inp], dim=-1) + return g1, g2, h2, rot_mat, sw diff --git a/deepmd/pt/model/descriptor/env_mat.py b/deepmd/pt/model/descriptor/env_mat.py new file mode 100644 index 0000000000..63181388df --- /dev/null +++ b/deepmd/pt/model/descriptor/env_mat.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + +from deepmd.pt.utils.preprocess import ( + compute_smooth_weight, +) + + +def _make_env_mat_se_a(nlist, coord, rcut: float, ruct_smth: float): + """Make smooth environment matrix.""" + bsz, natoms, nnei = nlist.shape + coord = coord.view(bsz, -1, 3) + mask = nlist >= 0 + nlist = nlist * mask + coord_l = coord[:, :natoms].view(bsz, -1, 1, 3) + index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3) + coord_r = torch.gather(coord, 1, index) + coord_r = coord_r.view(bsz, natoms, nnei, 3) + diff = coord_r - coord_l + length = torch.linalg.norm(diff, dim=-1, keepdim=True) + # for index 0 nloc atom + length = length + ~mask.unsqueeze(-1) + t0 = 1 / length + t1 = diff / length**2 + weight = compute_smooth_weight(length, ruct_smth, rcut) + env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight * mask.unsqueeze(-1) + return env_mat_se_a, diff * mask.unsqueeze(-1), weight + + +def prod_env_mat_se_a( + extended_coord, nlist, atype, mean, stddev, rcut: float, rcut_smth: float +): + """Generate smooth environment matrix from atom coordinates and other context. + + Args: + - extended_coord: Copied atom coordinates with shape [nframes, nall*3]. + - atype: Atom types with shape [nframes, nloc]. + - natoms: Batched atom statisics with shape [len(sec)+2]. + - box: Batched simulation box with shape [nframes, 9]. + - mean: Average value of descriptor per element type with shape [len(sec), nnei, 4]. + - stddev: Standard deviation of descriptor per element type with shape [len(sec), nnei, 4]. + - deriv_stddev: StdDev of descriptor derivative per element type with shape [len(sec), nnei, 4, 3]. + - rcut: Cut-off radius. + - rcut_smth: Smooth hyper-parameter for pair force & energy. + + Returns + ------- + - env_mat_se_a: Shape is [nframes, natoms[1]*nnei*4]. + """ + nframes = extended_coord.shape[0] + _env_mat_se_a, diff, switch = _make_env_mat_se_a( + nlist, extended_coord, rcut, rcut_smth + ) # shape [n_atom, dim, 4] + t_avg = mean[atype] # [n_atom, dim, 4] + t_std = stddev[atype] # [n_atom, dim, 4] + env_mat_se_a = (_env_mat_se_a - t_avg) / t_std + return env_mat_se_a, diff, switch diff --git a/deepmd/pt/model/descriptor/gaussian_lcc.py b/deepmd/pt/model/descriptor/gaussian_lcc.py new file mode 100644 index 0000000000..26ec1175b8 --- /dev/null +++ b/deepmd/pt/model/descriptor/gaussian_lcc.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch +import torch.nn as nn + +from deepmd.pt.model.descriptor import ( + Descriptor, +) +from deepmd.pt.model.network.network import ( + Evoformer3bEncoder, + GaussianEmbedding, + TypeEmbedNet, +) +from deepmd.pt.utils import ( + env, +) + + +class DescrptGaussianLcc(Descriptor): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + num_pair: int, + embed_dim: int = 768, + kernel_num: int = 128, + pair_embed_dim: int = 64, + num_block: int = 1, + layer_num: int = 12, + attn_head: int = 48, + pair_hidden_dim: int = 16, + ffn_embedding_dim: int = 768, + dropout: float = 0.0, + droppath_prob: float = 0.1, + pair_dropout: float = 0.25, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + pre_ln: bool = True, + do_tag_embedding: bool = False, + tag_ener_pref: bool = False, + atomic_sum_gbf: bool = False, + pre_add_seq: bool = True, + tri_update: bool = True, + **kwargs, + ): + """Construct a descriptor of Gaussian Based Local Cluster. + + Args: + - rcut: Cut-off radius. + - rcut_smth: Smooth hyper-parameter for pair force & energy. **Not used in this descriptor**. + - sel: For each element type, how many atoms is selected as neighbors. + - ntypes: Number of atom types. + - num_pair: Number of atom type pairs. Default is 2 * ntypes. + - kernel_num: Number of gaussian kernels. + - embed_dim: Dimension of atomic representation. + - pair_embed_dim: Dimension of pair representation. + - num_block: Number of evoformer blocks. + - layer_num: Number of attention layers. + - attn_head: Number of attention heads. + - pair_hidden_dim: Hidden dimension of pair representation during attention process. + - ffn_embedding_dim: Dimension during feed forward network. + - dropout: Dropout probability of atomic representation. + - droppath_prob: If not zero, it will use drop paths (Stochastic Depth) per sample and ignore `dropout`. + - pair_dropout: Dropout probability of pair representation during triangular update. + - attention_dropout: Dropout probability during attetion process. + - activation_dropout: Dropout probability of pair feed forward network. + - pre_ln: Do previous layer norm or not. + - do_tag_embedding: Add tag embedding to atomic and pair representations. (`tags`, `tags2`, `tags3` must exist) + - atomic_sum_gbf: Add sum of gaussian outputs to atomic representation or not. + - pre_add_seq: Add output of other descriptor (if has) to the atomic representation before attention. + """ + super().__init__() + self.rcut = rcut + self.rcut_smth = rcut_smth + self.embed_dim = embed_dim + self.num_pair = num_pair + self.kernel_num = kernel_num + self.pair_embed_dim = pair_embed_dim + self.num_block = num_block + self.layer_num = layer_num + self.attention_heads = attn_head + self.pair_hidden_dim = pair_hidden_dim + self.ffn_embedding_dim = ffn_embedding_dim + self.dropout = dropout + self.droppath_prob = droppath_prob + self.pair_dropout = pair_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.pre_ln = pre_ln + self.do_tag_embedding = do_tag_embedding + self.tag_ener_pref = tag_ener_pref + self.atomic_sum_gbf = atomic_sum_gbf + self.local_cluster = True + self.pre_add_seq = pre_add_seq + self.tri_update = tri_update + + if isinstance(sel, int): + sel = [sel] + + self.ntypes = ntypes + self.sec = torch.tensor(sel) + self.nnei = sum(sel) + + if self.do_tag_embedding: + self.tag_encoder = nn.Embedding(3, self.embed_dim) + self.tag_encoder2 = nn.Embedding(2, self.embed_dim) + self.tag_type_embedding = TypeEmbedNet(10, pair_embed_dim) + self.edge_type_embedding = nn.Embedding( + (ntypes + 1) * (ntypes + 1), + pair_embed_dim, + padding_idx=(ntypes + 1) * (ntypes + 1) - 1, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + ) + self.gaussian_encoder = GaussianEmbedding( + rcut, + kernel_num, + num_pair, + embed_dim, + pair_embed_dim, + sel, + ntypes, + atomic_sum_gbf, + ) + self.backbone = Evoformer3bEncoder( + self.nnei, + layer_num=self.layer_num, + attn_head=self.attention_heads, + atomic_dim=self.embed_dim, + pair_dim=self.pair_embed_dim, + pair_hidden_dim=self.pair_hidden_dim, + ffn_embedding_dim=self.ffn_embedding_dim, + dropout=self.dropout, + droppath_prob=self.droppath_prob, + pair_dropout=self.pair_dropout, + attention_dropout=self.attention_dropout, + activation_dropout=self.activation_dropout, + pre_ln=self.pre_ln, + tri_update=self.tri_update, + ) + + @property + def dim_out(self): + """Returns the output dimension of atomic representation.""" + return self.embed_dim + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.embed_dim + + @property + def dim_emb(self): + """Returns the output dimension of pair representation.""" + return self.pair_embed_dim + + def compute_input_stats(self, merged): + """Update mean and stddev for descriptor elements.""" + return [], [], [], [], [] + + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + pass + + def forward( + self, + extended_coord, + nlist, + atype, + nlist_type, + nlist_loc=None, + atype_tebd=None, + nlist_tebd=None, + seq_input=None, + ): + """Calculate the atomic and pair representations of this descriptor. + + Args: + - extended_coord: Copied atom coordinates with shape [nframes, nall, 3]. + - nlist: Neighbor list with shape [nframes, nloc, nnei]. + - atype: Atom type with shape [nframes, nloc]. + - nlist_type: Atom type of neighbors with shape [nframes, nloc, nnei]. + - nlist_loc: Local index of neighbor list with shape [nframes, nloc, nnei]. + - atype_tebd: Atomic type embedding with shape [nframes, nloc, tebd_dim]. + - nlist_tebd: Type embeddings of neighbor with shape [nframes, nloc, nnei, tebd_dim]. + - seq_input: The sequential input from other descriptor with + shape [nframes, nloc, tebd_dim] or [nframes * nloc, 1 + nnei, tebd_dim] + + Returns + ------- + - result: descriptor with shape [nframes, nloc, self.filter_neuron[-1] * self.axis_neuron]. + - ret: environment matrix with shape [nframes, nloc, self.neei, out_size] + """ + nframes, nloc = nlist.shape[:2] + nall = extended_coord.shape[1] + nlist2 = torch.cat( + [ + torch.arange(0, nloc, device=nlist.device) + .reshape(1, nloc, 1) + .expand(nframes, -1, -1), + nlist, + ], + dim=-1, + ) + nlist_loc2 = torch.cat( + [ + torch.arange(0, nloc, device=nlist_loc.device) + .reshape(1, nloc, 1) + .expand(nframes, -1, -1), + nlist_loc, + ], + dim=-1, + ) + nlist_type2 = torch.cat([atype.reshape(nframes, nloc, 1), nlist_type], dim=-1) + nnei2_mask = nlist2 != -1 + padding_mask = nlist2 == -1 + nlist2 = nlist2 * nnei2_mask + nlist_loc2 = nlist_loc2 * nnei2_mask + + # nframes x nloc x (1 + nnei2) x (1 + nnei2) + pair_mask = nnei2_mask.unsqueeze(-1) * nnei2_mask.unsqueeze(-2) + # nframes x nloc x (1 + nnei2) x (1 + nnei2) x head + attn_mask = torch.zeros( + [nframes, nloc, 1 + self.nnei, 1 + self.nnei, self.attention_heads], + device=nlist.device, + dtype=extended_coord.dtype, + ) + attn_mask.masked_fill_(padding_mask.unsqueeze(2).unsqueeze(-1), float("-inf")) + # (nframes x nloc) x head x (1 + nnei2) x (1 + nnei2) + attn_mask = ( + attn_mask.reshape( + nframes * nloc, 1 + self.nnei, 1 + self.nnei, self.attention_heads + ) + .permute(0, 3, 1, 2) + .contiguous() + ) + + # Atomic feature + # [(nframes x nloc) x (1 + nnei2) x tebd_dim] + atom_feature = torch.gather( + atype_tebd, + dim=1, + index=nlist_loc2.reshape(nframes, -1) + .unsqueeze(-1) + .expand(-1, -1, self.embed_dim), + ).reshape(nframes * nloc, 1 + self.nnei, self.embed_dim) + if self.pre_add_seq and seq_input is not None: + first_dim = seq_input.shape[0] + if first_dim == nframes * nloc: + atom_feature += seq_input + elif first_dim == nframes: + atom_feature_seq = torch.gather( + seq_input, + dim=1, + index=nlist_loc2.reshape(nframes, -1) + .unsqueeze(-1) + .expand(-1, -1, self.embed_dim), + ).reshape(nframes * nloc, 1 + self.nnei, self.embed_dim) + atom_feature += atom_feature_seq + else: + raise RuntimeError + atom_feature = atom_feature * nnei2_mask.reshape( + nframes * nloc, 1 + self.nnei, 1 + ) + + # Pair feature + # [(nframes x nloc) x (1 + nnei2)] + nlist_type2_reshape = nlist_type2.reshape(nframes * nloc, 1 + self.nnei) + # [(nframes x nloc) x (1 + nnei2) x (1 + nnei2)] + edge_type = nlist_type2_reshape.unsqueeze(-1) * ( + self.ntypes + 1 + ) + nlist_type2_reshape.unsqueeze(-2) + # [(nframes x nloc) x (1 + nnei2) x (1 + nnei2) x pair_dim] + edge_feature = self.edge_type_embedding(edge_type) + + # [(nframes x nloc) x (1 + nnei2) x (1 + nnei2) x 2] + edge_type_2dim = torch.cat( + [ + nlist_type2_reshape.view(nframes * nloc, 1 + self.nnei, 1, 1).expand( + -1, -1, 1 + self.nnei, -1 + ), + nlist_type2_reshape.view(nframes * nloc, 1, 1 + self.nnei, 1).expand( + -1, 1 + self.nnei, -1, -1 + ) + + self.ntypes, + ], + dim=-1, + ) + # [(nframes x nloc) x (1 + nnei2) x 3] + coord_selected = torch.gather( + extended_coord.unsqueeze(1) + .expand(-1, nloc, -1, -1) + .reshape(nframes * nloc, nall, 3), + dim=1, + index=nlist2.reshape(nframes * nloc, 1 + self.nnei, 1).expand(-1, -1, 3), + ) + + # Update pair features (or and atomic features) with gbf features + # delta_pos: [(nframes x nloc) x (1 + nnei2) x (1 + nnei2) x 3]. + atomic_feature, pair_feature, delta_pos = self.gaussian_encoder( + coord_selected, atom_feature, edge_type_2dim, edge_feature + ) + # [(nframes x nloc) x (1 + nnei2) x (1 + nnei2) x pair_dim] + attn_bias = pair_feature + + # output: [(nframes x nloc) x (1 + nnei2) x tebd_dim] + # pair: [(nframes x nloc) x (1 + nnei2) x (1 + nnei2) x pair_dim] + output, pair = self.backbone( + atomic_feature, + pair=attn_bias, + attn_mask=attn_mask, + pair_mask=pair_mask, + atom_mask=nnei2_mask.reshape(nframes * nloc, 1 + self.nnei), + ) + + return output, pair, delta_pos, None diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py new file mode 100644 index 0000000000..11bbc80729 --- /dev/null +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Optional, +) + +import torch + +from deepmd.pt.model.descriptor import ( + DescriptorBlock, +) +from deepmd.pt.model.network.network import ( + Identity, + Linear, +) + + +@DescriptorBlock.register("hybrid") +class DescrptBlockHybrid(DescriptorBlock): + def __init__( + self, + list, + ntypes: int, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + hybrid_mode: str = "concat", + **kwargs, + ): + """Construct a hybrid descriptor. + + Args: + - descriptor_list: list of descriptors. + - descriptor_param: descriptor configs. + """ + super().__init__() + supported_descrpt = ["se_atten", "se_uni"] + descriptor_list = [] + for descriptor_param_item in list: + descriptor_type_tmp = descriptor_param_item["type"] + assert ( + descriptor_type_tmp in supported_descrpt + ), f"Only descriptors in {supported_descrpt} are supported for `hybrid` descriptor!" + descriptor_param_item["ntypes"] = ntypes + if descriptor_type_tmp == "se_atten": + descriptor_param_item["tebd_dim"] = tebd_dim + descriptor_param_item["tebd_input_mode"] = tebd_input_mode + descriptor_list.append(DescriptorBlock(**descriptor_param_item)) + self.descriptor_list = torch.nn.ModuleList(descriptor_list) + self.descriptor_param = list + self.rcut = [descrpt.rcut for descrpt in self.descriptor_list] + self.sec = [descrpt.sec for descrpt in self.descriptor_list] + self.sel = [descrpt.sel for descrpt in self.descriptor_list] + self.split_sel = [sum(ii) for ii in self.sel] + self.local_cluster_list = [ + descrpt.local_cluster for descrpt in self.descriptor_list + ] + self.local_cluster = True in self.local_cluster_list + self.hybrid_mode = hybrid_mode + self.tebd_dim = tebd_dim + assert self.hybrid_mode in ["concat", "sequential"] + sequential_transform = [] + if self.hybrid_mode == "sequential": + for ii in range(len(descriptor_list) - 1): + if descriptor_list[ii].dim_out == descriptor_list[ii + 1].dim_in: + sequential_transform.append(Identity()) + else: + sequential_transform.append( + Linear( + descriptor_list[ii].dim_out, + descriptor_list[ii + 1].dim_in, + bias=False, + init="glorot", + ) + ) + sequential_transform.append(Identity()) + self.sequential_transform = torch.nn.ModuleList(sequential_transform) + self.ntypes = ntypes + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return [sum(ii) for ii in self.get_sel()] + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntype(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + if self.hybrid_mode == "concat": + return sum([descrpt.dim_out for descrpt in self.descriptor_list]) + elif self.hybrid_mode == "sequential": + return self.descriptor_list[-1].dim_out + else: + raise RuntimeError + + @property + def dim_emb_list(self) -> List[int]: + """Returns the output dimension list of embeddings.""" + return [descrpt.dim_emb for descrpt in self.descriptor_list] + + @property + def dim_emb(self): + """Returns the output dimension of embedding.""" + if self.hybrid_mode == "concat": + return sum(self.dim_emb_list) + elif self.hybrid_mode == "sequential": + return self.descriptor_list[-1].dim_emb + else: + raise RuntimeError + + def share_params(self, base_class, shared_level, resume=False): + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + if shared_level == 0: + for ii, des in enumerate(self.descriptor_list): + self.descriptor_list[ii].share_params( + base_class.descriptor_list[ii], shared_level, resume=resume + ) + if self.hybrid_mode == "sequential": + self.sequential_transform = base_class.sequential_transform + else: + raise NotImplementedError + + def compute_input_stats(self, merged): + """Update mean and stddev for descriptor elements.""" + sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] + for ii, descrpt in enumerate(self.descriptor_list): + merged_tmp = [ + { + key: item[key] if not isinstance(item[key], list) else item[key][ii] + for key in item + } + for item in merged + ] + ( + sumr_tmp, + suma_tmp, + sumn_tmp, + sumr2_tmp, + suma2_tmp, + ) = descrpt.compute_input_stats(merged_tmp) + sumr.append(sumr_tmp) + suma.append(suma_tmp) + sumn.append(sumn_tmp) + sumr2.append(sumr2_tmp) + suma2.append(suma2_tmp) + return sumr, suma, sumn, sumr2, suma2 + + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + for ii, descrpt in enumerate(self.descriptor_list): + descrpt.init_desc_stat(sumr[ii], suma[ii], sumn[ii], sumr2[ii], suma2[ii]) + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: Optional[torch.Tensor] = None, + mapping: Optional[torch.Tensor] = None, + ): + """Calculate decoded embedding for each atom. + + Args: + - extended_coord: Tell atom coordinates with shape [nframes, natoms[1]*3]. + - nlist: Tell atom types with shape [nframes, natoms[1]]. + - atype: Tell atom count and element count. Its shape is [2+self.ntypes]. + - nlist_type: Tell simulation box with shape [nframes, 9]. + - atype_tebd: Tell simulation box with shape [nframes, 9]. + - nlist_tebd: Tell simulation box with shape [nframes, 9]. + + Returns + ------- + - result: descriptor with shape [nframes, nloc, self.filter_neuron[-1] * self.axis_neuron]. + - ret: environment matrix with shape [nframes, nloc, self.neei, out_size] + """ + nlist_list = list(torch.split(nlist, self.split_sel, -1)) + nframes, nloc, nnei = nlist.shape + concat_rot_mat = True + if self.hybrid_mode == "concat": + out_descriptor = [] + # out_env_mat = [] + out_rot_mat_list = [] + # out_diff = [] + for ii, descrpt in enumerate(self.descriptor_list): + descriptor, env_mat, diff, rot_mat, sw = descrpt( + nlist_list[ii], + extended_coord, + extended_atype, + extended_atype_embd, + mapping, + ) + if descriptor.shape[0] == nframes * nloc: + # [nframes * nloc, 1 + nnei, emb_dim] + descriptor = descriptor[:, 0, :].reshape(nframes, nloc, -1) + out_descriptor.append(descriptor) + # out_env_mat.append(env_mat) + # out_diff.append(diff) + out_rot_mat_list.append(rot_mat) + if rot_mat is None: + concat_rot_mat = False + out_descriptor = torch.concat(out_descriptor, dim=-1) + if concat_rot_mat: + out_rot_mat = torch.concat(out_rot_mat_list, dim=-2) + else: + out_rot_mat = None + return out_descriptor, None, None, out_rot_mat, sw + elif self.hybrid_mode == "sequential": + assert extended_atype_embd is not None + assert mapping is not None + nframes, nloc, nnei = nlist.shape + nall = extended_coord.view(nframes, -1).shape[1] // 3 + seq_input_ext = extended_atype_embd + seq_input = ( + seq_input_ext[:, :nloc, :] if len(self.descriptor_list) == 0 else None + ) + env_mat, diff, rot_mat, sw = None, None, None, None + env_mat_list, diff_list = [], [] + for ii, (descrpt, seq_transform) in enumerate( + zip(self.descriptor_list, self.sequential_transform) + ): + seq_output, env_mat, diff, rot_mat, sw = descrpt( + nlist_list[ii], + extended_coord, + extended_atype, + seq_input_ext, + mapping, + ) + seq_input = seq_transform(seq_output) + mapping_ext = ( + mapping.view(nframes, nall) + .unsqueeze(-1) + .expand(-1, -1, seq_input.shape[-1]) + ) + seq_input_ext = torch.gather(seq_input, 1, mapping_ext) + env_mat_list.append(env_mat) + diff_list.append(diff) + return seq_input, env_mat_list, diff_list, rot_mat, sw + else: + raise RuntimeError diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py new file mode 100644 index 0000000000..21ae0ff6f3 --- /dev/null +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -0,0 +1,749 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + List, +) + +import torch + +from deepmd.pt.model.network.network import ( + SimpleLinear, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + get_activation_fn, +) + + +def torch_linear(*args, **kwargs): + return torch.nn.Linear( + *args, **kwargs, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + + +def _make_nei_g1( + g1_ext: torch.Tensor, + nlist: torch.Tensor, +) -> torch.Tensor: + # nlist: nb x nloc x nnei + nb, nloc, nnei = nlist.shape + # g1_ext: nb x nall x ng1 + ng1 = g1_ext.shape[-1] + # index: nb x (nloc x nnei) x ng1 + index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) + # gg1 : nb x (nloc x nnei) x ng1 + gg1 = torch.gather(g1_ext, dim=1, index=index) + # gg1 : nb x nloc x nnei x ng1 + gg1 = gg1.view(nb, nloc, nnei, ng1) + return gg1 + + +def _apply_nlist_mask( + gg: torch.Tensor, + nlist_mask: torch.Tensor, +) -> torch.Tensor: + # gg: nf x nloc x nnei x ng + # msk: nf x nloc x nnei + return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0) + + +def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: + # gg: nf x nloc x nnei x ng + # sw: nf x nloc x nnei + return gg * sw.unsqueeze(-1) + + +def _apply_h_norm( + hh: torch.Tensor, # nf x nloc x nnei x 3 +) -> torch.Tensor: + """Normalize h by the std of vector length. + do not have an idea if this is a good way. + """ + nf, nl, nnei, _ = hh.shape + # nf x nloc x nnei + normh = torch.linalg.norm(hh, dim=-1) + # nf x nloc + std = torch.std(normh, dim=-1) + # nf x nloc x nnei x 3 + hh = hh[:, :, :, :] / (1.0 + std[:, :, None, None]) + return hh + + +class Atten2Map(torch.nn.Module): + def __init__( + self, + ni: int, + nd: int, + nh: int, + has_gate: bool = False, # apply gate to attn map + smooth: bool = True, + attnw_shift: float = 20.0, + ): + super().__init__() + self.ni = ni + self.nd = nd + self.nh = nh + self.mapqk = SimpleLinear(ni, nd * 2 * nh, bias=False) + self.has_gate = has_gate + self.smooth = smooth + self.attnw_shift = attnw_shift + + def forward( + self, + g2: torch.Tensor, # nb x nloc x nnei x ng2 + h2: torch.Tensor, # nb x nloc x nnei x 3 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei + ) -> torch.Tensor: + ( + nb, + nloc, + nnei, + _, + ) = g2.shape + nd, nh = self.nd, self.nh + # nb x nloc x nnei x nd x (nh x 2) + g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2) + # nb x nloc x (nh x 2) x nnei x nd + g2qk = torch.permute(g2qk, (0, 1, 4, 2, 3)) + # nb x nloc x nh x nnei x nd + g2q, g2k = torch.split(g2qk, nh, dim=2) + # g2q = torch.nn.functional.normalize(g2q, dim=-1) + # g2k = torch.nn.functional.normalize(g2k, dim=-1) + # nb x nloc x nh x nnei x nnei + attnw = torch.matmul(g2q, torch.transpose(g2k, -1, -2)) / nd**0.5 + if self.has_gate: + gate = torch.matmul(h2, torch.transpose(h2, -1, -2)).unsqueeze(-3) + attnw = attnw * gate + # mask the attenmap, nb x nloc x 1 x 1 x nnei + attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2) + # mask the attenmap, nb x nloc x 1 x nnei x 1 + attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1) + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ + :, :, None, None, : + ] - self.attnw_shift + else: + attnw = attnw.masked_fill( + attnw_mask, + float("-inf"), + ) + attnw = torch.softmax(attnw, dim=-1) + attnw = attnw.masked_fill( + attnw_mask, + 0.0, + ) + # nb x nloc x nh x nnei x nnei + attnw = attnw.masked_fill( + attnw_mask_c, + 0.0, + ) + if self.smooth: + attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] + # nb x nloc x nnei x nnei + h2h2t = torch.matmul(h2, torch.transpose(h2, -1, -2)) / 3.0**0.5 + # nb x nloc x nh x nnei x nnei + ret = attnw * h2h2t[:, :, None, :, :] + # ret = torch.softmax(g2qk, dim=-1) + # nb x nloc x nnei x nnei x nh + ret = torch.permute(ret, (0, 1, 3, 4, 2)) + return ret + + +class Atten2MultiHeadApply(torch.nn.Module): + def __init__( + self, + ni: int, + nh: int, + ): + super().__init__() + self.ni = ni + self.nh = nh + self.mapv = SimpleLinear(ni, ni * nh, bias=False) + self.head_map = SimpleLinear(ni * nh, ni) + + def forward( + self, + AA: torch.Tensor, # nf x nloc x nnei x nnei x nh + g2: torch.Tensor, # nf x nloc x nnei x ng2 + ) -> torch.Tensor: + nf, nloc, nnei, ng2 = g2.shape + nh = self.nh + # nf x nloc x nnei x ng2 x nh + g2v = self.mapv(g2).view(nf, nloc, nnei, ng2, nh) + # nf x nloc x nh x nnei x ng2 + g2v = torch.permute(g2v, (0, 1, 4, 2, 3)) + # g2v = torch.nn.functional.normalize(g2v, dim=-1) + # nf x nloc x nh x nnei x nnei + AA = torch.permute(AA, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x ng2 + ret = torch.matmul(AA, g2v) + # nf x nloc x nnei x ng2 x nh + ret = torch.permute(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) + # nf x nloc x nnei x ng2 + return self.head_map(ret) + + +class Atten2EquiVarApply(torch.nn.Module): + def __init__( + self, + ni: int, + nh: int, + ): + super().__init__() + self.ni = ni + self.nh = nh + self.head_map = SimpleLinear(nh, 1, bias=False) + + def forward( + self, + AA: torch.Tensor, # nf x nloc x nnei x nnei x nh + h2: torch.Tensor, # nf x nloc x nnei x 3 + ) -> torch.Tensor: + nf, nloc, nnei, _ = h2.shape + nh = self.nh + # nf x nloc x nh x nnei x nnei + AA = torch.permute(AA, (0, 1, 4, 2, 3)) + h2m = torch.unsqueeze(h2, dim=2) + # nf x nloc x nh x nnei x 3 + h2m = torch.tile(h2m, [1, 1, nh, 1, 1]) + # nf x nloc x nh x nnei x 3 + ret = torch.matmul(AA, h2m) + # nf x nloc x nnei x 3 x nh + ret = torch.permute(ret, (0, 1, 3, 4, 2)).view(nf, nloc, nnei, 3, nh) + # nf x nloc x nnei x 3 + return torch.squeeze(self.head_map(ret), dim=-1) + + +class LocalAtten(torch.nn.Module): + def __init__( + self, + ni: int, + nd: int, + nh: int, + smooth: bool = True, + attnw_shift: float = 20.0, + ): + super().__init__() + self.ni = ni + self.nd = nd + self.nh = nh + self.mapq = SimpleLinear(ni, nd * 1 * nh, bias=False) + self.mapkv = SimpleLinear(ni, (nd + ni) * nh, bias=False) + self.head_map = SimpleLinear(ni * nh, ni) + self.smooth = smooth + self.attnw_shift = attnw_shift + + def forward( + self, + g1: torch.Tensor, # nb x nloc x ng1 + gg1: torch.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei + ) -> torch.Tensor: + nb, nloc, nnei = nlist_mask.shape + ni, nd, nh = self.ni, self.nd, self.nh + assert ni == g1.shape[-1] + assert ni == gg1.shape[-1] + # nb x nloc x nd x nh + g1q = self.mapq(g1).view(nb, nloc, nd, nh) + # nb x nloc x nh x nd + g1q = torch.permute(g1q, (0, 1, 3, 2)) + # nb x nloc x nnei x (nd+ni) x nh + gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh) + gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3)) + # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 + gg1k, gg1v = torch.split(gg1kv, [nd, ni], dim=-1) + + # nb x nloc x nh x 1 x nnei + attnw = torch.matmul(g1q.unsqueeze(-2), torch.transpose(gg1k, -1, -2)) / nd**0.5 + # nb x nloc x nh x nnei + attnw = attnw.squeeze(-2) + # mask the attenmap, nb x nloc x 1 x nnei + attnw_mask = ~nlist_mask.unsqueeze(-2) + # nb x nloc x nh x nnei + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift + else: + attnw = attnw.masked_fill( + attnw_mask, + float("-inf"), + ) + attnw = torch.softmax(attnw, dim=-1) + attnw = attnw.masked_fill( + attnw_mask, + 0.0, + ) + if self.smooth: + attnw = attnw * sw.unsqueeze(-2) + + # nb x nloc x nh x ng1 + ret = ( + torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) + ) + # nb x nloc x ng1 + ret = self.head_map(ret) + return ret + + +class RepformerLayer(torch.nn.Module): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + g1_dim=128, + g2_dim=16, + axis_dim: int = 4, + update_chnnl_2: bool = True, + do_bn_mode: str = "no", + bn_momentum: float = 0.1, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation: str = "tanh", + update_style: str = "res_avg", + set_davg_zero: bool = True, # TODO + smooth: bool = True, + ): + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.rcut = rcut + self.rcut_smth = rcut_smth + self.ntypes = ntypes + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + assert len(sel) == 1 + self.sel = torch.tensor(sel) + self.sec = self.sel + self.axis_dim = axis_dim + self.set_davg_zero = set_davg_zero + self.do_bn_mode = do_bn_mode + self.bn_momentum = bn_momentum + self.act = get_activation_fn(activation) + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_attn = update_g1_has_attn + self.update_chnnl_2 = update_chnnl_2 + self.update_g2_has_g1g1 = update_g2_has_g1g1 if self.update_chnnl_2 else False + self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False + self.update_h2 = update_h2 if self.update_chnnl_2 else False + del update_g2_has_g1g1, update_g2_has_attn, update_h2 + self.update_style = update_style + self.smooth = smooth + self.g1_dim = g1_dim + self.g2_dim = g2_dim + + g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_dim) + self.linear1 = SimpleLinear(g1_in_dim, g1_dim) + self.linear2 = None + self.proj_g1g2 = None + self.proj_g1g1g2 = None + self.attn2g_map = None + self.attn2_mh_apply = None + self.attn2_lm = None + self.attn2h_map = None + self.attn2_ev_apply = None + self.loc_attn = None + + if self.update_chnnl_2: + self.linear2 = SimpleLinear(g2_dim, g2_dim) + if self.update_g1_has_conv: + self.proj_g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) + if self.update_g2_has_g1g1: + self.proj_g1g1g2 = SimpleLinear(g1_dim, g2_dim, bias=False) + if self.update_g2_has_attn: + self.attn2g_map = Atten2Map( + g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth + ) + self.attn2_mh_apply = Atten2MultiHeadApply(g2_dim, attn2_nhead) + self.attn2_lm = torch.nn.LayerNorm( + g2_dim, + elementwise_affine=True, + device=env.DEVICE, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + ) + if self.update_h2: + self.attn2h_map = Atten2Map( + g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth + ) + self.attn2_ev_apply = Atten2EquiVarApply(g2_dim, attn2_nhead) + if self.update_g1_has_attn: + self.loc_attn = LocalAtten(g1_dim, attn1_hidden, attn1_nhead, self.smooth) + + if self.do_bn_mode == "uniform": + self.bn1 = self._bn_layer() + self.bn2 = self._bn_layer() + elif self.do_bn_mode == "component": + self.bn1 = self._bn_layer(nf=g1_dim) + self.bn2 = self._bn_layer(nf=g2_dim) + elif self.do_bn_mode == "no": + self.bn1, self.bn2 = None, None + else: + raise RuntimeError(f"unknown bn_mode {self.do_bn_mode}") + + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: + ret = g1d + if self.update_g1_has_grrg: + ret += g2d * ax + if self.update_g1_has_drrd: + ret += g1d * ax + if self.update_g1_has_conv: + ret += g2d + return ret + + def _update_h2( + self, + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + assert self.attn2h_map is not None + assert self.attn2_ev_apply is not None + nb, nloc, nnei, _ = g2.shape + # # nb x nloc x nnei x nh2 + # h2_1 = self.attn2_ev_apply(AA, h2) + # h2_update.append(h2_1) + # nb x nloc x nnei x nnei x nh + AAh = self.attn2h_map(g2, h2, nlist_mask, sw) + # nb x nloc x nnei x nh2 + h2_1 = self.attn2_ev_apply(AAh, h2) + return h2_1 + + def _update_g1_conv( + self, + gg1: torch.Tensor, + g2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + assert self.proj_g1g2 is not None + nb, nloc, nnei, _ = g2.shape + ng1 = gg1.shape[-1] + ng2 = g2.shape[-1] + # gg1 : nb x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) + # nb x nloc x nnei x ng2 + gg1 = _apply_nlist_mask(gg1, nlist_mask) + if not self.smooth: + # normalized by number of neighbors, not smooth + # nb x nloc x 1 + invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)).unsqueeze(-1) + else: + gg1 = _apply_switch(gg1, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + # nb x nloc x ng2 + g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei + return g1_11 + + def _cal_h2g2( + self, + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + # g2: nf x nloc x nnei x ng2 + # h2: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nb, nloc, nnei, _ = g2.shape + ng2 = g2.shape[-1] + # nb x nloc x nnei x ng2 + g2 = _apply_nlist_mask(g2, nlist_mask) + if not self.smooth: + # nb x nloc + invnnei = 1.0 / (self.epsilon + torch.sum(nlist_mask, dim=-1)) + # nb x nloc x 1 x 1 + invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) + else: + g2 = _apply_switch(g2, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1, 1), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + # nb x nloc x 3 x ng2 + h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei + return h2g2 + + def _cal_grrg(self, h2g2: torch.Tensor) -> torch.Tensor: + # nb x nloc x 3 x ng2 + nb, nloc, _, ng2 = h2g2.shape + # nb x nloc x 3 x axis + h2g2m = torch.split(h2g2, self.axis_dim, dim=-1)[0] + # nb x nloc x axis x ng2 + g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) + # nb x nloc x (axisxng2) + g1_13 = g1_13.view(nb, nloc, self.axis_dim * ng2) + return g1_13 + + def _update_g1_grrg( + self, + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + # g2: nf x nloc x nnei x ng2 + # h2: nf x nloc x nnei x 3 + # msk: nf x nloc x nnei + nb, nloc, nnei, _ = g2.shape + ng2 = g2.shape[-1] + # nb x nloc x 3 x ng2 + h2g2 = self._cal_h2g2(g2, h2, nlist_mask, sw) + # nb x nloc x (axisxng2) + g1_13 = self._cal_grrg(h2g2) + return g1_13 + + def _update_g2_g1g1( + self, + g1: torch.Tensor, # nb x nloc x ng1 + gg1: torch.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei + ) -> torch.Tensor: + ret = g1.unsqueeze(-2) * gg1 + # nb x nloc x nnei x ng1 + ret = _apply_nlist_mask(ret, nlist_mask) + if self.smooth: + ret = _apply_switch(ret, sw) + return ret + + def _apply_bn( + self, + bn_number: int, + gg: torch.Tensor, + ): + if self.do_bn_mode == "uniform": + return self._apply_bn_uni(bn_number, gg) + elif self.do_bn_mode == "component": + return self._apply_bn_comp(bn_number, gg) + else: + return gg + + def _apply_nb_1(self, bn_number: int, gg: torch.Tensor) -> torch.Tensor: + nb, nl, nf = gg.shape + gg = gg.view([nb, 1, nl * nf]) + if bn_number == 1: + assert self.bn1 is not None + gg = self.bn1(gg) + else: + assert self.bn2 is not None + gg = self.bn2(gg) + return gg.view([nb, nl, nf]) + + def _apply_nb_2( + self, + bn_number: int, + gg: torch.Tensor, + ) -> torch.Tensor: + nb, nl, nnei, nf = gg.shape + gg = gg.view([nb, 1, nl * nnei * nf]) + if bn_number == 1: + assert self.bn1 is not None + gg = self.bn1(gg) + else: + assert self.bn2 is not None + gg = self.bn2(gg) + return gg.view([nb, nl, nnei, nf]) + + def _apply_bn_uni( + self, + bn_number: int, + gg: torch.Tensor, + mode: str = "1", + ) -> torch.Tensor: + if len(gg.shape) == 3: + return self._apply_nb_1(bn_number, gg) + elif len(gg.shape) == 4: + return self._apply_nb_2(bn_number, gg) + else: + raise RuntimeError(f"unsupported input shape {gg.shape}") + + def _apply_bn_comp( + self, + bn_number: int, + gg: torch.Tensor, + ) -> torch.Tensor: + ss = gg.shape + nf = ss[-1] + gg = gg.view([-1, nf]) + if bn_number == 1: + assert self.bn1 is not None + gg = self.bn1(gg).view(ss) + else: + assert self.bn2 is not None + gg = self.bn2(gg).view(ss) + return gg + + def forward( + self, + g1_ext: torch.Tensor, # nf x nall x ng1 + g2: torch.Tensor, # nf x nloc x nnei x ng2 + h2: torch.Tensor, # nf x nloc x nnei x 3 + nlist: torch.Tensor, # nf x nloc x nnei + nlist_mask: torch.Tensor, # nf x nloc x nnei + sw: torch.Tensor, # switch func, nf x nloc x nnei + ): + """ + Parameters + ---------- + g1_ext : nf x nall x ng1 extended single-atom chanel + g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant + h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant + nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) + nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei switch function + + Returns + ------- + g1: nf x nloc x ng1 updated single-atom chanel + g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant + h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant + """ + cal_gg1 = ( + self.update_g1_has_drrd + or self.update_g1_has_conv + or self.update_g1_has_attn + or self.update_g2_has_g1g1 + ) + + nb, nloc, nnei, _ = g2.shape + nall = g1_ext.shape[1] + g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) + assert (nb, nloc) == g1.shape[:2] + assert (nb, nloc, nnei) == h2.shape[:3] + ng1 = g1.shape[-1] + ng2 = g2.shape[-1] + nh2 = h2.shape[-1] + + if self.bn1 is not None: + g1 = self._apply_bn(1, g1) + if self.bn2 is not None: + g2 = self._apply_bn(2, g2) + if self.update_h2: + h2 = _apply_h_norm(h2) + + g2_update: List[torch.Tensor] = [g2] + h2_update: List[torch.Tensor] = [h2] + g1_update: List[torch.Tensor] = [g1] + g1_mlp: List[torch.Tensor] = [g1] + + if cal_gg1: + gg1 = _make_nei_g1(g1_ext, nlist) + else: + gg1 = None + + if self.update_chnnl_2: + # nb x nloc x nnei x ng2 + assert self.linear2 is not None + g2_1 = self.act(self.linear2(g2)) + g2_update.append(g2_1) + + if self.update_g2_has_g1g1: + assert gg1 is not None + assert self.proj_g1g1g2 is not None + g2_update.append( + self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) + ) + + if self.update_g2_has_attn: + assert self.attn2g_map is not None + assert self.attn2_mh_apply is not None + assert self.attn2_lm is not None + # nb x nloc x nnei x nnei x nh + AAg = self.attn2g_map(g2, h2, nlist_mask, sw) + # nb x nloc x nnei x ng2 + g2_2 = self.attn2_mh_apply(AAg, g2) + g2_2 = self.attn2_lm(g2_2) + g2_update.append(g2_2) + + if self.update_h2: + h2_update.append(self._update_h2(g2, h2, nlist_mask, sw)) + + if self.update_g1_has_conv: + assert gg1 is not None + g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) + + if self.update_g1_has_grrg: + g1_mlp.append(self._update_g1_grrg(g2, h2, nlist_mask, sw)) + + if self.update_g1_has_drrd: + assert gg1 is not None + g1_mlp.append(self._update_g1_grrg(gg1, h2, nlist_mask, sw)) + + # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] + # conv grrg drrd + g1_1 = self.act(self.linear1(torch.cat(g1_mlp, dim=-1))) + g1_update.append(g1_1) + + if self.update_g1_has_attn: + assert gg1 is not None + assert self.loc_attn is not None + g1_update.append(self.loc_attn(g1, gg1, nlist_mask, sw)) + + # update + if self.update_chnnl_2: + g2_new = self.list_update(g2_update) + h2_new = self.list_update(h2_update) + else: + g2_new, h2_new = g2, h2 + g1_new = self.list_update(g1_update) + return g1_new, g2_new, h2_new + + @torch.jit.export + def list_update_res_avg( + self, + update_list: List[torch.Tensor], + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + @torch.jit.export + def list_update_res_incr(self, update_list: List[torch.Tensor]) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + @torch.jit.export + def list_update(self, update_list: List[torch.Tensor]) -> torch.Tensor: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def _bn_layer( + self, + nf: int = 1, + ) -> Callable: + return torch.nn.BatchNorm1d( + nf, + eps=1e-5, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + device=env.DEVICE, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + ) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py new file mode 100644 index 0000000000..26887b1b75 --- /dev/null +++ b/deepmd/pt/model/descriptor/repformers.py @@ -0,0 +1,348 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Optional, +) + +import numpy as np +import torch + +from deepmd.pt.model.descriptor.descriptor import ( + DescriptorBlock, + compute_std, +) +from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat_se_a, +) +from deepmd.pt.model.network.network import ( + SimpleLinear, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + build_neighbor_list, +) +from deepmd.pt.utils.utils import ( + get_activation_fn, +) + +from .repformer_layer import ( + RepformerLayer, +) +from .se_atten import ( + analyze_descrpt, +) + +mydtype = env.GLOBAL_PT_FLOAT_PRECISION +mydev = env.DEVICE + + +def torch_linear(*args, **kwargs): + return torch.nn.Linear(*args, **kwargs, dtype=mydtype, device=mydev) + + +simple_linear = SimpleLinear +mylinear = simple_linear + + +@DescriptorBlock.register("se_repformer") +@DescriptorBlock.register("se_uni") +class DescrptBlockRepformers(DescriptorBlock): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + nlayers: int = 3, + g1_dim=128, + g2_dim=16, + axis_dim: int = 4, + direct_dist: bool = False, + do_bn_mode: str = "no", + bn_momentum: float = 0.1, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation: str = "tanh", + update_style: str = "res_avg", + set_davg_zero: bool = True, # TODO + smooth: bool = True, + add_type_ebd_to_seq: bool = False, + type: Optional[str] = None, + ): + """ + smooth: + If strictly smooth, cannot be used with update_g1_has_attn + add_type_ebd_to_seq: + At the presence of seq_input (optional input to forward), + whether or not add an type embedding to seq_input. + If no seq_input is given, it has no effect. + """ + super().__init__() + del type + self.epsilon = 1e-4 # protection of 1./nnei + self.rcut = rcut + self.rcut_smth = rcut_smth + self.ntypes = ntypes + self.nlayers = nlayers + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + assert len(sel) == 1 + self.sel = sel + self.sec = self.sel + self.split_sel = self.sel + self.axis_dim = axis_dim + self.set_davg_zero = set_davg_zero + self.g1_dim = g1_dim + self.g2_dim = g2_dim + self.act = get_activation_fn(activation) + self.direct_dist = direct_dist + self.add_type_ebd_to_seq = add_type_ebd_to_seq + + self.g2_embd = mylinear(1, self.g2_dim) + layers = [] + for ii in range(nlayers): + layers.append( + RepformerLayer( + rcut, + rcut_smth, + sel, + ntypes, + self.g1_dim, + self.g2_dim, + axis_dim=self.axis_dim, + update_chnnl_2=(ii != nlayers - 1), + do_bn_mode=do_bn_mode, + bn_momentum=bn_momentum, + update_g1_has_conv=update_g1_has_conv, + update_g1_has_drrd=update_g1_has_drrd, + update_g1_has_grrg=update_g1_has_grrg, + update_g1_has_attn=update_g1_has_attn, + update_g2_has_g1g1=update_g2_has_g1g1, + update_g2_has_attn=update_g2_has_attn, + update_h2=update_h2, + attn1_hidden=attn1_hidden, + attn1_nhead=attn1_nhead, + attn2_has_gate=attn2_has_gate, + attn2_hidden=attn2_hidden, + attn2_nhead=attn2_nhead, + activation=activation, + update_style=update_style, + smooth=smooth, + ) + ) + self.layers = torch.nn.ModuleList(layers) + + sshape = (self.ntypes, self.nnei, 4) + mean = torch.zeros(sshape, dtype=mydtype, device=mydev) + stddev = torch.ones(sshape, dtype=mydtype, device=mydev) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntype(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.g1_dim + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.g1_dim + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.g2_dim + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: Optional[torch.Tensor] = None, + mapping: Optional[torch.Tensor] = None, + ): + assert mapping is not None + assert extended_atype_embd is not None + nframes, nloc, nnei = nlist.shape + nall = extended_coord.view(nframes, -1).shape[1] // 3 + atype = extended_atype[:, :nloc] + # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 + dmatrix, diff, sw = prod_env_mat_se_a( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + ) + nlist_mask = nlist != -1 + sw = torch.squeeze(sw, -1) + # beyond the cutoff sw should be 0.0 + sw = sw.masked_fill(~nlist_mask, 0.0) + + # [nframes, nloc, tebd_dim] + atype_embd = extended_atype_embd[:, :nloc, :] + assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim] + + g1 = self.act(atype_embd) + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + if not self.direct_dist: + g2, h2 = torch.split(dmatrix, [1, 3], dim=-1) + else: + g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff + g2 = g2 / self.rcut + h2 = h2 / self.rcut + # nb x nloc x nnei x ng2 + g2 = self.act(self.g2_embd(g2)) + + # set all padding positions to index of 0 + # if the a neighbor is real or not is indicated by nlist_mask + nlist[nlist == -1] = 0 + # nb x nall x ng1 + mapping = mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim) + for idx, ll in enumerate(self.layers): + # g1: nb x nloc x ng1 + # g1_ext: nb x nall x ng1 + g1_ext = torch.gather(g1, 1, mapping) + g1, g2, h2 = ll.forward( + g1_ext, + g2, + h2, + nlist, + nlist_mask, + sw, + ) + + # uses the last layer. + # nb x nloc x 3 x ng2 + h2g2 = ll._cal_h2g2(g2, h2, nlist_mask, sw) + # (nb x nloc) x ng2 x 3 + rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) + + return g1, g2, h2, rot_mat.view(-1, self.dim_emb, 3), sw + + def compute_input_stats(self, merged): + """Update mean and stddev for descriptor elements.""" + ndescrpt = self.nnei * 4 + sumr = [] + suma = [] + sumn = [] + sumr2 = [] + suma2 = [] + mixed_type = "real_natoms_vec" in merged[0] + for system in merged: + index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3) + extended_coord = torch.gather(system["coord"], dim=1, index=index) + extended_coord = extended_coord - system["shift"] + index = system["mapping"] + extended_atype = torch.gather(system["atype"], dim=1, index=index) + nloc = system["atype"].shape[-1] + ####################################################### + # dirty hack here! the interface of dataload should be + # redesigned to support descriptors like dpa2 + ####################################################### + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + self.rcut, + self.get_sel(), + distinguish_types=False, + ) + env_mat, _, _ = prod_env_mat_se_a( + extended_coord, + nlist, + system["atype"], + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + ) + if not mixed_type: + sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( + env_mat.detach().cpu().numpy(), ndescrpt, system["natoms"] + ) + else: + sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( + env_mat.detach().cpu().numpy(), + ndescrpt, + system["real_natoms_vec"], + mixed_type=mixed_type, + real_atype=system["atype"].detach().cpu().numpy(), + ) + sumr.append(sysr) + suma.append(sysa) + sumn.append(sysn) + sumr2.append(sysr2) + suma2.append(sysa2) + sumr = np.sum(sumr, axis=0) + suma = np.sum(suma, axis=0) + sumn = np.sum(sumn, axis=0) + sumr2 = np.sum(sumr2, axis=0) + suma2 = np.sum(suma2, axis=0) + return sumr, suma, sumn, sumr2, suma2 + + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + all_davg = [] + all_dstd = [] + for type_i in range(self.ntypes): + davgunit = [[sumr[type_i] / (sumn[type_i] + 1e-15), 0, 0, 0]] + dstdunit = [ + [ + compute_std(sumr2[type_i], sumr[type_i], sumn[type_i], self.rcut), + compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + ] + ] + davg = np.tile(davgunit, [self.nnei, 1]) + dstd = np.tile(dstdunit, [self.nnei, 1]) + all_davg.append(davg) + all_dstd.append(dstd) + self.sumr = sumr + self.suma = suma + self.sumn = sumn + self.sumr2 = sumr2 + self.suma2 = suma2 + if not self.set_davg_zero: + mean = np.stack(all_davg) + self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + stddev = np.stack(all_dstd) + self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py new file mode 100644 index 0000000000..10aa66311e --- /dev/null +++ b/deepmd/pt/model/descriptor/se_a.py @@ -0,0 +1,478 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + ClassVar, + List, + Optional, +) + +import numpy as np +import torch + +from deepmd.pt.model.descriptor import ( + Descriptor, + DescriptorBlock, + compute_std, + prod_env_mat_se_a, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +try: + from typing import ( + Final, + ) +except ImportError: + from torch.jit import Final + +from deepmd.model_format import EnvMat as DPEnvMat +from deepmd.pt.model.network.mlp import ( + EmbeddingNet, + NetworkCollection, +) +from deepmd.pt.model.network.network import ( + TypeFilter, +) + + +@Descriptor.register("se_e2_a") +class DescrptSeA(Descriptor): + def __init__( + self, + rcut, + rcut_smth, + sel, + neuron=[25, 50, 100], + axis_neuron=16, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = "float64", + resnet_dt: bool = False, + old_impl: bool = False, + **kwargs, + ): + super().__init__() + self.sea = DescrptBlockSeA( + rcut, + rcut_smth, + sel, + neuron, + axis_neuron, + set_davg_zero, + activation_function, + precision, + resnet_dt, + old_impl, + **kwargs, + ) + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.sea.get_rcut() + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return self.sea.get_nsel() + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sea.get_sel() + + def get_ntype(self) -> int: + """Returns the number of element types.""" + return self.sea.get_ntype() + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.sea.get_dim_out() + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.sea.dim_out + + def compute_input_stats(self, merged): + """Update mean and stddev for descriptor elements.""" + return self.sea.compute_input_stats(merged) + + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + self.sea.init_desc_stat(sumr, suma, sumn, sumr2, suma2) + + @classmethod + def get_stat_name(cls, config): + descrpt_type = config["type"] + assert descrpt_type in ["se_e2_a"] + return f'stat_file_sea_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz' + + @classmethod + def get_data_process_key(cls, config): + descrpt_type = config["type"] + assert descrpt_type in ["se_e2_a"] + return {"sel": config["sel"], "rcut": config["rcut"]} + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + ): + return self.sea.forward(nlist, extended_coord, extended_atype, None, mapping) + + def set_stat_mean_and_stddev( + self, + mean: torch.Tensor, + stddev: torch.Tensor, + ) -> None: + self.sea.mean = mean + self.sea.stddev = stddev + + def serialize(self) -> dict: + obj = self.sea + return { + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "resnet_dt": obj.resnet_dt, + "set_davg_zero": obj.set_davg_zero, + "activation_function": obj.activation_function, + "precision": obj.precision, + "embeddings": obj.filter_layers.serialize(), + "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), + "@variables": { + "davg": obj["davg"].detach().cpu().numpy(), + "dstd": obj["dstd"].detach().cpu().numpy(), + }, + ## to be updated when the options are supported. + "trainable": True, + "type_one_side": True, + "exclude_types": [], + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptSeA": + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + env_mat = data.pop("env_mat") + obj = cls(**data) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.sea.prec, device=env.DEVICE) + + obj.sea["davg"] = t_cvt(variables["davg"]) + obj.sea["dstd"] = t_cvt(variables["dstd"]) + obj.sea.filter_layers = NetworkCollection.deserialize(embeddings) + return obj + + +@DescriptorBlock.register("se_e2_a") +class DescrptBlockSeA(DescriptorBlock): + ndescrpt: Final[int] + __constants__: ClassVar[list] = ["ndescrpt"] + + def __init__( + self, + rcut, + rcut_smth, + sel, + neuron=[25, 50, 100], + axis_neuron=16, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = "float64", + resnet_dt: bool = False, + old_impl: bool = False, + **kwargs, + ): + """Construct an embedding net of type `se_a`. + + Args: + - rcut: Cut-off radius. + - rcut_smth: Smooth hyper-parameter for pair force & energy. + - sel: For each element type, how many atoms is selected as neighbors. + - filter_neuron: Number of neurons in each hidden layers of the embedding net. + - axis_neuron: Number of columns of the sub-matrix of the embedding matrix. + """ + super().__init__() + self.rcut = rcut + self.rcut_smth = rcut_smth + self.neuron = neuron + self.filter_neuron = self.neuron + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.resnet_dt = resnet_dt + self.old_impl = old_impl + + self.ntypes = len(sel) + self.sel = sel + self.sec = torch.tensor( + np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE + ) + self.split_sel = self.sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 + + wanted_shape = (self.ntypes, self.nnei, 4) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.filter_layers_old = None + self.filter_layers = None + + if self.old_impl: + filter_layers = [] + # TODO: remove + start_index = 0 + for type_i in range(self.ntypes): + one = TypeFilter(start_index, sel[type_i], self.filter_neuron) + filter_layers.append(one) + start_index += sel[type_i] + self.filter_layers_old = torch.nn.ModuleList(filter_layers) + else: + filter_layers = NetworkCollection( + ndim=1, ntypes=len(sel), network_type="embedding_network" + ) + # TODO: ndim=2 if type_one_side=False + for ii in range(self.ntypes): + filter_layers[(ii,)] = EmbeddingNet( + 1, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + ) + self.filter_layers = filter_layers + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntype(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.filter_neuron[-1] * self.axis_neuron + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return 0 + + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def compute_input_stats(self, merged): + """Update mean and stddev for descriptor elements.""" + sumr = [] + suma = [] + sumn = [] + sumr2 = [] + suma2 = [] + for system in merged: + index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3) + extended_coord = torch.gather(system["coord"], dim=1, index=index) + extended_coord = extended_coord - system["shift"] + env_mat, _, _ = prod_env_mat_se_a( + extended_coord, + system["nlist"], + system["atype"], + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + ) + sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( + env_mat.detach().cpu().numpy(), self.ndescrpt, system["natoms"] + ) + sumr.append(sysr) + suma.append(sysa) + sumn.append(sysn) + sumr2.append(sysr2) + suma2.append(sysa2) + sumr = np.sum(sumr, axis=0) + suma = np.sum(suma, axis=0) + sumn = np.sum(sumn, axis=0) + sumr2 = np.sum(sumr2, axis=0) + suma2 = np.sum(suma2, axis=0) + return sumr, suma, sumn, sumr2, suma2 + + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + all_davg = [] + all_dstd = [] + for type_i in range(self.ntypes): + davgunit = [[sumr[type_i] / (sumn[type_i] + 1e-15), 0, 0, 0]] + dstdunit = [ + [ + compute_std(sumr2[type_i], sumr[type_i], sumn[type_i], self.rcut), + compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + ] + ] + davg = np.tile(davgunit, [self.nnei, 1]) + dstd = np.tile(dstdunit, [self.nnei, 1]) + all_davg.append(davg) + all_dstd.append(dstd) + self.sumr = sumr + self.suma = suma + self.sumn = sumn + self.sumr2 = sumr2 + self.suma2 = suma2 + if not self.set_davg_zero: + mean = np.stack(all_davg) + self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + stddev = np.stack(all_dstd) + self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: Optional[torch.Tensor] = None, + mapping: Optional[torch.Tensor] = None, + ): + """Calculate decoded embedding for each atom. + + Args: + - coord: Tell atom coordinates with shape [nframes, natoms[1]*3]. + - atype: Tell atom types with shape [nframes, natoms[1]]. + - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. + - box: Tell simulation box with shape [nframes, 9]. + + Returns + ------- + - `torch.Tensor`: descriptor matrix with shape [nframes, natoms[0]*self.filter_neuron[-1]*self.axis_neuron]. + """ + del extended_atype_embd, mapping + nloc = nlist.shape[1] + atype = extended_atype[:, :nloc] + dmatrix, diff, _ = prod_env_mat_se_a( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + ) + + if self.old_impl: + assert self.filter_layers_old is not None + dmatrix = dmatrix.view( + -1, self.ndescrpt + ) # shape is [nframes*nall, self.ndescrpt] + xyz_scatter = torch.empty( + 1, + ) + ret = self.filter_layers_old[0](dmatrix) + xyz_scatter = ret + for ii, transform in enumerate(self.filter_layers_old[1:]): + # shape is [nframes*nall, 4, self.filter_neuron[-1]] + ret = transform.forward(dmatrix) + xyz_scatter = xyz_scatter + ret + else: + assert self.filter_layers is not None + dmatrix = dmatrix.view(-1, self.nnei, 4) + nfnl = dmatrix.shape[0] + # pre-allocate a shape to pass jit + xyz_scatter = torch.zeros( + [nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE + ) + for ii, ll in enumerate(self.filter_layers.networks): + # nfnl x nt x 4 + rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] + ss = rr[:, :, :1] + # nfnl x nt x ng + gg = ll.forward(ss) + # nfnl x 4 x ng + gr = torch.matmul(rr.permute(0, 2, 1), gg) + xyz_scatter += gr + + xyz_scatter /= self.nnei + xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) + rot_mat = xyz_scatter_1[:, :, 1:4] + xyz_scatter_2 = xyz_scatter[:, :, 0 : self.axis_neuron] + result = torch.matmul( + xyz_scatter_1, xyz_scatter_2 + ) # shape is [nframes*nall, self.filter_neuron[-1], self.axis_neuron] + return ( + result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), + None, + None, + None, + None, + ) + + +def analyze_descrpt(matrix, ndescrpt, natoms): + """Collect avg, square avg and count of descriptors in a batch.""" + ntypes = natoms.shape[1] - 2 + start_index = 0 + sysr = [] + sysa = [] + sysn = [] + sysr2 = [] + sysa2 = [] + for type_i in range(ntypes): + end_index = start_index + natoms[0, 2 + type_i] + dd = matrix[:, start_index:end_index] # all descriptors for this element + start_index = end_index + dd = np.reshape( + dd, [-1, 4] + ) # Shape is [nframes*natoms[2+type_id]*self.nnei, 4] + ddr = dd[:, :1] + dda = dd[:, 1:] + sumr = np.sum(ddr) + suma = np.sum(dda) / 3.0 + sumn = dd.shape[0] # Value is nframes*natoms[2+type_id]*self.nnei + sumr2 = np.sum(np.multiply(ddr, ddr)) + suma2 = np.sum(np.multiply(dda, dda)) / 3.0 + sysr.append(sumr) + sysa.append(suma) + sysn.append(sumn) + sysr2.append(sumr2) + sysa2.append(suma2) + return sysr, sysr2, sysa, sysa2, sysn diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py new file mode 100644 index 0000000000..0c932f42f2 --- /dev/null +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -0,0 +1,392 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Optional, +) + +import numpy as np +import torch + +from deepmd.pt.model.descriptor.descriptor import ( + DescriptorBlock, + compute_std, +) +from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat_se_a, +) +from deepmd.pt.model.network.network import ( + NeighborWiseAttention, + TypeFilter, +) +from deepmd.pt.utils import ( + env, +) + + +@DescriptorBlock.register("se_atten") +class DescrptBlockSeAtten(DescriptorBlock): + def __init__( + self, + rcut, + rcut_smth, + sel, + ntypes: int, + neuron: list = [25, 50, 100], + axis_neuron: int = 16, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + # set_davg_zero: bool = False, + set_davg_zero: bool = True, # TODO + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + post_ln=True, + ffn=False, + ffn_embed_dim=1024, + activation="tanh", + scaling_factor=1.0, + head_num=1, + normalize=True, + temperature=None, + return_rot=False, + type: Optional[str] = None, + ): + """Construct an embedding net of type `se_atten`. + + Args: + - rcut: Cut-off radius. + - rcut_smth: Smooth hyper-parameter for pair force & energy. + - sel: For each element type, how many atoms is selected as neighbors. + - filter_neuron: Number of neurons in each hidden layers of the embedding net. + - axis_neuron: Number of columns of the sub-matrix of the embedding matrix. + """ + super().__init__() + del type + self.rcut = rcut + self.rcut_smth = rcut_smth + self.filter_neuron = neuron + self.axis_neuron = axis_neuron + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.set_davg_zero = set_davg_zero + self.attn_dim = attn + self.attn_layer = attn_layer + self.attn_dotr = attn_dotr + self.attn_mask = attn_mask + self.post_ln = post_ln + self.ffn = ffn + self.ffn_embed_dim = ffn_embed_dim + self.activation = activation + self.scaling_factor = scaling_factor + self.head_num = head_num + self.normalize = normalize + self.temperature = temperature + self.return_rot = return_rot + + if isinstance(sel, int): + sel = [sel] + + self.ntypes = ntypes + self.sel = sel + self.sec = self.sel + self.split_sel = self.sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 + self.dpa1_attention = NeighborWiseAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + post_ln=self.post_ln, + ffn=self.ffn, + ffn_embed_dim=self.ffn_embed_dim, + activation=self.activation, + scaling_factor=self.scaling_factor, + head_num=self.head_num, + normalize=self.normalize, + temperature=self.temperature, + ) + + wanted_shape = (self.ntypes, self.nnei, 4) + mean = torch.zeros( + wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + stddev = torch.ones( + wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + + filter_layers = [] + one = TypeFilter( + 0, + self.nnei, + self.filter_neuron, + return_G=True, + tebd_dim=self.tebd_dim, + use_tebd=True, + tebd_mode=self.tebd_input_mode, + ) + filter_layers.append(one) + self.filter_layers = torch.nn.ModuleList(filter_layers) + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntype(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_in(self) -> int: + """Returns the output dimension.""" + return self.dim_in + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.filter_neuron[-1] * self.axis_neuron + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.tebd_dim + + @property + def dim_emb(self): + """Returns the output dimension of embedding.""" + return self.filter_neuron[-1] + + def compute_input_stats(self, merged): + """Update mean and stddev for descriptor elements.""" + sumr = [] + suma = [] + sumn = [] + sumr2 = [] + suma2 = [] + mixed_type = "real_natoms_vec" in merged[0] + for system in merged: + index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3) + extended_coord = torch.gather(system["coord"], dim=1, index=index) + extended_coord = extended_coord - system["shift"] + env_mat, _, _ = prod_env_mat_se_a( + extended_coord, + system["nlist"], + system["atype"], + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + ) + if not mixed_type: + sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( + env_mat.detach().cpu().numpy(), self.ndescrpt, system["natoms"] + ) + else: + sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt( + env_mat.detach().cpu().numpy(), + self.ndescrpt, + system["real_natoms_vec"], + mixed_type=mixed_type, + real_atype=system["atype"].detach().cpu().numpy(), + ) + sumr.append(sysr) + suma.append(sysa) + sumn.append(sysn) + sumr2.append(sysr2) + suma2.append(sysa2) + sumr = np.sum(sumr, axis=0) + suma = np.sum(suma, axis=0) + sumn = np.sum(sumn, axis=0) + sumr2 = np.sum(sumr2, axis=0) + suma2 = np.sum(suma2, axis=0) + return sumr, suma, sumn, sumr2, suma2 + + def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): + all_davg = [] + all_dstd = [] + for type_i in range(self.ntypes): + davgunit = [[sumr[type_i] / (sumn[type_i] + 1e-15), 0, 0, 0]] + dstdunit = [ + [ + compute_std(sumr2[type_i], sumr[type_i], sumn[type_i], self.rcut), + compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + compute_std(suma2[type_i], suma[type_i], sumn[type_i], self.rcut), + ] + ] + davg = np.tile(davgunit, [self.nnei, 1]) + dstd = np.tile(dstdunit, [self.nnei, 1]) + all_davg.append(davg) + all_dstd.append(dstd) + self.sumr = sumr + self.suma = suma + self.sumn = sumn + self.sumr2 = sumr2 + self.suma2 = suma2 + if not self.set_davg_zero: + mean = np.stack(all_davg) + self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + stddev = np.stack(all_dstd) + self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: Optional[torch.Tensor] = None, + mapping: Optional[torch.Tensor] = None, + ) -> List[torch.Tensor]: + """Calculate decoded embedding for each atom. + + Args: + - coord: Tell atom coordinates with shape [nframes, natoms[1]*3]. + - atype: Tell atom types with shape [nframes, natoms[1]]. + - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. + - box: Tell simulation box with shape [nframes, 9]. + + Returns + ------- + - result: descriptor with shape [nframes, nloc, self.filter_neuron[-1] * self.axis_neuron]. + - ret: environment matrix with shape [nframes, nloc, self.neei, out_size] + """ + del mapping + assert extended_atype_embd is not None + nframes, nloc, nnei = nlist.shape + atype = extended_atype[:, :nloc] + nb = nframes + nall = extended_coord.view(nb, -1, 3).shape[1] + dmatrix, diff, sw = prod_env_mat_se_a( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + ) + dmatrix = dmatrix.view( + -1, self.ndescrpt + ) # shape is [nframes*nall, self.ndescrpt] + nlist_mask = nlist != -1 + nlist[nlist == -1] = 0 + sw = torch.squeeze(sw, -1) + # beyond the cutoff sw should be 0.0 + sw = sw.masked_fill(~nlist_mask, 0.0) + # nf x nloc x nt -> nf x nloc x nnei x nt + atype_tebd = extended_atype_embd[:, :nloc, :] + atype_tebd_nnei = atype_tebd.unsqueeze(2).expand(-1, -1, self.nnei, -1) + # nf x nall x nt + nt = extended_atype_embd.shape[-1] + atype_tebd_ext = extended_atype_embd + # nb x (nloc x nnei) x nt + index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, nt) + # nb x (nloc x nnei) x nt + atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) + # nb x nloc x nnei x nt + atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) + ret = self.filter_layers[0]( + dmatrix, + atype_tebd=atype_tebd_nnei, + nlist_tebd=atype_tebd_nlist, + ) # shape is [nframes*nall, self.neei, out_size] + input_r = torch.nn.functional.normalize( + dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + ret = self.dpa1_attention( + ret, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + inputs_reshape = dmatrix.view(-1, self.nnei, 4).permute( + 0, 2, 1 + ) # shape is [nframes*natoms[0], 4, self.neei] + xyz_scatter = torch.matmul( + inputs_reshape, ret + ) # shape is [nframes*natoms[0], 4, out_size] + xyz_scatter = xyz_scatter / self.nnei + xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) + rot_mat = xyz_scatter_1[:, :, 1:4] + xyz_scatter_2 = xyz_scatter[:, :, 0 : self.axis_neuron] + result = torch.matmul( + xyz_scatter_1, xyz_scatter_2 + ) # shape is [nframes*nloc, self.filter_neuron[-1], self.axis_neuron] + return ( + result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), + ret.view(-1, nloc, self.nnei, self.filter_neuron[-1]), + diff, + rot_mat.view(-1, self.filter_neuron[-1], 3), + sw, + ) + + +def analyze_descrpt(matrix, ndescrpt, natoms, mixed_type=False, real_atype=None): + """Collect avg, square avg and count of descriptors in a batch.""" + ntypes = natoms.shape[1] - 2 + if not mixed_type: + sysr = [] + sysa = [] + sysn = [] + sysr2 = [] + sysa2 = [] + start_index = 0 + for type_i in range(ntypes): + end_index = start_index + natoms[0, 2 + type_i] + dd = matrix[:, start_index:end_index] + start_index = end_index + dd = np.reshape( + dd, [-1, 4] + ) # Shape is [nframes*natoms[2+type_id]*self.nnei, 4] + ddr = dd[:, :1] + dda = dd[:, 1:] + sumr = np.sum(ddr) + suma = np.sum(dda) / 3.0 + sumn = dd.shape[0] # Value is nframes*natoms[2+type_id]*self.nnei + sumr2 = np.sum(np.multiply(ddr, ddr)) + suma2 = np.sum(np.multiply(dda, dda)) / 3.0 + sysr.append(sumr) + sysa.append(suma) + sysn.append(sumn) + sysr2.append(sumr2) + sysa2.append(suma2) + else: + sysr = [0.0 for i in range(ntypes)] + sysa = [0.0 for i in range(ntypes)] + sysn = [0 for i in range(ntypes)] + sysr2 = [0.0 for i in range(ntypes)] + sysa2 = [0.0 for i in range(ntypes)] + for frame_item in range(matrix.shape[0]): + dd_ff = matrix[frame_item] + atype_frame = real_atype[frame_item] + for type_i in range(ntypes): + type_idx = atype_frame == type_i + dd = dd_ff[type_idx] + dd = np.reshape(dd, [-1, 4]) # typen_atoms * nnei, 4 + ddr = dd[:, :1] + dda = dd[:, 1:] + sumr = np.sum(ddr) + suma = np.sum(dda) / 3.0 + sumn = dd.shape[0] + sumr2 = np.sum(np.multiply(ddr, ddr)) + suma2 = np.sum(np.multiply(dda, dda)) / 3.0 + sysr[type_i] += sumr + sysa[type_i] += suma + sysn[type_i] += sumn + sysr2[type_i] += sumr2 + sysa2[type_i] += suma2 + + return sysr, sysr2, sysa, sysa2, sysn diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py new file mode 100644 index 0000000000..a3db3dbdec --- /dev/null +++ b/deepmd/pt/model/model/__init__.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .ener import ( + EnergyModel, +) +from .model import ( + BaseModel, +) + + +def get_model(model_params, sampled=None): + return EnergyModel( + descriptor=model_params["descriptor"], + fitting_net=model_params.get("fitting_net", None), + type_map=model_params["type_map"], + type_embedding=model_params.get("type_embedding", None), + resuming=model_params.get("resuming", False), + stat_file_dir=model_params.get("stat_file_dir", None), + stat_file_path=model_params.get("stat_file_path", None), + sampled=sampled, + ) + + +__all__ = [ + "BaseModel", + "EnergyModel", + "get_model", +] diff --git a/deepmd/pt/model/model/atomic_model.py b/deepmd/pt/model/model/atomic_model.py new file mode 100644 index 0000000000..47fd463fc9 --- /dev/null +++ b/deepmd/pt/model/model/atomic_model.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + ABC, + abstractmethod, +) +from typing import ( + Dict, + List, + Optional, +) + +import torch + +from deepmd.model_format import ( + FittingOutputDef, +) +from deepmd.pt.model.task import ( + Fitting, +) + + +class AtomicModel(ABC): + @abstractmethod + def get_fitting_net(self) -> Fitting: + raise NotImplementedError + + @abstractmethod + def get_fitting_output_def(self) -> FittingOutputDef: + raise NotImplementedError + + @abstractmethod + def get_rcut(self) -> float: + raise NotImplementedError + + @abstractmethod + def get_sel(self) -> List[int]: + raise NotImplementedError + + @abstractmethod + def distinguish_types(self) -> bool: + raise NotImplementedError + + @abstractmethod + def forward_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + raise NotImplementedError + + def do_grad( + self, + var_name: Optional[str] = None, + ) -> bool: + """Tell if the output variable `var_name` is differentiable. + if var_name is None, returns if any of the variable is differentiable. + + """ + odef = self.get_fitting_output_def() + if var_name is None: + require: List[bool] = [] + for vv in odef.keys(): + require.append(self.do_grad_(vv)) + return any(require) + else: + return self.do_grad_(var_name) + + def do_grad_( + self, + var_name: str, + ) -> bool: + """Tell if the output variable `var_name` is differentiable.""" + assert var_name is not None + return self.get_fitting_output_def()[var_name].differentiable diff --git a/deepmd/pt/model/model/dp_atomic_model.py b/deepmd/pt/model/model/dp_atomic_model.py new file mode 100644 index 0000000000..ffeeeda660 --- /dev/null +++ b/deepmd/pt/model/model/dp_atomic_model.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, +) + +import torch + +from deepmd.model_format import ( + FittingOutputDef, +) +from deepmd.pt.model.descriptor.descriptor import ( + Descriptor, +) +from deepmd.pt.model.task import ( + DenoiseNet, + Fitting, +) + +from .atomic_model import ( + AtomicModel, +) +from .model import ( + BaseModel, +) + + +class DPAtomicModel(BaseModel, AtomicModel): + """Model give atomic prediction of some physical property. + + Parameters + ---------- + descriptor + Descriptor + fitting_net + Fitting net + type_map + Mapping atom type to the name (str) of the type. + For example `type_map[1]` gives the name of the type 1. + type_embedding + Type embedding net + resuming + Whether to resume/fine-tune from checkpoint or not. + stat_file_dir + The directory to the state files. + stat_file_path + The path to the state files. + sampled + Sampled frames to compute the statistics. + """ + + def __init__( + self, + descriptor: dict, + fitting_net: dict, + type_map: Optional[List[str]], + type_embedding: Optional[dict] = None, + resuming: bool = False, + stat_file_dir=None, + stat_file_path=None, + sampled=None, + **kwargs, + ): + super().__init__() + # Descriptor + Type Embedding Net (Optional) + ntypes = len(type_map) + self.type_map = type_map + self.ntypes = ntypes + descriptor["ntypes"] = ntypes + self.combination = descriptor.get("combination", False) + if self.combination: + self.prefactor = descriptor.get("prefactor", [0.5, 0.5]) + self.descriptor_type = descriptor["type"] + + self.type_split = True + if self.descriptor_type not in ["se_e2_a"]: + self.type_split = False + + self.descriptor = Descriptor(**descriptor) + self.rcut = self.descriptor.get_rcut() + self.sel = self.descriptor.get_sel() + self.split_nlist = False + + # Statistics + self.compute_or_load_stat( + fitting_net, + ntypes, + resuming=resuming, + type_map=type_map, + stat_file_dir=stat_file_dir, + stat_file_path=stat_file_path, + sampled=sampled, + ) + + # Fitting + if fitting_net: + fitting_net["type"] = fitting_net.get("type", "ener") + if self.descriptor_type not in ["se_e2_a"]: + fitting_net["ntypes"] = 1 + else: + fitting_net["ntypes"] = self.descriptor.get_ntype() + fitting_net["use_tebd"] = False + fitting_net["embedding_width"] = self.descriptor.dim_out + + self.grad_force = "direct" not in fitting_net["type"] + if not self.grad_force: + fitting_net["out_dim"] = self.descriptor.dim_emb + if "ener" in fitting_net["type"]: + fitting_net["return_energy"] = True + self.fitting_net = Fitting(**fitting_net) + else: + self.fitting_net = None + self.grad_force = False + if not self.split_nlist: + self.coord_denoise_net = DenoiseNet( + self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb + ) + elif self.combination: + self.coord_denoise_net = DenoiseNet( + self.descriptor.dim_out, + self.ntypes - 1, + self.descriptor.dim_emb_list, + self.prefactor, + ) + else: + self.coord_denoise_net = DenoiseNet( + self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb + ) + + def get_fitting_net(self) -> Fitting: + """Get the fitting net.""" + return ( + self.fitting_net if self.fitting_net is not None else self.coord_denoise_net + ) + + def get_fitting_output_def(self) -> FittingOutputDef: + """Get the output def of the fitting net.""" + return ( + self.fitting_net.output_def() + if self.fitting_net is not None + else self.coord_denoise_net.output_def() + ) + + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return self.rcut + + def get_sel(self) -> List[int]: + """Get the neighbor selection.""" + return self.sel + + def distinguish_types(self) -> bool: + """If distinguish different types by sorting.""" + return self.type_split + + def forward_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + """Return atomic prediction. + + Parameters + ---------- + extended_coord + coodinates in extended region + extended_atype + atomic type in extended region + nlist + neighbor list. nf x nloc x nsel + mapping + mapps the extended indices to local indices + + Returns + ------- + result_dict + the result dict, defined by the fitting net output def. + + """ + nframes, nloc, nnei = nlist.shape + atype = extended_atype[:, :nloc] + if self.do_grad(): + extended_coord.requires_grad_(True) + descriptor, env_mat, diff, rot_mat, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + assert descriptor is not None + # energy, force + if self.fitting_net is not None: + fit_ret = self.fitting_net( + descriptor, atype, atype_tebd=None, rot_mat=rot_mat + ) + # denoise + else: + nlist_list = [nlist] + if not self.split_nlist: + nnei_mask = nlist != -1 + elif self.combination: + nnei_mask = [] + for item in nlist_list: + nnei_mask_item = item != -1 + nnei_mask.append(nnei_mask_item) + else: + env_mat = env_mat[-1] + diff = diff[-1] + nnei_mask = nlist_list[-1] != -1 + fit_ret = self.coord_denoise_net(env_mat, diff, nnei_mask, descriptor, sw) + return fit_ret diff --git a/deepmd/pt/model/model/ener.py b/deepmd/pt/model/model/ener.py new file mode 100644 index 0000000000..c316c99a86 --- /dev/null +++ b/deepmd/pt/model/model/ener.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, +) + +import torch + +from .dp_atomic_model import ( + DPAtomicModel, +) +from .make_model import ( + make_model, +) + +DPModel = make_model(DPAtomicModel) + + +class EnergyModel(DPModel): + model_type = "ener" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, atype, box, do_atomic_virial=do_atomic_virial + ) + if self.fitting_net is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if do_atomic_virial: + model_predict["atomic_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-3) + else: + model_predict["force"] = model_ret["dforce"] + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + def forward_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + model_ret = self.common_forward_lower( + extended_coord, + extended_atype, + nlist, + mapping, + do_atomic_virial=do_atomic_virial, + ) + if self.fitting_net is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret[ + "energy_derv_c" + ].squeeze(-3) + else: + assert model_ret["dforce"] is not None + model_predict["dforce"] = model_ret["dforce"] + else: + model_predict = model_ret + return model_predict + + +# should be a stand-alone function!!!! +def process_nlist( + nlist, + extended_atype, + mapping: Optional[torch.Tensor] = None, +): + # process the nlist_type and nlist_loc + nframes, nloc = nlist.shape[:2] + nmask = nlist == -1 + nlist[nmask] = 0 + if mapping is not None: + nlist_loc = torch.gather( + mapping, + dim=1, + index=nlist.reshape(nframes, -1), + ).reshape(nframes, nloc, -1) + nlist_loc[nmask] = -1 + else: + nlist_loc = None + nlist_type = torch.gather( + extended_atype, + dim=1, + index=nlist.reshape(nframes, -1), + ).reshape(nframes, nloc, -1) + nlist_type[nmask] = -1 + nlist[nmask] = -1 + return nlist_loc, nlist_type, nframes, nloc + + +def process_nlist_gathered( + nlist, + extended_atype, + split_sel: List[int], + mapping: Optional[torch.Tensor] = None, +): + nlist_list = list(torch.split(nlist, split_sel, -1)) + nframes, nloc = nlist_list[0].shape[:2] + nlist_type_list = [] + nlist_loc_list = [] + for nlist_item in nlist_list: + nmask = nlist_item == -1 + nlist_item[nmask] = 0 + if mapping is not None: + nlist_loc_item = torch.gather( + mapping, dim=1, index=nlist_item.reshape(nframes, -1) + ).reshape(nframes, nloc, -1) + nlist_loc_item[nmask] = -1 + nlist_loc_list.append(nlist_loc_item) + nlist_type_item = torch.gather( + extended_atype, dim=1, index=nlist_item.reshape(nframes, -1) + ).reshape(nframes, nloc, -1) + nlist_type_item[nmask] = -1 + nlist_type_list.append(nlist_type_item) + nlist_item[nmask] = -1 + + if mapping is not None: + nlist_loc = torch.cat(nlist_loc_list, -1) + else: + nlist_loc = None + nlist_type = torch.cat(nlist_type_list, -1) + return nlist_loc, nlist_type, nframes, nloc diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py new file mode 100644 index 0000000000..3ddd21fbb8 --- /dev/null +++ b/deepmd/pt/model/model/make_model.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + Optional, +) + +import torch + +from deepmd.model_format import ( + ModelOutputDef, +) +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, + fit_output_to_model_output, +) +from deepmd.pt.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.pt.utils.region import ( + normalize_coord, +) + + +def make_model(T_AtomicModel): + class CM(T_AtomicModel): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__( + *args, + **kwargs, + ) + + def get_model_output_def(self): + return ModelOutputDef(self.get_fitting_output_def()) + + # cannot use the name forward. torch script does not work + def forward_common( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + """Return total energy of the system. + Args: + - coord: Atom coordinates with shape [nframes, natoms[1]*3]. + - atype: Atom types with shape [nframes, natoms[1]]. + - natoms: Atom statisics with shape [self.ntypes+2]. + - box: Simulation box with shape [nframes, 9]. + - atomic_virial: Whether or not compoute the atomic virial. + + Returns + ------- + - energy: Energy per atom. + - force: XYZ force per atom. + """ + nframes, nloc = atype.shape[:2] + if box is not None: + coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) + else: + coord_normalized = coord.clone() + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, self.get_rcut() + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + self.get_rcut(), + self.get_sel(), + distinguish_types=self.distinguish_types(), + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + model_predict_lower = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + do_atomic_virial=do_atomic_virial, + ) + model_predict = communicate_extended_output( + model_predict_lower, + self.get_model_output_def(), + mapping, + do_atomic_virial=do_atomic_virial, + ) + return model_predict + + def forward_common_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ): + """Return model prediction. + + Parameters + ---------- + extended_coord + coodinates in extended region + extended_atype + atomic type in extended region + nlist + neighbor list. nf x nloc x nsel + mapping + mapps the extended indices to local indices + do_atomic_virial + whether do atomic virial + + Returns + ------- + result_dict + the result dict, defined by the fitting net output def. + + """ + atomic_ret = self.forward_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + model_predict = fit_output_to_model_output( + atomic_ret, + self.get_fitting_output_def(), + extended_coord, + do_atomic_virial=do_atomic_virial, + ) + return model_predict + + return CM diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py new file mode 100644 index 0000000000..139744c1e9 --- /dev/null +++ b/deepmd/pt/model/model/model.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +import os + +import numpy as np +import torch + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.stat import ( + compute_output_stats, +) + + +class BaseModel(torch.nn.Module): + def __init__(self): + """Construct a basic model for different tasks.""" + super().__init__() + + def forward(self, *args, **kwargs): + """Model output.""" + raise NotImplementedError + + def compute_or_load_stat( + self, + fitting_param, + ntypes, + resuming=False, + type_map=None, + stat_file_dir=None, + stat_file_path=None, + sampled=None, + ): + if fitting_param is None: + fitting_param = {} + if not resuming: + if sampled is not None: # compute stat + for sys in sampled: + for key in sys: + if isinstance(sys[key], list): + sys[key] = [item.to(env.DEVICE) for item in sys[key]] + else: + if sys[key] is not None: + sys[key] = sys[key].to(env.DEVICE) + sumr, suma, sumn, sumr2, suma2 = self.descriptor.compute_input_stats( + sampled + ) + + energy = [item["energy"] for item in sampled] + mixed_type = "real_natoms_vec" in sampled[0] + if mixed_type: + input_natoms = [item["real_natoms_vec"] for item in sampled] + else: + input_natoms = [item["natoms"] for item in sampled] + tmp = compute_output_stats(energy, input_natoms) + fitting_param["bias_atom_e"] = tmp[:, 0] + if stat_file_path is not None: + if not os.path.exists(stat_file_dir): + os.mkdir(stat_file_dir) + if not isinstance(stat_file_path, list): + logging.info(f"Saving stat file to {stat_file_path}") + np.savez_compressed( + stat_file_path, + sumr=sumr, + suma=suma, + sumn=sumn, + sumr2=sumr2, + suma2=suma2, + bias_atom_e=fitting_param["bias_atom_e"], + type_map=type_map, + ) + else: + for ii, file_path in enumerate(stat_file_path): + logging.info(f"Saving stat file to {file_path}") + np.savez_compressed( + file_path, + sumr=sumr[ii], + suma=suma[ii], + sumn=sumn[ii], + sumr2=sumr2[ii], + suma2=suma2[ii], + bias_atom_e=fitting_param["bias_atom_e"], + type_map=type_map, + ) + else: # load stat + target_type_map = type_map + if not isinstance(stat_file_path, list): + logging.info(f"Loading stat file from {stat_file_path}") + stats = np.load(stat_file_path) + stat_type_map = list(stats["type_map"]) + missing_type = [ + i for i in target_type_map if i not in stat_type_map + ] + assert not missing_type, f"These type are not in stat file {stat_file_path}: {missing_type}! Please change the stat file path!" + idx_map = [stat_type_map.index(i) for i in target_type_map] + if stats["sumr"].size: + sumr, suma, sumn, sumr2, suma2 = ( + stats["sumr"][idx_map], + stats["suma"][idx_map], + stats["sumn"][idx_map], + stats["sumr2"][idx_map], + stats["suma2"][idx_map], + ) + else: + sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] + fitting_param["bias_atom_e"] = stats["bias_atom_e"][idx_map] + else: + sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] + id_bias_atom_e = None + for ii, file_path in enumerate(stat_file_path): + logging.info(f"Loading stat file from {file_path}") + stats = np.load(file_path) + stat_type_map = list(stats["type_map"]) + missing_type = [ + i for i in target_type_map if i not in stat_type_map + ] + assert not missing_type, f"These type are not in stat file {file_path}: {missing_type}! Please change the stat file path!" + idx_map = [stat_type_map.index(i) for i in target_type_map] + if stats["sumr"].size: + sumr_tmp, suma_tmp, sumn_tmp, sumr2_tmp, suma2_tmp = ( + stats["sumr"][idx_map], + stats["suma"][idx_map], + stats["sumn"][idx_map], + stats["sumr2"][idx_map], + stats["suma2"][idx_map], + ) + else: + sumr_tmp, suma_tmp, sumn_tmp, sumr2_tmp, suma2_tmp = ( + [], + [], + [], + [], + [], + ) + sumr.append(sumr_tmp) + suma.append(suma_tmp) + sumn.append(sumn_tmp) + sumr2.append(sumr2_tmp) + suma2.append(suma2_tmp) + fitting_param["bias_atom_e"] = stats["bias_atom_e"][idx_map] + if id_bias_atom_e is None: + id_bias_atom_e = fitting_param["bias_atom_e"] + else: + assert ( + id_bias_atom_e == fitting_param["bias_atom_e"] + ).all(), "bias_atom_e in stat files are not consistent!" + self.descriptor.init_desc_stat(sumr, suma, sumn, sumr2, suma2) + else: # resuming for checkpoint; init model params from scratch + fitting_param["bias_atom_e"] = [0.0] * ntypes diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py new file mode 100644 index 0000000000..673491d788 --- /dev/null +++ b/deepmd/pt/model/model/transform_output.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, +) + +import torch + +from deepmd.model_format import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, + get_deriv_name, + get_reduce_name, +) + + +def atomic_virial_corr( + extended_coord: torch.Tensor, + atom_energy: torch.Tensor, +): + nall = extended_coord.shape[1] + nloc = atom_energy.shape[1] + coord, _ = torch.split(extended_coord, [nloc, nall - nloc], dim=1) + # no derivative with respect to the loc coord. + coord = coord.detach() + ce = coord * atom_energy + sumce0, sumce1, sumce2 = torch.split(torch.sum(ce, dim=1), [1, 1, 1], dim=-1) + faked_grad = torch.ones_like(sumce0) + lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad]) + extended_virial_corr0 = torch.autograd.grad( + [sumce0], [extended_coord], grad_outputs=lst, create_graph=True + )[0] + assert extended_virial_corr0 is not None + extended_virial_corr1 = torch.autograd.grad( + [sumce1], [extended_coord], grad_outputs=lst, create_graph=True + )[0] + assert extended_virial_corr1 is not None + extended_virial_corr2 = torch.autograd.grad( + [sumce2], [extended_coord], grad_outputs=lst, create_graph=True + )[0] + assert extended_virial_corr2 is not None + extended_virial_corr = torch.concat( + [ + extended_virial_corr0.unsqueeze(-1), + extended_virial_corr1.unsqueeze(-1), + extended_virial_corr2.unsqueeze(-1), + ], + dim=-1, + ) + return extended_virial_corr + + +def task_deriv_one( + atom_energy: torch.Tensor, + energy: torch.Tensor, + extended_coord: torch.Tensor, + do_atomic_virial: bool = False, +): + faked_grad = torch.ones_like(energy) + lst = torch.jit.annotate(List[Optional[torch.Tensor]], [faked_grad]) + extended_force = torch.autograd.grad( + [energy], [extended_coord], grad_outputs=lst, create_graph=True + )[0] + assert extended_force is not None + extended_force = -extended_force + extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2) + # the correction sums to zero, which does not contribute to global virial + if do_atomic_virial: + extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy) + extended_virial = extended_virial + extended_virial_corr + return extended_force, extended_virial + + +def get_leading_dims( + vv: torch.Tensor, + vdef: OutputVariableDef, +): + """Get the dimensions of nf x nloc.""" + vshape = vv.shape + return list(vshape[: (len(vshape) - len(vdef.shape))]) + + +def get_atom_axis( + vdef: torch.Tensor, +): + """Get the axis of atoms.""" + atom_axis = -(len(vdef.shape) + 1) + return atom_axis + + +def take_deriv( + vv: torch.Tensor, + svv: torch.Tensor, + vdef: OutputVariableDef, + coord_ext: torch.Tensor, + do_atomic_virial: bool = False, +): + size = 1 + for ii in vdef.shape: + size *= ii + vv1 = vv.view(list(get_leading_dims(vv, vdef)) + [size]) # noqa: RUF005 + svv1 = svv.view(list(get_leading_dims(svv, vdef)) + [size]) # noqa: RUF005 + split_vv1 = torch.split(vv1, [1] * size, dim=-1) + split_svv1 = torch.split(svv1, [1] * size, dim=-1) + split_ff, split_avir = [], [] + for vvi, svvi in zip(split_vv1, split_svv1): + # nf x nloc x 3, nf x nloc x 3 x 3 + ffi, aviri = task_deriv_one( + vvi, svvi, coord_ext, do_atomic_virial=do_atomic_virial + ) + # nf x nloc x 1 x 3, nf x nloc x 1 x 3 x 3 + ffi = ffi.unsqueeze(-2) + aviri = aviri.unsqueeze(-3) + split_ff.append(ffi) + split_avir.append(aviri) + # nf x nloc x v_dim x 3, nf x nloc x v_dim x 3 x 3 + ff = torch.concat(split_ff, dim=-2) + avir = torch.concat(split_avir, dim=-3) + return ff, avir + + +def fit_output_to_model_output( + fit_ret: Dict[str, torch.Tensor], + fit_output_def: FittingOutputDef, + coord_ext: torch.Tensor, + do_atomic_virial: bool = False, +) -> Dict[str, torch.Tensor]: + """Transform the output of the fitting network to + the model output. + + """ + model_ret = dict(fit_ret.items()) + for kk, vv in fit_ret.items(): + vdef = fit_output_def[kk] + shap = vdef.shape + atom_axis = -(len(shap) + 1) + if vdef.reduciable: + kk_redu = get_reduce_name(kk) + model_ret[kk_redu] = torch.sum(vv, dim=atom_axis) + if vdef.differentiable: + kk_derv_r, kk_derv_c = get_deriv_name(kk) + dr, dc = take_deriv( + vv, + model_ret[kk_redu], + vdef, + coord_ext, + do_atomic_virial=do_atomic_virial, + ) + model_ret[kk_derv_r] = dr + model_ret[kk_derv_c] = dc + return model_ret + + +def communicate_extended_output( + model_ret: Dict[str, torch.Tensor], + model_output_def: ModelOutputDef, + mapping: torch.Tensor, # nf x nloc + do_atomic_virial: bool = False, +) -> Dict[str, torch.Tensor]: + """Transform the output of the model network defined on + local and ghost (extended) atoms to local atoms. + + """ + new_ret = {} + for kk in model_output_def.keys_outp(): + vv = model_ret[kk] + vdef = model_output_def[kk] + new_ret[kk] = vv + if vdef.reduciable: + kk_redu = get_reduce_name(kk) + new_ret[kk_redu] = model_ret[kk_redu] + if vdef.differentiable: + # nf x nloc + vldims = get_leading_dims(vv, vdef) + # nf x nall + mldims = list(mapping.shape) + kk_derv_r, kk_derv_c = get_deriv_name(kk) + # vdim x 3 + derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005 + mapping = mapping.view(mldims + [1] * len(derv_r_ext_dims)).expand( + [-1] * len(mldims) + derv_r_ext_dims + ) + force = torch.zeros( + vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device + ) + # nf x nloc x 1 x 3 + new_ret[kk_derv_r] = torch.scatter_reduce( + force, + 1, + index=mapping, + src=model_ret[kk_derv_r], + reduce="sum", + ) + mapping = mapping.unsqueeze(-1).expand( + [-1] * (len(mldims) + len(derv_r_ext_dims)) + [3] + ) + virial = torch.zeros( + vldims + derv_r_ext_dims + [3], dtype=vv.dtype, device=vv.device + ) + # nf x nloc x 1 x 3 + new_ret[kk_derv_c] = torch.scatter_reduce( + virial, + 1, + index=mapping, + src=model_ret[kk_derv_c], + reduce="sum", + ) + new_ret[kk_derv_c + "_redu"] = torch.sum(new_ret[kk_derv_c], dim=1) + if not do_atomic_virial: + # pop atomic virial, because it is not correctly calculated. + new_ret.pop(kk_derv_c) + return new_ret diff --git a/deepmd/pt/model/network/__init__.py b/deepmd/pt/model/network/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt/model/network/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py new file mode 100644 index 0000000000..e3ac0e7bc2 --- /dev/null +++ b/deepmd/pt/model/network/mlp.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + ClassVar, + Dict, + Optional, +) + +import numpy as np +import torch +import torch.nn as nn + +from deepmd.pt.utils import ( + env, +) + +device = env.DEVICE + +from deepmd.model_format import ( + NativeLayer, +) +from deepmd.model_format import NetworkCollection as DPNetworkCollection +from deepmd.model_format import ( + make_embedding_network, + make_fitting_network, + make_multilayer_network, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + + +def empty_t(shape, precision): + return torch.empty(shape, dtype=precision, device=device) + + +class MLPLayer(nn.Module): + def __init__( + self, + num_in, + num_out, + bias: bool = True, + use_timestep: bool = False, + activation_function: Optional[str] = None, + resnet: bool = False, + bavg: float = 0.0, + stddev: float = 1.0, + precision: str = DEFAULT_PRECISION, + ): + super().__init__() + self.use_timestep = use_timestep + self.activate_name = activation_function + self.activate = ActivationFn(self.activate_name) + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.matrix = nn.Parameter(data=empty_t((num_in, num_out), self.prec)) + nn.init.normal_(self.matrix.data, std=stddev / np.sqrt(num_out + num_in)) + if bias: + self.bias = nn.Parameter( + data=empty_t([num_out], self.prec), + ) + nn.init.normal_(self.bias.data, mean=bavg, std=stddev) + else: + self.bias = None + if self.use_timestep: + self.idt = nn.Parameter(data=empty_t([num_out], self.prec)) + nn.init.normal_(self.idt.data, mean=0.1, std=0.001) + else: + self.idt = None + self.resnet = resnet + + def check_type_consistency(self): + precision = self.precision + + def check_var(var): + if var is not None: + # assertion "float64" == "double" would fail + assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision] + + check_var(self.w) + check_var(self.b) + check_var(self.idt) + + def dim_in(self) -> int: + return self.matrix.shape[0] + + def dim_out(self) -> int: + return self.matrix.shape[1] + + def forward( + self, + xx: torch.Tensor, + ) -> torch.Tensor: + """One MLP layer used by DP model. + + Parameters + ---------- + xx : torch.Tensor + The input. + + Returns + ------- + yy: torch.Tensor + The output. + """ + yy = ( + torch.matmul(xx, self.matrix) + self.bias + if self.bias is not None + else torch.matmul(xx, self.matrix) + ) + yy = self.activate(yy).clone() + yy = yy * self.idt if self.idt is not None else yy + if self.resnet: + if xx.shape[-1] == yy.shape[-1]: + yy += xx + elif 2 * xx.shape[-1] == yy.shape[-1]: + yy += torch.concat([xx, xx], dim=-1) + else: + yy = yy + return yy + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + nl = NativeLayer( + self.matrix.shape[0], + self.matrix.shape[1], + bias=self.bias is not None, + use_timestep=self.idt is not None, + activation_function=self.activate_name, + resnet=self.resnet, + precision=self.precision, + ) + nl.w, nl.b, nl.idt = ( + self.matrix.detach().cpu().numpy(), + self.bias.detach().cpu().numpy() if self.bias is not None else None, + self.idt.detach().cpu().numpy() if self.idt is not None else None, + ) + return nl.serialize() + + @classmethod + def deserialize(cls, data: dict) -> "MLPLayer": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + nl = NativeLayer.deserialize(data) + obj = cls( + nl["matrix"].shape[0], + nl["matrix"].shape[1], + bias=nl["bias"] is not None, + use_timestep=nl["idt"] is not None, + activation_function=nl["activation_function"], + resnet=nl["resnet"], + precision=nl["precision"], + ) + prec = PRECISION_DICT[obj.precision] + + def check_load_param(ss): + return ( + nn.Parameter(data=torch.tensor(nl[ss], dtype=prec, device=device)) + if nl[ss] is not None + else None + ) + + obj.matrix = check_load_param("matrix") + obj.bias = check_load_param("bias") + obj.idt = check_load_param("idt") + return obj + + +MLP_ = make_multilayer_network(MLPLayer, nn.Module) + + +class MLP(MLP_): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.layers = torch.nn.ModuleList(self.layers) + + forward = MLP_.call + + +EmbeddingNet = make_embedding_network(MLP, MLPLayer) + +FittingNet = make_fitting_network(EmbeddingNet, MLP, MLPLayer) + + +class NetworkCollection(DPNetworkCollection, nn.Module): + """PyTorch implementation of NetworkCollection.""" + + NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = { + "network": MLP, + "embedding_network": EmbeddingNet, + # "fitting_network": FittingNet, + } + + def __init__(self, *args, **kwargs): + # init both two base classes + DPNetworkCollection.__init__(self, *args, **kwargs) + nn.Module.__init__(self) + self.networks = self._networks = torch.nn.ModuleList(self._networks) diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py new file mode 100644 index 0000000000..8b5b3cf998 --- /dev/null +++ b/deepmd/pt/model/network/network.py @@ -0,0 +1,1897 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from deepmd.pt.utils import ( + env, +) + +try: + from typing import ( + Final, + ) +except ImportError: + from torch.jit import Final + +from functools import ( + partial, +) + +import torch.utils.checkpoint + +from deepmd.pt.utils.utils import ( + ActivationFn, + get_activation_fn, +) + + +def Tensor(*shape): + return torch.empty(shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + + +class Dropout(nn.Module): + def __init__(self, p): + super().__init__() + self.p = p + + def forward(self, x, inplace: bool = False): + if self.p > 0 and self.training: + return F.dropout(x, p=self.p, training=True, inplace=inplace) + else: + return x + + +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +class DropPath(torch.nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, prob=None): + super().__init__() + self.drop_prob = prob + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + def extra_repr(self) -> str: + return f"prob={self.drop_prob}" + + +def softmax_dropout( + input_x, dropout_prob, is_training=True, mask=None, bias=None, inplace=True +): + input_x = input_x.contiguous() + if not inplace: + input_x = input_x.clone() + if mask is not None: + input_x += mask + if bias is not None: + input_x += bias + return F.dropout(F.softmax(input_x, dim=-1), p=dropout_prob, training=is_training) + + +def checkpoint_sequential( + functions, + input_x, + enabled=True, +): + def wrap_tuple(a): + return (a,) if type(a) is not tuple else a + + def exec(func, a): + return wrap_tuple(func(*a)) + + def get_wrap_exec(func): + def wrap_exec(*a): + return exec(func, a) + + return wrap_exec + + input_x = wrap_tuple(input_x) + + is_grad_enabled = torch.is_grad_enabled() + + if enabled and is_grad_enabled: + for func in functions: + input_x = torch.utils.checkpoint.checkpoint(get_wrap_exec(func), *input_x) + else: + for func in functions: + input_x = exec(func, input_x) + return input_x + + +class ResidualLinear(nn.Module): + resnet: Final[int] + + def __init__(self, num_in, num_out, bavg=0.0, stddev=1.0, resnet_dt=False): + """Construct a residual linear layer. + + Args: + - num_in: Width of input tensor. + - num_out: Width of output tensor. + - resnet_dt: Using time-step in the ResNet construction. + """ + super().__init__() + self.num_in = num_in + self.num_out = num_out + self.resnet = resnet_dt + + self.matrix = nn.Parameter(data=Tensor(num_in, num_out)) + nn.init.normal_(self.matrix.data, std=stddev / np.sqrt(num_out + num_in)) + self.bias = nn.Parameter(data=Tensor(1, num_out)) + nn.init.normal_(self.bias.data, mean=bavg, std=stddev) + if self.resnet: + self.idt = nn.Parameter(data=Tensor(1, num_out)) + nn.init.normal_(self.idt.data, mean=1.0, std=0.001) + + def forward(self, inputs): + """Return X ?+ X*W+b.""" + xw_plus_b = torch.matmul(inputs, self.matrix) + self.bias + hidden = torch.tanh(xw_plus_b) + if self.resnet: + hidden = hidden * self.idt + if self.num_in == self.num_out: + return inputs + hidden + elif self.num_in * 2 == self.num_out: + return torch.cat([inputs, inputs], dim=1) + hidden + else: + return hidden + + +class TypeFilter(nn.Module): + use_tebd: Final[bool] + tebd_mode: Final[str] + + def __init__( + self, + offset, + length, + neuron, + return_G=False, + tebd_dim=0, + use_tebd=False, + tebd_mode="concat", + ): + """Construct a filter on the given element as neighbor. + + Args: + - offset: Element offset in the descriptor matrix. + - length: Atom count of this element. + - neuron: Number of neurons in each hidden layers of the embedding net. + """ + super().__init__() + self.offset = offset + self.length = length + self.tebd_dim = tebd_dim + self.use_tebd = use_tebd + self.tebd_mode = tebd_mode + supported_tebd_mode = ["concat", "dot", "dot_residual_s", "dot_residual_t"] + assert ( + tebd_mode in supported_tebd_mode + ), f"Unknown tebd_mode {tebd_mode}! Supported are {supported_tebd_mode}." + if use_tebd and tebd_mode == "concat": + self.neuron = [1 + tebd_dim * 2, *neuron] + else: + self.neuron = [1, *neuron] + + deep_layers = [] + for ii in range(1, len(self.neuron)): + one = ResidualLinear(self.neuron[ii - 1], self.neuron[ii]) + deep_layers.append(one) + self.deep_layers = nn.ModuleList(deep_layers) + + deep_layers_t = [] + if use_tebd and tebd_mode in ["dot", "dot_residual_s", "dot_residual_t"]: + self.neuron_t = [tebd_dim * 2, *neuron] + for ii in range(1, len(self.neuron_t)): + one = ResidualLinear(self.neuron_t[ii - 1], self.neuron_t[ii]) + deep_layers_t.append(one) + self.deep_layers_t = nn.ModuleList(deep_layers_t) + + self.return_G = return_G + + def forward( + self, + inputs, + atype_tebd: Optional[torch.Tensor] = None, + nlist_tebd: Optional[torch.Tensor] = None, + ): + """Calculate decoded embedding for each atom. + + Args: + - inputs: Descriptor matrix. Its shape is [nframes*natoms[0], len_descriptor]. + + Returns + ------- + - `torch.Tensor`: Embedding contributed by me. Its shape is [nframes*natoms[0], 4, self.neuron[-1]]. + """ + inputs_i = inputs[:, self.offset * 4 : (self.offset + self.length) * 4] + inputs_reshape = inputs_i.reshape( + -1, 4 + ) # shape is [nframes*natoms[0]*self.length, 4] + xyz_scatter = inputs_reshape[:, 0:1] + + # concat the tebd as input + if self.use_tebd and self.tebd_mode == "concat": + assert nlist_tebd is not None and atype_tebd is not None + nlist_tebd = nlist_tebd.reshape(-1, self.tebd_dim) + atype_tebd = atype_tebd.reshape(-1, self.tebd_dim) + # [nframes * nloc * nnei, 1 + tebd_dim * 2] + xyz_scatter = torch.concat([xyz_scatter, nlist_tebd, atype_tebd], dim=1) + + for linear in self.deep_layers: + xyz_scatter = linear(xyz_scatter) + # [nframes * nloc * nnei, out_size] + + # dot the tebd output + if self.use_tebd and self.tebd_mode in [ + "dot", + "dot_residual_s", + "dot_residual_t", + ]: + assert nlist_tebd is not None and atype_tebd is not None + nlist_tebd = nlist_tebd.reshape(-1, self.tebd_dim) + atype_tebd = atype_tebd.reshape(-1, self.tebd_dim) + # [nframes * nloc * nnei, tebd_dim * 2] + two_side_tebd = torch.concat([nlist_tebd, atype_tebd], dim=1) + for linear in self.deep_layers_t: + two_side_tebd = linear(two_side_tebd) + # [nframes * nloc * nnei, out_size] + if self.tebd_mode == "dot": + xyz_scatter = xyz_scatter * two_side_tebd + elif self.tebd_mode == "dot_residual_s": + xyz_scatter = xyz_scatter * two_side_tebd + xyz_scatter + elif self.tebd_mode == "dot_residual_t": + xyz_scatter = xyz_scatter * two_side_tebd + two_side_tebd + + xyz_scatter = xyz_scatter.view( + -1, self.length, self.neuron[-1] + ) # shape is [nframes*natoms[0], self.length, self.neuron[-1]] + if self.return_G: + return xyz_scatter + else: + # shape is [nframes*natoms[0], 4, self.length] + inputs_reshape = inputs_i.view(-1, self.length, 4).permute(0, 2, 1) + return torch.matmul(inputs_reshape, xyz_scatter) + + +class SimpleLinear(nn.Module): + use_timestep: Final[bool] + + def __init__( + self, + num_in, + num_out, + bavg=0.0, + stddev=1.0, + use_timestep=False, + activate=None, + bias: bool = True, + ): + """Construct a linear layer. + + Args: + - num_in: Width of input tensor. + - num_out: Width of output tensor. + - use_timestep: Apply time-step to weight. + - activate: type of activate func. + """ + super().__init__() + self.num_in = num_in + self.num_out = num_out + self.use_timestep = use_timestep + self.activate = ActivationFn(activate) + + self.matrix = nn.Parameter(data=Tensor(num_in, num_out)) + nn.init.normal_(self.matrix.data, std=stddev / np.sqrt(num_out + num_in)) + if bias: + self.bias = nn.Parameter(data=Tensor(1, num_out)) + nn.init.normal_(self.bias.data, mean=bavg, std=stddev) + else: + self.bias = None + if self.use_timestep: + self.idt = nn.Parameter(data=Tensor(1, num_out)) + nn.init.normal_(self.idt.data, mean=0.1, std=0.001) + + def forward(self, inputs): + """Return X*W+b.""" + xw = torch.matmul(inputs, self.matrix) + hidden = xw + self.bias if self.bias is not None else xw + hidden = self.activate(hidden) + if self.use_timestep: + hidden = hidden * self.idt + return hidden + + +class Linear(nn.Linear): + def __init__( + self, + d_in: int, + d_out: int, + bias: bool = True, + init: str = "default", + ): + super().__init__(d_in, d_out, bias=bias, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + + self.use_bias = bias + + if self.use_bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init == "default": + self._trunc_normal_init(1.0) + elif init == "relu": + self._trunc_normal_init(2.0) + elif init == "glorot": + self._glorot_uniform_init() + elif init == "gating": + self._zero_init(self.use_bias) + elif init == "normal": + self._normal_init() + elif init == "final": + self._zero_init(False) + else: + raise ValueError("Invalid init method.") + + def _trunc_normal_init(self, scale=1.0): + # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 + _, fan_in = self.weight.shape + scale = scale / max(1, fan_in) + std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR + nn.init.trunc_normal_(self.weight, mean=0.0, std=std) + + def _glorot_uniform_init(self): + nn.init.xavier_uniform_(self.weight, gain=1) + + def _zero_init(self, use_bias=True): + with torch.no_grad(): + self.weight.fill_(0.0) + if use_bias: + with torch.no_grad(): + self.bias.fill_(1.0) + + def _normal_init(self): + nn.init.kaiming_normal_(self.weight, nonlinearity="linear") + + +class Transition(nn.Module): + def __init__(self, d_in, n, dropout=0.0): + super().__init__() + + self.d_in = d_in + self.n = n + + self.linear_1 = Linear(self.d_in, self.n * self.d_in, init="relu") + self.act = nn.GELU() + self.linear_2 = Linear(self.n * self.d_in, d_in, init="final") + self.dropout = dropout + + def _transition(self, x): + x = self.linear_1(x) + x = self.act(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.linear_2(x) + return x + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + x = self._transition(x=x) + return x + + +class Embedding(nn.Embedding): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + dtype=torch.float64, + ): + super().__init__( + num_embeddings, embedding_dim, padding_idx=padding_idx, dtype=dtype + ) + self._normal_init() + + if padding_idx is not None: + self.weight.data[self.padding_idx].zero_() + + def _normal_init(self, std=0.02): + nn.init.normal_(self.weight, mean=0.0, std=std) + + +class NonLinearHead(nn.Module): + def __init__(self, input_dim, out_dim, activation_fn, hidden=None): + super().__init__() + hidden = input_dim if not hidden else hidden + self.linear1 = SimpleLinear(input_dim, hidden, activate=activation_fn) + self.linear2 = SimpleLinear(hidden, out_dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class NonLinear(nn.Module): + def __init__(self, input, output_size, hidden=None): + super().__init__() + + if hidden is None: + hidden = input + self.layer1 = Linear(input, hidden, init="relu") + self.layer2 = Linear(hidden, output_size, init="final") + + def forward(self, x): + x = F.linear(x, self.layer1.weight) + # x = fused_ops.bias_torch_gelu(x, self.layer1.bias) + x = nn.GELU()(x) + self.layer1.bias + x = self.layer2(x) + return x + + def zero_init(self): + nn.init.zeros_(self.layer2.weight) + nn.init.zeros_(self.layer2.bias) + + +class MaskLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, embed_dim, output_dim, activation_fn, weight=None): + super().__init__() + self.dense = SimpleLinear(embed_dim, embed_dim) + self.activation_fn = get_activation_fn(activation_fn) + self.layer_norm = nn.LayerNorm(embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + + if weight is None: + weight = nn.Linear( + embed_dim, output_dim, bias=False, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ).weight + self.weight = weight + self.bias = nn.Parameter( + torch.zeros(output_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + ) + + def forward(self, features, masked_tokens: Optional[torch.Tensor] = None, **kwargs): + # Only project the masked tokens while training, + # saves both memory and computation + if masked_tokens is not None: + features = features[masked_tokens, :] + + x = self.dense(features) + x = self.activation_fn(x) + x = self.layer_norm(x) + # project back to size of vocabulary with bias + x = F.linear(x, self.weight) + self.bias + return x + + +class ResidualDeep(nn.Module): + def __init__( + self, type_id, embedding_width, neuron, bias_atom_e, out_dim=1, resnet_dt=False + ): + """Construct a filter on the given element as neighbor. + + Args: + - typei: Element ID. + - embedding_width: Embedding width per atom. + - neuron: Number of neurons in each hidden layers of the embedding net. + - resnet_dt: Using time-step in the ResNet construction. + """ + super().__init__() + self.type_id = type_id + self.neuron = [embedding_width, *neuron] + self.out_dim = out_dim + + deep_layers = [] + for ii in range(1, len(self.neuron)): + one = SimpleLinear( + num_in=self.neuron[ii - 1], + num_out=self.neuron[ii], + use_timestep=( + resnet_dt and ii > 1 and self.neuron[ii - 1] == self.neuron[ii] + ), + activate="tanh", + ) + deep_layers.append(one) + self.deep_layers = nn.ModuleList(deep_layers) + if not env.ENERGY_BIAS_TRAINABLE: + bias_atom_e = 0 + self.final_layer = SimpleLinear(self.neuron[-1], self.out_dim, bias_atom_e) + + def forward(self, inputs): + """Calculate decoded embedding for each atom. + + Args: + - inputs: Embedding net output per atom. Its shape is [nframes*nloc, self.embedding_width]. + + Returns + ------- + - `torch.Tensor`: Output layer with shape [nframes*nloc, self.neuron[-1]]. + """ + outputs = inputs + for idx, linear in enumerate(self.deep_layers): + if idx > 0 and linear.num_in == linear.num_out: + outputs = outputs + linear(outputs) + else: + outputs = linear(outputs) + outputs = self.final_layer(outputs) + return outputs + + +class TypeEmbedNet(nn.Module): + def __init__(self, type_nums, embed_dim, bavg=0.0, stddev=1.0): + """Construct a type embedding net.""" + super().__init__() + self.embedding = nn.Embedding( + type_nums + 1, + embed_dim, + padding_idx=type_nums, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + ) + # nn.init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev) + + def forward(self, atype): + """ + Args: + atype: Type of each input, [nframes, nloc] or [nframes, nloc, nnei]. + + Returns + ------- + type_embedding: + + """ + return self.embedding(atype) + + def share_params(self, base_class, shared_level, resume=False): + assert ( + self.__class__ == base_class.__class__ + ), "Only TypeEmbedNet of the same type can share params!" + if shared_level == 0: + # the following will successfully link all the params except buffers, which need manually link. + for item in self._modules: + self._modules[item] = base_class._modules[item] + else: + raise NotImplementedError + + +@torch.jit.script +def gaussian(x, mean, std: float): + pi = 3.14159 + a = (2 * pi) ** 0.5 + return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) + + +class GaussianKernel(nn.Module): + def __init__(self, K=128, num_pair=512, std_width=1.0, start=0.0, stop=9.0): + super().__init__() + self.K = K + std_width = std_width + start = start + stop = stop + mean = torch.linspace(start, stop, K, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + self.std = (std_width * (mean[1] - mean[0])).item() + self.register_buffer("mean", mean) + self.mul = Embedding( + num_pair + 1, 1, padding_idx=num_pair, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.bias = Embedding( + num_pair + 1, 1, padding_idx=num_pair, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + nn.init.constant_(self.bias.weight, 0) + nn.init.constant_(self.mul.weight, 1.0) + + def forward(self, x, atom_pair): + mul = self.mul(atom_pair).abs().sum(dim=-2) + bias = self.bias(atom_pair).sum(dim=-2) + x = mul * x.unsqueeze(-1) + bias + # [nframes, nloc, nnei, K] + x = x.expand(-1, -1, -1, self.K) + mean = self.mean.view(-1) + return gaussian(x, mean, self.std) + + +class GaussianEmbedding(nn.Module): + def __init__( + self, + rcut, + kernel_num, + num_pair, + embed_dim, + pair_embed_dim, + sel, + ntypes, + atomic_sum_gbf, + ): + """Construct a gaussian kernel based embedding of pair representation. + + Args: + rcut: Radial cutoff. + kernel_num: Number of gaussian kernels. + num_pair: Number of different pairs. + embed_dim: Dimension of atomic representation. + pair_embed_dim: Dimension of pair representation. + sel: Number of neighbors. + ntypes: Number of atom types. + """ + super().__init__() + self.gbf = GaussianKernel(K=kernel_num, num_pair=num_pair, stop=rcut) + self.gbf_proj = NonLinear(kernel_num, pair_embed_dim) + self.embed_dim = embed_dim + self.pair_embed_dim = pair_embed_dim + self.atomic_sum_gbf = atomic_sum_gbf + if self.atomic_sum_gbf: + if kernel_num != self.embed_dim: + self.edge_proj = torch.nn.Linear( + kernel_num, self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + else: + self.edge_proj = None + self.ntypes = ntypes + self.nnei = sel + + def forward(self, coord_selected, atom_feature, edge_type_2dim, edge_feature): + ## local cluster forward + """Calculate decoded embedding for each atom. + Args: + coord_selected: Clustered atom coordinates with shape [nframes*nloc, natoms, 3]. + atom_feature: Previous calculated atomic features with shape [nframes*nloc, natoms, embed_dim]. + edge_type_2dim: Edge index for gbf calculation with shape [nframes*nloc, natoms, natoms, 2]. + edge_feature: Previous calculated edge features with shape [nframes*nloc, natoms, natoms, pair_dim]. + + Returns + ------- + atom_feature: Updated atomic features with shape [nframes*nloc, natoms, embed_dim]. + attn_bias: Updated edge features as attention bias with shape [nframes*nloc, natoms, natoms, pair_dim]. + delta_pos: Delta position for force/vector prediction with shape [nframes*nloc, natoms, natoms, 3]. + """ + ncluster, natoms, _ = coord_selected.shape + # ncluster x natoms x natoms x 3 + delta_pos = coord_selected.unsqueeze(1) - coord_selected.unsqueeze(2) + # (ncluster x natoms x natoms + dist = delta_pos.norm(dim=-1).view(-1, natoms, natoms) + # [ncluster, natoms, natoms, K] + gbf_feature = self.gbf(dist, edge_type_2dim) + if self.atomic_sum_gbf: + edge_features = gbf_feature + # [ncluster, natoms, K] + sum_edge_features = edge_features.sum(dim=-2) + if self.edge_proj is not None: + sum_edge_features = self.edge_proj(sum_edge_features) + # [ncluster, natoms, embed_dim] + atom_feature = atom_feature + sum_edge_features + + # [ncluster, natoms, natoms, pair_dim] + gbf_result = self.gbf_proj(gbf_feature) + + attn_bias = gbf_result + edge_feature + return atom_feature, attn_bias, delta_pos + + +class NeighborWiseAttention(nn.Module): + def __init__( + self, + layer_num, + nnei, + embed_dim, + hidden_dim, + dotr=False, + do_mask=False, + post_ln=True, + ffn=False, + ffn_embed_dim=1024, + activation="tanh", + scaling_factor=1.0, + head_num=1, + normalize=True, + temperature=None, + ): + """Construct a neighbor-wise attention net.""" + super().__init__() + self.layer_num = layer_num + attention_layers = [] + for i in range(self.layer_num): + attention_layers.append( + NeighborWiseAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + post_ln=post_ln, + ffn=ffn, + ffn_embed_dim=ffn_embed_dim, + activation=activation, + scaling_factor=scaling_factor, + head_num=head_num, + normalize=normalize, + temperature=temperature, + ) + ) + self.attention_layers = nn.ModuleList(attention_layers) + + def forward( + self, + input_G, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + ): + """ + Args: + input_G: Input G, [nframes * nloc, nnei, embed_dim]. + nei_mask: neighbor mask, [nframes * nloc, nnei]. + input_r: normalized radial, [nframes, nloc, nei, 3]. + + Returns + ------- + out: Output G, [nframes * nloc, nnei, embed_dim] + + """ + out = input_G + # https://github.com/pytorch/pytorch/issues/39165#issuecomment-635472592 + for layer in self.attention_layers: + out = layer(out, nei_mask, input_r=input_r, sw=sw) + return out + + +class NeighborWiseAttentionLayer(nn.Module): + ffn: Final[bool] + + def __init__( + self, + nnei, + embed_dim, + hidden_dim, + dotr=False, + do_mask=False, + post_ln=True, + ffn=False, + ffn_embed_dim=1024, + activation="tanh", + scaling_factor=1.0, + head_num=1, + normalize=True, + temperature=None, + ): + """Construct a neighbor-wise attention layer.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.post_ln = post_ln + self.ffn = ffn + self.attention_layer = GatedSelfAttetion( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + head_num=head_num, + normalize=normalize, + temperature=temperature, + ) + self.attn_layer_norm = nn.LayerNorm( + self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + if self.ffn: + self.ffn_embed_dim = ffn_embed_dim + self.fc1 = nn.Linear( + self.embed_dim, self.ffn_embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.activation_fn = get_activation_fn(activation) + self.fc2 = nn.Linear( + self.ffn_embed_dim, self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.final_layer_norm = nn.LayerNorm( + self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + + def forward( + self, + x, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + ): + residual = x + if not self.post_ln: + x = self.attn_layer_norm(x) + x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x = residual + x + if self.post_ln: + x = self.attn_layer_norm(x) + if self.ffn: + residual = x + if not self.post_ln: + x = self.final_layer_norm(x) + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + x = residual + x + if self.post_ln: + x = self.final_layer_norm(x) + return x + + +class GatedSelfAttetion(nn.Module): + def __init__( + self, + nnei, + embed_dim, + hidden_dim, + dotr=False, + do_mask=False, + scaling_factor=1.0, + head_num=1, + normalize=True, + temperature=None, + bias=True, + smooth=True, + ): + """Construct a neighbor-wise attention net.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.dotr = dotr + self.do_mask = do_mask + if temperature is None: + self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 + else: + self.scaling = temperature + self.normalize = normalize + self.in_proj = SimpleLinear( + embed_dim, + hidden_dim * 3, + bavg=0.0, + stddev=1.0, + use_timestep=False, + bias=bias, + ) + self.out_proj = SimpleLinear( + hidden_dim, embed_dim, bavg=0.0, stddev=1.0, use_timestep=False, bias=bias + ) + self.smooth = smooth + + def forward( + self, + query, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + attnw_shift: float = 20.0, + ): + """ + Args: + query: input G, [nframes * nloc, nnei, embed_dim]. + nei_mask: neighbor mask, [nframes * nloc, nnei]. + input_r: normalized radial, [nframes, nloc, nei, 3]. + + Returns + ------- + type_embedding: + + """ + q, k, v = self.in_proj(query).chunk(3, dim=-1) + # [nframes * nloc, nnei, hidden_dim] + q = q.view(-1, self.nnei, self.hidden_dim) + k = k.view(-1, self.nnei, self.hidden_dim) + v = v.view(-1, self.nnei, self.hidden_dim) + if self.normalize: + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + v = F.normalize(v, dim=-1) + q = q * self.scaling + k = k.transpose(1, 2) + # [nframes * nloc, nnei, nnei] + attn_weights = torch.bmm(q, k) + # [nframes * nloc, nnei] + nei_mask = nei_mask.view(-1, self.nnei) + if self.smooth: + # [nframes * nloc, nnei] + assert sw is not None + sw = sw.view([-1, self.nnei]) + attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[ + :, None, : + ] - attnw_shift + else: + attn_weights = attn_weights.masked_fill( + ~nei_mask.unsqueeze(1), float("-inf") + ) + attn_weights = F.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), 0.0) + if self.smooth: + assert sw is not None + attn_weights = attn_weights * sw[:, :, None] * sw[:, None, :] + if self.dotr: + assert input_r is not None, "input_r must be provided when dotr is True!" + angular_weight = torch.bmm(input_r, input_r.transpose(1, 2)) + attn_weights = attn_weights * angular_weight + o = torch.bmm(attn_weights, v) + output = self.out_proj(o) + return output + + +class LocalSelfMultiheadAttention(nn.Module): + def __init__(self, feature_dim, attn_head, scaling_factor=1.0): + super().__init__() + self.feature_dim = feature_dim + self.attn_head = attn_head + self.head_dim = feature_dim // attn_head + assert ( + feature_dim % attn_head == 0 + ), f"feature_dim {feature_dim} must be divided by attn_head {attn_head}!" + self.scaling = (self.head_dim * scaling_factor) ** -0.5 + self.in_proj = SimpleLinear(self.feature_dim, self.feature_dim * 3) + # TODO debug + # self.out_proj = SimpleLinear(self.feature_dim, self.feature_dim) + + def forward( + self, + query, + attn_bias: Optional[torch.Tensor] = None, + nlist_mask: Optional[torch.Tensor] = None, + nlist: Optional[torch.Tensor] = None, + return_attn=True, + ): + nframes, nloc, feature_dim = query.size() + _, _, nnei = nlist.size() + assert feature_dim == self.feature_dim + # [nframes, nloc, feature_dim] + q, k, v = self.in_proj(query).chunk(3, dim=-1) + # [nframes * attn_head * nloc, 1, head_dim] + q = ( + q.view(nframes, nloc, self.attn_head, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(nframes * self.attn_head * nloc, 1, self.head_dim) + * self.scaling + ) + # [nframes, nloc, feature_dim] --> [nframes, nloc + 1, feature_dim] + # with nlist [nframes, nloc, nnei] --> [nframes, nloc, nnei, feature_dim] + # padding = torch.zeros(feature_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION).to(k.device) + # k = torch.concat([k, padding.unsqueeze(0).unsqueeze(1)], dim=1) + # v = torch.concat([v, padding.unsqueeze(0).unsqueeze(1)], dim=1) + + # [nframes, nloc * nnei, feature_dim] + index = nlist.view(nframes, -1).unsqueeze(-1).expand(-1, -1, feature_dim) + k = torch.gather(k, dim=1, index=index) + # [nframes, nloc * nnei, feature_dim] + v = torch.gather(v, dim=1, index=index) + # [nframes * attn_head * nloc, nnei, head_dim] + k = ( + k.view(nframes, nloc, nnei, self.attn_head, self.head_dim) + .permute(0, 3, 1, 2, 4) + .contiguous() + .view(nframes * self.attn_head * nloc, nnei, self.head_dim) + ) + v = ( + v.view(nframes, nloc, nnei, self.attn_head, self.head_dim) + .permute(0, 3, 1, 2, 4) + .contiguous() + .view(nframes * self.attn_head * nloc, nnei, self.head_dim) + ) + # [nframes * attn_head * nloc, 1, nnei] + attn_weights = torch.bmm(q, k.transpose(1, 2)) + # maskfill + # [nframes, attn_head, nloc, nnei] + attn_weights = attn_weights.view( + nframes, self.attn_head, nloc, nnei + ).masked_fill(~nlist_mask.unsqueeze(1), float("-inf")) + # add bias + if return_attn: + attn_weights = attn_weights + attn_bias + # softmax + # [nframes * attn_head * nloc, 1, nnei] + attn = F.softmax(attn_weights, dim=-1).view( + nframes * self.attn_head * nloc, 1, nnei + ) + # bmm + # [nframes * attn_head * nloc, 1, head_dim] + o = torch.bmm(attn, v) + assert list(o.size()) == [nframes * self.attn_head * nloc, 1, self.head_dim] + # [nframes, nloc, feature_dim] + o = ( + o.view(nframes, self.attn_head, nloc, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(nframes, nloc, self.feature_dim) + ) + # out + ## TODO debug: + # o = self.out_proj(o) + if not return_attn: + return o + else: + return o, attn_weights, attn + + +class NodeTaskHead(nn.Module): + def __init__( + self, + embed_dim: int, + pair_dim: int, + num_head: int, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + self.pair_norm = nn.LayerNorm(pair_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + self.embed_dim = embed_dim + self.q_proj = Linear(embed_dim, embed_dim, bias=False, init="glorot") + self.k_proj = Linear(embed_dim, embed_dim, bias=False, init="glorot") + self.v_proj = Linear(embed_dim, embed_dim, bias=False, init="glorot") + self.num_heads = num_head + self.head_dim = embed_dim // num_head + self.scaling = self.head_dim**-0.5 + self.force_proj = Linear(embed_dim, 1, init="final", bias=False) + self.linear_bias = Linear(pair_dim, num_head) + self.dropout = 0.1 + + def zero_init(self): + nn.init.zeros_(self.force_proj.weight) + + def forward( + self, + query: Tensor, + pair: Tensor, + delta_pos: Tensor, + attn_mask: Tensor = None, + ) -> Tensor: + ncluster, natoms, _ = query.size() + query = self.layer_norm(query) + # [ncluster, natoms, natoms, pair_dim] + pair = self.pair_norm(pair) + + # [ncluster, attn_head, natoms, head_dim] + q = ( + self.q_proj(query) + .view(ncluster, natoms, self.num_heads, -1) + .transpose(1, 2) + * self.scaling + ) + # [ncluster, attn_head, natoms, head_dim] + k = ( + self.k_proj(query) + .view(ncluster, natoms, self.num_heads, -1) + .transpose(1, 2) + ) + v = ( + self.v_proj(query) + .view(ncluster, natoms, self.num_heads, -1) + .transpose(1, 2) + ) + # [ncluster, attn_head, natoms, natoms] + attn = q @ k.transpose(-1, -2) + del q, k + # [ncluster, attn_head, natoms, natoms] + bias = self.linear_bias(pair).permute(0, 3, 1, 2).contiguous() + + # [ncluster, attn_head, natoms, natoms] + attn_probs = softmax_dropout( + attn, + self.dropout, + self.training, + mask=attn_mask, + bias=bias.contiguous(), + ).view(ncluster, self.num_heads, natoms, natoms) + + # delta_pos: [ncluster, natoms, natoms, 3] + # [ncluster, attn_head, natoms, natoms, 3] + rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).type_as( + attn_probs + ) + # [ncluster, attn_head, 3, natoms, natoms] + rot_attn_probs = rot_attn_probs.permute(0, 1, 4, 2, 3) + # [ncluster, attn_head, 3, natoms, head_dim] + x = rot_attn_probs @ v.unsqueeze(2) + # [ncluster, natoms, 3, embed_dim] + x = x.permute(0, 3, 2, 1, 4).contiguous().view(ncluster, natoms, 3, -1) + cur_force = self.force_proj(x).view(ncluster, natoms, 3) + return cur_force + + +class EnergyHead(nn.Module): + def __init__( + self, + input_dim, + output_dim, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + self.linear_in = Linear(input_dim, input_dim, init="relu") + + self.linear_out = Linear(input_dim, output_dim, bias=True, init="final") + + def forward(self, x): + x = x.type(self.linear_in.weight.dtype) + x = F.gelu(self.layer_norm(self.linear_in(x))) + x = self.linear_out(x) + return x + + +class OuterProduct(nn.Module): + def __init__(self, d_atom, d_pair, d_hid=32): + super().__init__() + + self.d_atom = d_atom + self.d_pair = d_pair + self.d_hid = d_hid + + self.linear_in = nn.Linear( + d_atom, d_hid * 2, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.linear_out = nn.Linear( + d_hid**2, d_pair, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.act = nn.GELU() + + def _opm(self, a, b): + # [nframes, nloc, d] + nframes, nloc, d = a.shape + a = a.view(nframes, nloc, 1, d, 1) + b = b.view(nframes, 1, nloc, 1, d) + # [nframes, nloc, nloc, d, d] + outer = a * b + outer = outer.view(outer.shape[:-2] + (-1,)) + outer = self.linear_out(outer) + return outer + + def forward( + self, + m: torch.Tensor, + nlist: torch.Tensor, + op_mask: float, + op_norm: float, + ) -> torch.Tensor: + ab = self.linear_in(m) + ab = ab * op_mask + a, b = ab.chunk(2, dim=-1) + # [ncluster, natoms, natoms, d_pair] + z = self._opm(a, b) + z *= op_norm + return z + + +class Attention(nn.Module): + def __init__( + self, + q_dim: int, + k_dim: int, + v_dim: int, + head_dim: int, + num_heads: int, + gating: bool = False, + dropout: float = 0.0, + ): + super().__init__() + + self.num_heads = num_heads + self.head_dim = head_dim + total_dim = head_dim * self.num_heads + self.total_dim = total_dim + self.q_dim = q_dim + self.gating = gating + self.linear_q = Linear(q_dim, total_dim, bias=False, init="glorot") + self.linear_k = Linear(k_dim, total_dim, bias=False, init="glorot") + self.linear_v = Linear(v_dim, total_dim, bias=False, init="glorot") + self.linear_o = Linear(total_dim, q_dim, init="final") + self.linear_g = None + if self.gating: + self.linear_g = Linear(q_dim, total_dim, init="gating") + # precompute the 1/sqrt(head_dim) + self.norm = head_dim**-0.5 + self.dropout = dropout + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bias: torch.Tensor, + mask: torch.Tensor = None, + ) -> torch.Tensor: + nframes, nloc, embed_dim = q.size() + g = None + if self.linear_g is not None: + # gating, use raw query input + # [nframes, nloc, total_dim] + g = self.linear_g(q) + # [nframes, nloc, total_dim] + q = self.linear_q(q) + q *= self.norm + # [nframes, nloc, total_dim] + k = self.linear_k(k) + # [nframes, nloc, total_dim] + v = self.linear_v(v) + # global + # q [nframes, h, nloc, d] + # k [nframes, h, nloc, d] + # v [nframes, h, nloc, d] + # attn [nframes, h, nloc, nloc] + # o [nframes, h, nloc, d] + + # [nframes, h, nloc, d] + q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() + k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() + v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) + # [nframes, h, nloc, nloc] + attn = torch.matmul(q, k.transpose(-1, -2)) + del q, k + # [nframes, h, nloc, nloc] + attn = softmax_dropout(attn, self.dropout, self.training, mask=mask, bias=bias) + # [nframes, h, nloc, d] + o = torch.matmul(attn, v) + del attn, v + + # local + # q [nframes, h, nloc, 1, d] + # k [nframes, h, nloc, nnei, d] + # v [nframes, h, nloc, nnei, d] + # attn [nframes, h, nloc, nnei] + # o [nframes, h, nloc, d] + + assert list(o.size()) == [nframes, self.num_heads, nloc, self.head_dim] + # [nframes, nloc, total_dim] + o = o.transpose(-2, -3).contiguous() + o = o.view(*o.shape[:-2], -1) + + if g is not None: + o = torch.sigmoid(g) * o + + # merge heads + o = self.linear_o(o) + return o + + +class AtomAttention(nn.Module): + def __init__( + self, + q_dim: int, + k_dim: int, + v_dim: int, + pair_dim: int, + head_dim: int, + num_heads: int, + gating: bool = False, + dropout: float = 0.0, + ): + super().__init__() + + self.mha = Attention( + q_dim, k_dim, v_dim, head_dim, num_heads, gating=gating, dropout=dropout + ) + self.layer_norm = nn.LayerNorm(pair_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + self.linear_bias = Linear(pair_dim, num_heads) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + nlist: torch.Tensor, + pair: torch.Tensor, + mask: torch.Tensor = None, + ) -> torch.Tensor: + pair = self.layer_norm(pair) + bias = self.linear_bias(pair).permute(0, 3, 1, 2).contiguous() + return self.mha(q, k, v, bias=bias, mask=mask) + + +class TriangleMultiplication(nn.Module): + def __init__(self, d_pair, d_hid): + super().__init__() + + self.linear_ab_p = Linear(d_pair, d_hid * 2) + self.linear_ab_g = Linear(d_pair, d_hid * 2, init="gating") + + self.linear_g = Linear(d_pair, d_pair, init="gating") + self.linear_z = Linear(d_hid, d_pair, init="final") + + self.layer_norm_out = nn.LayerNorm(d_hid, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + + def forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # z : [nframes, nloc, nloc, pair_dim] + + # [nframes, nloc, nloc, pair_dim] + g = self.linear_g(z) + if self.training: + ab = self.linear_ab_p(z) * torch.sigmoid(self.linear_ab_g(z)) + else: + ab = self.linear_ab_p(z) + ab *= torch.sigmoid(self.linear_ab_g(z)) + # [nframes, nloc, nloc, d] + a, b = torch.chunk(ab, 2, dim=-1) + del z, ab + + # [nframes, d, nloc_i, nloc_k] row not trans + a1 = a.permute(0, 3, 1, 2) + # [nframes, d, nloc_k, nloc_j(i)] trans + b1 = b.transpose(-1, -3) + # [nframes, d, nloc_i, nloc_j] + x = torch.matmul(a1, b1) + del a1, b1 + + # [nframes, d, nloc_k, nloc_j(i)] not trans + b2 = b.permute(0, 3, 1, 2) + # [nframes, d, nloc_i, nloc_k] col trans # check TODO + a2 = a.transpose(-1, -3) + + # [nframes, d, nloc_i, nloc_j] + x = x + torch.matmul(a2, b2) + del a, b, a2, b2 + + # [nframes, nloc_i, nloc_j, d] + x = x.permute(0, 2, 3, 1) + + x = self.layer_norm_out(x) + x = self.linear_z(x) + return g * x + + +class EvoformerEncoderLayer(nn.Module): + def __init__( + self, + feature_dim: int = 768, + ffn_dim: int = 2048, + attn_head: int = 8, + activation_fn: str = "gelu", + post_ln: bool = False, + ): + super().__init__() + self.feature_dim = feature_dim + self.ffn_dim = ffn_dim + self.attn_head = attn_head + self.activation_fn = ( + get_activation_fn(activation_fn) if activation_fn is not None else None + ) + self.post_ln = post_ln + self.self_attn_layer_norm = nn.LayerNorm( + self.feature_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + + self.self_attn = LocalSelfMultiheadAttention( + self.feature_dim, + self.attn_head, + ) + self.final_layer_norm = nn.LayerNorm( + self.feature_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.fc1 = SimpleLinear(self.feature_dim, self.ffn_dim) + self.fc2 = SimpleLinear(self.ffn_dim, self.feature_dim) + + def forward( + self, + x, + attn_bias: Optional[torch.Tensor] = None, + nlist_mask: Optional[torch.Tensor] = None, + nlist: Optional[torch.Tensor] = None, + return_attn=True, + ): + residual = x + if not self.post_ln: + x = self.self_attn_layer_norm(x) + x = self.self_attn( + query=x, + attn_bias=attn_bias, + nlist_mask=nlist_mask, + nlist=nlist, + return_attn=return_attn, + ) + if return_attn: + x, attn_weights, attn_probs = x + x = residual + x + if self.post_ln: + x = self.self_attn_layer_norm(x) + + residual = x + if not self.post_ln: + x = self.final_layer_norm(x) + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + x = residual + x + if self.post_ln: + x = self.final_layer_norm(x) + if not return_attn: + return x + else: + return x, attn_weights, attn_probs + + +# output: atomic_rep, transformed_atomic_rep, pair_rep, delta_pair_rep, norm_x, norm_delta_pair_rep, +class Evoformer2bEncoder(nn.Module): + def __init__( + self, + nnei: int, + layer_num: int = 6, + attn_head: int = 8, + atomic_dim: int = 1024, + pair_dim: int = 100, + feature_dim: int = 1024, + ffn_dim: int = 2048, + post_ln: bool = False, + final_layer_norm: bool = True, + final_head_layer_norm: bool = False, + emb_layer_norm: bool = False, + atomic_residual: bool = False, + evo_residual: bool = False, + residual_factor: float = 1.0, + activation_function: str = "gelu", + ): + super().__init__() + self.nnei = nnei + self.layer_num = layer_num + self.attn_head = attn_head + self.atomic_dim = atomic_dim + self.pair_dim = pair_dim + self.feature_dim = feature_dim + self.ffn_dim = ffn_dim + self.post_ln = post_ln + self._final_layer_norm = final_layer_norm + self._final_head_layer_norm = final_head_layer_norm + self._emb_layer_norm = emb_layer_norm + self.activation_function = activation_function + self.evo_residual = evo_residual + self.residual_factor = residual_factor + if atomic_residual and atomic_dim == feature_dim: + self.atomic_residual = True + else: + self.atomic_residual = False + self.in_proj = SimpleLinear( + self.atomic_dim, + self.feature_dim, + bavg=0.0, + stddev=1.0, + use_timestep=False, + activate="tanh", + ) # TODO + self.out_proj = SimpleLinear( + self.feature_dim, + self.atomic_dim, + bavg=0.0, + stddev=1.0, + use_timestep=False, + activate="tanh", + ) + if self._emb_layer_norm: + self.emb_layer_norm = nn.LayerNorm( + self.feature_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + + ## TODO debug : self.in_proj_pair = NonLinearHead(self.pair_dim, self.attn_head, activation_fn=None) + self.in_proj_pair = SimpleLinear(self.pair_dim, self.attn_head, activate=None) + evoformer_encoder_layers = [] + for i in range(self.layer_num): + evoformer_encoder_layers.append( + EvoformerEncoderLayer( + feature_dim=self.feature_dim, + ffn_dim=self.ffn_dim, + attn_head=self.attn_head, + activation_fn=self.activation_function, + post_ln=self.post_ln, + ) + ) + self.evoformer_encoder_layers = nn.ModuleList(evoformer_encoder_layers) + if self._final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + self.feature_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + if self._final_head_layer_norm: + self.final_head_layer_norm = nn.LayerNorm( + self.attn_head, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + + def forward(self, atomic_rep, pair_rep, nlist, nlist_type, nlist_mask): + """Encoder the atomic and pair representations. + + Args: + - atomic_rep: Atomic representation with shape [nframes, nloc, atomic_dim]. + - pair_rep: Pair representation with shape [nframes, nloc, nnei, pair_dim]. + - nlist: Neighbor list with shape [nframes, nloc, nnei]. + - nlist_type: Neighbor types with shape [nframes, nloc, nnei]. + - nlist_mask: Neighbor mask with shape [nframes, nloc, nnei], `False` if blank. + + Returns + ------- + - atomic_rep: Atomic representation after encoder with shape [nframes, nloc, feature_dim]. + - transformed_atomic_rep: Transformed atomic representation after encoder with shape [nframes, nloc, atomic_dim]. + - pair_rep: Pair representation after encoder with shape [nframes, nloc, nnei, attn_head]. + - delta_pair_rep: Delta pair representation after encoder with shape [nframes, nloc, nnei, attn_head]. + - norm_x: Normalization loss of atomic_rep. + - norm_delta_pair_rep: Normalization loss of delta_pair_rep. + """ + # Global branch + nframes, nloc, _ = atomic_rep.size() + nnei = pair_rep.shape[2] + input_atomic_rep = atomic_rep + # [nframes, nloc, feature_dim] + if self.atomic_residual: + atomic_rep = atomic_rep + self.in_proj(atomic_rep) + else: + atomic_rep = self.in_proj(atomic_rep) + + if self._emb_layer_norm: + atomic_rep = self.emb_layer_norm(atomic_rep) + + # Local branch + # [nframes, nloc, nnei, attn_head] + pair_rep = self.in_proj_pair(pair_rep) + # [nframes, attn_head, nloc, nnei] + pair_rep = pair_rep.permute(0, 3, 1, 2).contiguous() + input_pair_rep = pair_rep + pair_rep = pair_rep.masked_fill(~nlist_mask.unsqueeze(1), float("-inf")) + + for i in range(self.layer_num): + atomic_rep, pair_rep, _ = self.evoformer_encoder_layers[i]( + atomic_rep, + attn_bias=pair_rep, + nlist_mask=nlist_mask, + nlist=nlist, + return_attn=True, + ) + + def norm_loss(x, eps=1e-10, tolerance=1.0): + # x = x.float() + max_norm = x.shape[-1] ** 0.5 + norm = torch.sqrt(torch.sum(x**2, dim=-1) + eps) + error = F.relu((norm - max_norm).abs() - tolerance) + return error + + def masked_mean(mask, value, dim=-1, eps=1e-10): + return ( + torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + ).mean() + + # atomic_rep shape: [nframes, nloc, feature_dim] + # pair_rep shape: [nframes, attn_head, nloc, nnei] + + norm_x = torch.mean(norm_loss(atomic_rep)) + if self._final_layer_norm: + atomic_rep = self.final_layer_norm(atomic_rep) + + delta_pair_rep = pair_rep - input_pair_rep + delta_pair_rep = delta_pair_rep.masked_fill(~nlist_mask.unsqueeze(1), 0) + # [nframes, nloc, nnei, attn_head] + delta_pair_rep = ( + delta_pair_rep.view(nframes, self.attn_head, nloc, nnei) + .permute(0, 2, 3, 1) + .contiguous() + ) + + # [nframes, nloc, nnei] + norm_delta_pair_rep = norm_loss(delta_pair_rep) + norm_delta_pair_rep = masked_mean(mask=nlist_mask, value=norm_delta_pair_rep) + if self._final_head_layer_norm: + delta_pair_rep = self.final_head_layer_norm(delta_pair_rep) + + if self.atomic_residual: + transformed_atomic_rep = atomic_rep + self.out_proj(atomic_rep) + else: + transformed_atomic_rep = self.out_proj(atomic_rep) + + if self.evo_residual: + transformed_atomic_rep = ( + self.residual_factor * transformed_atomic_rep + input_atomic_rep + ) * (1 / np.sqrt(2)) + + return ( + atomic_rep, + transformed_atomic_rep, + pair_rep, + delta_pair_rep, + norm_x, + norm_delta_pair_rep, + ) + + +class Evoformer3bEncoderLayer(nn.Module): + def __init__( + self, + nnei, + embedding_dim: int = 768, + pair_dim: int = 64, + pair_hidden_dim: int = 32, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + droppath_prob: float = 0.0, + pair_dropout: float = 0.25, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + pre_ln: bool = True, + tri_update: bool = True, + ): + super().__init__() + # Initialize parameters + self.nnei = nnei + self.embedding_dim = embedding_dim + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + + # self.dropout = dropout + self.activation_dropout = activation_dropout + + if droppath_prob > 0.0: + self.dropout_module = DropPath(droppath_prob) + else: + self.dropout_module = Dropout(dropout) + + # self.self_attn = AtomAttentionLocal(embedding_dim, embedding_dim, embedding_dim, pair_dim, + # embedding_dim // num_attention_heads, num_attention_heads, + # gating=False, dropout=attention_dropout) + self.self_attn = AtomAttention( + embedding_dim, + embedding_dim, + embedding_dim, + pair_dim, + embedding_dim // num_attention_heads, + num_attention_heads, + gating=False, + dropout=attention_dropout, + ) + # layer norm associated with the self attention layer + self.pre_ln = pre_ln + self.self_attn_layer_norm = nn.LayerNorm( + self.embedding_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.fc1 = nn.Linear( + self.embedding_dim, ffn_embedding_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.fc2 = nn.Linear( + ffn_embedding_dim, self.embedding_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.final_layer_norm = nn.LayerNorm( + self.embedding_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + + self.x_layer_norm_opm = nn.LayerNorm( + self.embedding_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + # self.opm = OuterProductLocal(self.embedding_dim, pair_dim, d_hid=pair_hidden_dim) + self.opm = OuterProduct(self.embedding_dim, pair_dim, d_hid=pair_hidden_dim) + # self.pair_layer_norm_opm = nn.LayerNorm(pair_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + self.pair_layer_norm_ffn = nn.LayerNorm( + pair_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.pair_ffn = Transition( + pair_dim, + 1, + dropout=activation_dropout, + ) + self.pair_dropout = pair_dropout + self.tri_update = tri_update + if self.tri_update: + self.pair_layer_norm_trimul = nn.LayerNorm( + pair_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION + ) + self.pair_tri_mul = TriangleMultiplication(pair_dim, pair_hidden_dim) + + def update_pair( + self, + x, + pair, + nlist, + op_mask, + op_norm, + ): + # local: + # [nframes, nloc, nnei, pair_dim] + # global: + # [nframes, nloc, nloc, pair_dim] + pair = pair + self.dropout_module( + self.opm(self.x_layer_norm_opm(x), nlist, op_mask, op_norm) + ) + if not self.pre_ln: + pair = self.pair_layer_norm_opm(pair) + return x, pair + + def shared_dropout(self, x, shared_dim, dropout): + shape = list(x.shape) + shape[shared_dim] = 1 + with torch.no_grad(): + mask = x.new_ones(shape) + return F.dropout(mask, p=dropout, training=self.training) * x + + def forward( + self, + x: torch.Tensor, + pair: torch.Tensor, + nlist: torch.Tensor = None, + attn_mask: Optional[torch.Tensor] = None, + pair_mask: Optional[torch.Tensor] = None, + op_mask: float = 1.0, + op_norm: float = 1.0, + ): + """Encoder the atomic and pair representations. + + Args: + - x: Atomic representation with shape [ncluster, natoms, embed_dim]. + - pair: Pair representation with shape [ncluster, natoms, natoms, pair_dim]. + - attn_mask: Attention mask with shape [ncluster, head, natoms, natoms]. + - pair_mask: Neighbor mask with shape [ncluster, natoms, natoms]. + + """ + # [ncluster, natoms, embed_dim] + residual = x + if self.pre_ln: + x = self.self_attn_layer_norm(x) + x = self.self_attn( + x, + x, + x, + nlist=nlist, + pair=pair, + mask=attn_mask, + ) + # x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) + x = residual + x + if not self.pre_ln: + x = self.self_attn_layer_norm(x) + + residual = x + if self.pre_ln: + x = self.final_layer_norm(x) + x = F.linear(x, self.fc1.weight) + # x = fused_ops.bias_torch_gelu(x, self.fc1.bias) + x = nn.GELU()(x) + self.fc1.bias + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + # x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) + + x = residual + x + if not self.pre_ln: + x = self.final_layer_norm(x) + + block = [ + partial( + self.update_pair, + nlist=nlist, + op_mask=op_mask, + op_norm=op_norm, + ) + ] + + x, pair = checkpoint_sequential( + block, + input_x=(x, pair), + ) + + if self.tri_update: + residual_pair = pair + if self.pre_ln: + pair = self.pair_layer_norm_trimul(pair) + + pair = self.shared_dropout( + self.pair_tri_mul(pair, pair_mask), -3, self.pair_dropout + ) + pair = residual_pair + pair + if not self.pre_ln: + pair = self.pair_layer_norm_trimul(pair) + + residual_pair = pair + if self.pre_ln: + pair = self.pair_layer_norm_ffn(pair) + pair = self.dropout_module(self.pair_ffn(pair)) + pair = residual_pair + pair + if not self.pre_ln: + pair = self.pair_layer_norm_ffn(pair) + return x, pair + + +class Evoformer3bEncoder(nn.Module): + def __init__( + self, + nnei, + layer_num=6, + attn_head=8, + atomic_dim=768, + pair_dim=64, + pair_hidden_dim=32, + ffn_embedding_dim=3072, + dropout: float = 0.1, + droppath_prob: float = 0.0, + pair_dropout: float = 0.25, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + pre_ln: bool = True, + tri_update: bool = True, + **kwargs, + ): + super().__init__() + self.nnei = nnei + if droppath_prob > 0: + droppath_probs = [ + x.item() for x in torch.linspace(0, droppath_prob, layer_num) + ] + else: + droppath_probs = None + + self.layers = nn.ModuleList( + [ + Evoformer3bEncoderLayer( + nnei, + atomic_dim, + pair_dim, + pair_hidden_dim, + ffn_embedding_dim, + num_attention_heads=attn_head, + dropout=dropout, + droppath_prob=droppath_probs[_], + pair_dropout=pair_dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + pre_ln=pre_ln, + tri_update=tri_update, + ) + for _ in range(layer_num) + ] + ) + + def forward(self, x, pair, attn_mask=None, pair_mask=None, atom_mask=None): + """Encoder the atomic and pair representations. + + Args: + x: Atomic representation with shape [ncluster, natoms, atomic_dim]. + pair: Pair representation with shape [ncluster, natoms, natoms, pair_dim]. + attn_mask: Attention mask (with -inf for softmax) with shape [ncluster, head, natoms, natoms]. + pair_mask: Pair mask (with 1 for real atom pair and 0 for padding) with shape [ncluster, natoms, natoms]. + atom_mask: Atom mask (with 1 for real atom and 0 for padding) with shape [ncluster, natoms]. + + Returns + ------- + x: Atomic representation with shape [ncluster, natoms, atomic_dim]. + pair: Pair representation with shape [ncluster, natoms, natoms, pair_dim]. + + """ + # [ncluster, natoms, 1] + op_mask = atom_mask.unsqueeze(-1) + op_mask = op_mask * (op_mask.size(-2) ** -0.5) + eps = 1e-3 + # [ncluster, natoms, natoms, 1] + op_norm = 1.0 / (eps + torch.einsum("...bc,...dc->...bdc", op_mask, op_mask)) + for layer in self.layers: + x, pair = layer( + x, + pair, + nlist=None, + attn_mask=attn_mask, + pair_mask=pair_mask, + op_mask=op_mask, + op_norm=op_norm, + ) + return x, pair diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py new file mode 100644 index 0000000000..fcf46632f3 --- /dev/null +++ b/deepmd/pt/model/task/__init__.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .atten_lcc import ( + FittingNetAttenLcc, +) +from .denoise import ( + DenoiseNet, +) +from .dipole import ( + DipoleFittingNetType, +) +from .ener import ( + EnergyFittingNet, + EnergyFittingNetDirect, +) +from .fitting import ( + Fitting, +) +from .task import ( + TaskBaseMethod, +) +from .type_predict import ( + TypePredictNet, +) + +__all__ = [ + "FittingNetAttenLcc", + "DenoiseNet", + "DipoleFittingNetType", + "EnergyFittingNet", + "EnergyFittingNetDirect", + "Fitting", + "TaskBaseMethod", + "TypePredictNet", +] diff --git a/deepmd/pt/model/task/atten_lcc.py b/deepmd/pt/model/task/atten_lcc.py new file mode 100644 index 0000000000..41ccf99330 --- /dev/null +++ b/deepmd/pt/model/task/atten_lcc.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch +import torch.nn as nn + +from deepmd.pt.model.network.network import ( + EnergyHead, + NodeTaskHead, +) +from deepmd.pt.model.task.task import ( + TaskBaseMethod, +) +from deepmd.pt.utils import ( + env, +) + + +class FittingNetAttenLcc(TaskBaseMethod): + def __init__( + self, embedding_width, bias_atom_e, pair_embed_dim, attention_heads, **kwargs + ): + super().__init__() + self.embedding_width = embedding_width + self.engergy_proj = EnergyHead(self.embedding_width, 1) + self.energe_agg_factor = nn.Embedding(4, 1, dtype=env.GLOBAL_PT_FLOAT_PRECISION) + nn.init.normal_(self.energe_agg_factor.weight, 0, 0.01) + bias_atom_e = torch.tensor(bias_atom_e) + self.register_buffer("bias_atom_e", bias_atom_e) + self.pair_embed_dim = pair_embed_dim + self.attention_heads = attention_heads + self.node_proc = NodeTaskHead( + self.embedding_width, self.pair_embed_dim, self.attention_heads + ) + self.node_proc.zero_init() + + def forward(self, output, pair, delta_pos, atype, nframes, nloc): + # [nframes x nloc x tebd_dim] + output_nloc = (output[:, 0, :]).reshape(nframes, nloc, self.embedding_width) + # Optional: GRRG or mean of gbf TODO + + # energy outut + # [nframes, nloc] + energy_out = self.engergy_proj(output_nloc).view(nframes, nloc) + # [nframes, nloc] + energy_factor = self.energe_agg_factor(torch.zeros_like(atype)).view( + nframes, nloc + ) + energy_out = (energy_out * energy_factor) + self.bias_atom_e[atype] + energy_out = energy_out.sum(dim=-1) + + # vector output + # predict_force: [(nframes x nloc) x (1 + nnei2) x 3] + predict_force = self.node_proc(output, pair, delta_pos=delta_pos) + # predict_force_nloc: [nframes x nloc x 3] + predict_force_nloc = (predict_force[:, 0, :]).reshape(nframes, nloc, 3) + return energy_out, predict_force_nloc diff --git a/deepmd/pt/model/task/denoise.py b/deepmd/pt/model/task/denoise.py new file mode 100644 index 0000000000..7e6b6dcdb6 --- /dev/null +++ b/deepmd/pt/model/task/denoise.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +import torch + +from deepmd.model_format import ( + FittingOutputDef, + OutputVariableDef, + fitting_check_output, +) +from deepmd.pt.model.network.network import ( + MaskLMHead, + NonLinearHead, +) +from deepmd.pt.model.task.task import ( + TaskBaseMethod, +) +from deepmd.pt.utils import ( + env, +) + + +@fitting_check_output +class DenoiseNet(TaskBaseMethod): + def __init__( + self, + feature_dim, + ntypes, + attn_head=8, + prefactor=[0.5, 0.5], + activation_function="gelu", + **kwargs, + ): + """Construct a denoise net. + + Args: + - ntypes: Element count. + - embedding_width: Embedding width per atom. + - neuron: Number of neurons in each hidden layers of the fitting net. + - bias_atom_e: Average enery per atom for each element. + - resnet_dt: Using time-step in the ResNet construction. + """ + super().__init__() + self.feature_dim = feature_dim + self.ntypes = ntypes + self.attn_head = attn_head + self.prefactor = torch.tensor( + prefactor, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + + self.lm_head = MaskLMHead( + embed_dim=self.feature_dim, + output_dim=ntypes, + activation_fn=activation_function, + weight=None, + ) + + if not isinstance(self.attn_head, list): + self.pair2coord_proj = NonLinearHead( + self.attn_head, 1, activation_fn=activation_function + ) + else: + self.pair2coord_proj = [] + self.ndescriptor = len(self.attn_head) + for ii in range(self.ndescriptor): + _pair2coord_proj = NonLinearHead( + self.attn_head[ii], 1, activation_fn=activation_function + ) + self.pair2coord_proj.append(_pair2coord_proj) + self.pair2coord_proj = torch.nn.ModuleList(self.pair2coord_proj) + + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "updated_coord", [3], reduciable=False, differentiable=False + ), + OutputVariableDef( + "logits", [-1], reduciable=False, differentiable=False + ), + ] + ) + + def forward( + self, + pair_weights, + diff, + nlist_mask, + features, + sw, + masked_tokens: Optional[torch.Tensor] = None, + ): + """Calculate the updated coord. + Args: + - coord: Input noisy coord with shape [nframes, nloc, 3]. + - pair_weights: Input pair weights with shape [nframes, nloc, nnei, head]. + - diff: Input pair relative coord list with shape [nframes, nloc, nnei, 3]. + - nlist_mask: Input nlist mask with shape [nframes, nloc, nnei]. + + Returns + ------- + - denoised_coord: Denoised updated coord with shape [nframes, nloc, 3]. + """ + # [nframes, nloc, nnei, 1] + logits = self.lm_head(features, masked_tokens=masked_tokens) + if not isinstance(self.attn_head, list): + attn_probs = self.pair2coord_proj(pair_weights) + out_coord = (attn_probs * diff).sum(dim=-2) / ( + sw.sum(dim=-1).unsqueeze(-1) + 1e-6 + ) + else: + assert len(self.prefactor) == self.ndescriptor + all_coord_update = [] + assert len(pair_weights) == len(diff) == len(nlist_mask) == self.ndescriptor + for ii in range(self.ndescriptor): + _attn_probs = self.pair2coord_proj[ii](pair_weights[ii]) + _coord_update = (_attn_probs * diff[ii]).sum(dim=-2) / ( + nlist_mask[ii].sum(dim=-1).unsqueeze(-1) + 1e-6 + ) + all_coord_update.append(_coord_update) + out_coord = self.prefactor[0] * all_coord_update[0] + for ii in range(self.ndescriptor - 1): + out_coord += self.prefactor[ii + 1] * all_coord_update[ii + 1] + return { + "updated_coord": out_coord, + "logits": logits, + } diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py new file mode 100644 index 0000000000..8511c7dc29 --- /dev/null +++ b/deepmd/pt/model/task/dipole.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging + +import torch + +from deepmd.pt.model.network.network import ( + ResidualDeep, +) +from deepmd.pt.model.task.task import ( + TaskBaseMethod, +) + + +class DipoleFittingNetType(TaskBaseMethod): + def __init__( + self, ntypes, embedding_width, neuron, out_dim, resnet_dt=True, **kwargs + ): + """Construct a fitting net for dipole. + + Args: + - ntypes: Element count. + - embedding_width: Embedding width per atom. + - neuron: Number of neurons in each hidden layers of the fitting net. + - bias_atom_e: Average enery per atom for each element. + - resnet_dt: Using time-step in the ResNet construction. + """ + super().__init__() + self.ntypes = ntypes + self.embedding_width = embedding_width + self.out_dim = out_dim + + filter_layers = [] + one = ResidualDeep( + 0, embedding_width, neuron, 0.0, out_dim=self.out_dim, resnet_dt=resnet_dt + ) + filter_layers.append(one) + self.filter_layers = torch.nn.ModuleList(filter_layers) + + if "seed" in kwargs: + logging.info("Set seed to %d in fitting net.", kwargs["seed"]) + torch.manual_seed(kwargs["seed"]) + + def forward(self, inputs, atype, atype_tebd, rot_mat): + """Based on embedding net output, alculate total energy. + + Args: + - inputs: Descriptor. Its shape is [nframes, nloc, self.embedding_width]. + - atype: Atom type. Its shape is [nframes, nloc]. + - atype_tebd: Atom type embedding. Its shape is [nframes, nloc, tebd_dim] + - rot_mat: GR during descriptor calculation. Its shape is [nframes * nloc, m1, 3]. + + Returns + ------- + - vec_out: output vector. Its shape is [nframes, nloc, 3]. + """ + nframes, nloc, _ = inputs.size() + if atype_tebd is not None: + inputs = torch.concat([inputs, atype_tebd], dim=-1) + vec_out = self.filter_layers[0](inputs) # Shape is [nframes, nloc, m1] + assert list(vec_out.size()) == [nframes, nloc, self.out_dim] + vec_out = vec_out.view(-1, 1, self.out_dim) + vec_out = ( + torch.bmm(vec_out, rot_mat).squeeze(-2).view(nframes, nloc, 3) + ) # Shape is [nframes, nloc, 3] + return vec_out diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py new file mode 100644 index 0000000000..7ddcbd5c54 --- /dev/null +++ b/deepmd/pt/model/task/ener.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + Optional, + Tuple, +) + +import torch + +from deepmd.model_format import ( + FittingOutputDef, + OutputVariableDef, + fitting_check_output, +) +from deepmd.pt.model.network.network import ( + ResidualDeep, +) +from deepmd.pt.model.task.fitting import ( + Fitting, +) +from deepmd.pt.utils import ( + env, +) + + +@Fitting.register("ener") +@fitting_check_output +class EnergyFittingNet(Fitting): + def __init__( + self, + ntypes, + embedding_width, + neuron, + bias_atom_e, + resnet_dt=True, + use_tebd=True, + **kwargs, + ): + """Construct a fitting net for energy. + + Args: + - ntypes: Element count. + - embedding_width: Embedding width per atom. + - neuron: Number of neurons in each hidden layers of the fitting net. + - bias_atom_e: Average enery per atom for each element. + - resnet_dt: Using time-step in the ResNet construction. + """ + super().__init__() + self.ntypes = ntypes + self.embedding_width = embedding_width + self.use_tebd = use_tebd + if not use_tebd: + assert self.ntypes == len(bias_atom_e), "Element count mismatches!" + bias_atom_e = torch.tensor(bias_atom_e) + self.register_buffer("bias_atom_e", bias_atom_e) + + filter_layers = [] + for type_i in range(self.ntypes): + bias_type = 0.0 + one = ResidualDeep( + type_i, embedding_width, neuron, bias_type, resnet_dt=resnet_dt + ) + filter_layers.append(one) + self.filter_layers = torch.nn.ModuleList(filter_layers) + + if "seed" in kwargs: + logging.info("Set seed to %d in fitting net.", kwargs["seed"]) + torch.manual_seed(kwargs["seed"]) + + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef("energy", [1], reduciable=True, differentiable=True), + ] + ) + + def forward( + self, + inputs: torch.Tensor, + atype: torch.Tensor, + atype_tebd: Optional[torch.Tensor] = None, + rot_mat: Optional[torch.Tensor] = None, + ): + """Based on embedding net output, alculate total energy. + + Args: + - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.embedding_width]. + - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. + + Returns + ------- + - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. + """ + outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion + if self.use_tebd: + if atype_tebd is not None: + inputs = torch.concat([inputs, atype_tebd], dim=-1) + atom_energy = self.filter_layers[0](inputs) + self.bias_atom_e[ + atype + ].unsqueeze(-1) + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + else: + for type_i, filter_layer in enumerate(self.filter_layers): + mask = atype == type_i + atom_energy = filter_layer(inputs) + atom_energy = atom_energy + self.bias_atom_e[type_i] + atom_energy = atom_energy * mask.unsqueeze(-1) + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + return {"energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + + +@Fitting.register("direct_force") +@Fitting.register("direct_force_ener") +@fitting_check_output +class EnergyFittingNetDirect(Fitting): + def __init__( + self, + ntypes, + embedding_width, + neuron, + bias_atom_e, + out_dim=1, + resnet_dt=True, + use_tebd=True, + return_energy=False, + **kwargs, + ): + """Construct a fitting net for energy. + + Args: + - ntypes: Element count. + - embedding_width: Embedding width per atom. + - neuron: Number of neurons in each hidden layers of the fitting net. + - bias_atom_e: Average enery per atom for each element. + - resnet_dt: Using time-step in the ResNet construction. + """ + super().__init__() + self.ntypes = ntypes + self.embedding_width = embedding_width + self.use_tebd = use_tebd + self.out_dim = out_dim + if not use_tebd: + assert self.ntypes == len(bias_atom_e), "Element count mismatches!" + bias_atom_e = torch.tensor(bias_atom_e) + self.register_buffer("bias_atom_e", bias_atom_e) + + filter_layers_dipole = [] + for type_i in range(self.ntypes): + one = ResidualDeep( + type_i, + embedding_width, + neuron, + 0.0, + out_dim=out_dim, + resnet_dt=resnet_dt, + ) + filter_layers_dipole.append(one) + self.filter_layers_dipole = torch.nn.ModuleList(filter_layers_dipole) + + self.return_energy = return_energy + filter_layers = [] + if self.return_energy: + for type_i in range(self.ntypes): + bias_type = 0.0 if self.use_tebd else bias_atom_e[type_i] + one = ResidualDeep( + type_i, embedding_width, neuron, bias_type, resnet_dt=resnet_dt + ) + filter_layers.append(one) + self.filter_layers = torch.nn.ModuleList(filter_layers) + + if "seed" in kwargs: + logging.info("Set seed to %d in fitting net.", kwargs["seed"]) + torch.manual_seed(kwargs["seed"]) + + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef("energy", [1], reduciable=True, differentiable=False), + OutputVariableDef( + "dforce", [3], reduciable=False, differentiable=False + ), + ] + ) + + def forward( + self, + inputs: torch.Tensor, + atype: torch.Tensor, + atype_tebd: Optional[torch.Tensor] = None, + rot_mat: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, None]: + """Based on embedding net output, alculate total energy. + + Args: + - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.embedding_width]. + - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. + + Returns + ------- + - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. + """ + nframes, nloc, _ = inputs.size() + if self.use_tebd: + if atype_tebd is not None: + inputs = torch.concat([inputs, atype_tebd], dim=-1) + vec_out = self.filter_layers_dipole[0]( + inputs + ) # Shape is [nframes, nloc, m1] + assert list(vec_out.size()) == [nframes, nloc, self.out_dim] + vec_out = vec_out.view(-1, 1, self.out_dim) + assert rot_mat is not None + vec_out = ( + torch.bmm(vec_out, rot_mat).squeeze(-2).view(nframes, nloc, 3) + ) # Shape is [nframes, nloc, 3] + else: + vec_out = torch.zeros_like(atype).unsqueeze(-1) # jit assertion + for type_i, filter_layer in enumerate(self.filter_layers_dipole): + mask = atype == type_i + vec_out_type = filter_layer(inputs) # Shape is [nframes, nloc, m1] + vec_out_type = vec_out_type * mask.unsqueeze(-1) + vec_out = vec_out + vec_out_type # Shape is [nframes, natoms[0], 1] + + outs = torch.zeros_like(atype).unsqueeze(-1) # jit assertion + if self.return_energy: + if self.use_tebd: + atom_energy = self.filter_layers[0](inputs) + self.bias_atom_e[ + atype + ].unsqueeze(-1) + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + else: + for type_i, filter_layer in enumerate(self.filter_layers): + mask = atype == type_i + atom_energy = filter_layer(inputs) + if not env.ENERGY_BIAS_TRAINABLE: + atom_energy = atom_energy + self.bias_atom_e[type_i] + atom_energy = atom_energy * mask.unsqueeze(-1) + outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] + return { + "energy": outs.to(env.GLOBAL_PT_FLOAT_PRECISION), + "dforce": vec_out, + } diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py new file mode 100644 index 0000000000..16e80f9c20 --- /dev/null +++ b/deepmd/pt/model/task/fitting.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + Callable, +) + +import numpy as np +import torch + +from deepmd.model_format import ( + FittingOutputDef, +) +from deepmd.pt.model.task.task import ( + TaskBaseMethod, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.env import ( + DEVICE, +) +from deepmd.pt.utils.plugin import ( + Plugin, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) + + +class Fitting(TaskBaseMethod): + __plugins = Plugin() + + @staticmethod + def register(key: str) -> Callable: + """Register a Fitting plugin. + + Parameters + ---------- + key : str + the key of a Fitting + + Returns + ------- + Fitting + the registered Fitting + + Examples + -------- + >>> @Fitting.register("some_fitting") + class SomeFitting(Fitting): + pass + """ + return Fitting.__plugins.register(key) + + def __new__(cls, *args, **kwargs): + if cls is Fitting: + try: + fitting_type = kwargs["type"] + except KeyError: + raise KeyError("the type of fitting should be set by `type`") + if fitting_type in Fitting.__plugins.plugins: + cls = Fitting.__plugins.plugins[fitting_type] + else: + raise RuntimeError("Unknown descriptor type: " + fitting_type) + return super().__new__(cls) + + def output_def(self) -> FittingOutputDef: + """Definition for the task Output.""" + raise NotImplementedError + + def forward(self, **kwargs): + """Task Output.""" + raise NotImplementedError + + def share_params(self, base_class, shared_level, resume=False): + assert ( + self.__class__ == base_class.__class__ + ), "Only fitting nets of the same type can share params!" + if shared_level == 0: + # link buffers + if hasattr(self, "bias_atom_e"): + self.bias_atom_e = base_class.bias_atom_e + # the following will successfully link all the params except buffers, which need manually link. + for item in self._modules: + self._modules[item] = base_class._modules[item] + elif shared_level == 1: + # only not share the bias_atom_e + # the following will successfully link all the params except buffers, which need manually link. + for item in self._modules: + self._modules[item] = base_class._modules[item] + elif shared_level == 2: + # share all the layers before final layer + # the following will successfully link all the params except buffers, which need manually link. + self._modules["filter_layers"][0].deep_layers = base_class._modules[ + "filter_layers" + ][0].deep_layers + elif shared_level == 3: + # share the first layers + # the following will successfully link all the params except buffers, which need manually link. + self._modules["filter_layers"][0].deep_layers[0] = base_class._modules[ + "filter_layers" + ][0].deep_layers[0] + else: + raise NotImplementedError + + def change_energy_bias( + self, config, model, old_type_map, new_type_map, bias_shift="delta", ntest=10 + ): + """Change the energy bias according to the input data and the pretrained model. + + Parameters + ---------- + config : Dict + The configuration. + model : EnergyModel + Energy model loaded pre-trained model. + new_type_map : list + The original type_map in dataset, they are targets to change the energy bias. + old_type_map : str + The full type_map in pretrained model + bias_shift : str + The mode for changing energy bias : ['delta', 'statistic'] + 'delta' : perform predictions on energies of target dataset, + and do least sqaure on the errors to obtain the target shift as bias. + 'statistic' : directly use the statistic energy bias in the target dataset. + ntest : int + The number of test samples in a system to change the energy bias. + """ + logging.info( + "Changing energy bias in pretrained model for types {}... " + "(this step may take long time)".format(str(new_type_map)) + ) + # data + systems = config["training"]["training_data"]["systems"] + finetune_data = DpLoaderSet( + systems, ntest, config["model"], type_split=False, noise_settings=None + ) + sampled = make_stat_input(finetune_data.systems, finetune_data.dataloaders, 1) + # map + sorter = np.argsort(old_type_map) + idx_type_map = sorter[ + np.searchsorted(old_type_map, new_type_map, sorter=sorter) + ] + mixed_type = np.all([i.mixed_type for i in finetune_data.systems]) + numb_type = len(old_type_map) + type_numbs, energy_ground_truth, energy_predict = [], [], [] + for test_data in sampled: + nframes = test_data["energy"].shape[0] + if mixed_type: + atype = test_data["atype"].detach().cpu().numpy() + else: + atype = test_data["atype"][0].detach().cpu().numpy() + assert np.array( + [i.item() in idx_type_map for i in list(set(atype.reshape(-1)))] + ).all(), "Some types are not in 'type_map'!" + energy_ground_truth.append(test_data["energy"].cpu().numpy()) + if mixed_type: + type_numbs.append( + np.array( + [(atype == i).sum(axis=-1) for i in idx_type_map], + dtype=np.int32, + ).T + ) + else: + type_numbs.append( + np.tile( + np.bincount(atype, minlength=numb_type)[idx_type_map], + (nframes, 1), + ) + ) + if bias_shift == "delta": + coord = test_data["coord"].to(DEVICE) + atype = test_data["atype"].to(DEVICE) + box = ( + test_data["box"].to(DEVICE) + if test_data["box"] is not None + else None + ) + ret = model(coord, atype, box) + energy_predict.append( + ret["energy"].reshape([nframes, 1]).detach().cpu().numpy() + ) + type_numbs = np.concatenate(type_numbs) + energy_ground_truth = np.concatenate(energy_ground_truth) + old_bias = self.bias_atom_e[idx_type_map] + if bias_shift == "delta": + energy_predict = np.concatenate(energy_predict) + bias_diff = energy_ground_truth - energy_predict + delta_bias = np.linalg.lstsq(type_numbs, bias_diff, rcond=None)[0] + unbias_e = energy_predict + type_numbs @ delta_bias + atom_numbs = type_numbs.sum(-1) + rmse_ae = np.sqrt( + np.mean( + np.square( + (unbias_e.ravel() - energy_ground_truth.ravel()) / atom_numbs + ) + ) + ) + self.bias_atom_e[idx_type_map] += torch.from_numpy( + delta_bias.reshape(-1) + ).to(DEVICE) + logging.info( + f"RMSE of atomic energy after linear regression is: {rmse_ae:10.5e} eV/atom." + ) + elif bias_shift == "statistic": + statistic_bias = np.linalg.lstsq( + type_numbs, energy_ground_truth, rcond=None + )[0] + self.bias_atom_e[idx_type_map] = ( + torch.from_numpy(statistic_bias.reshape(-1)) + .type_as(self.bias_atom_e[idx_type_map]) + .to(DEVICE) + ) + else: + raise RuntimeError("Unknown bias_shift mode: " + bias_shift) + logging.info( + "Change energy bias of {} from {} to {}.".format( + str(new_type_map), + str(old_bias.detach().cpu().numpy()), + str(self.bias_atom_e[idx_type_map].detach().cpu().numpy()), + ) + ) + return None diff --git a/deepmd/pt/model/task/task.py b/deepmd/pt/model/task/task.py new file mode 100644 index 0000000000..a9b2efeb9a --- /dev/null +++ b/deepmd/pt/model/task/task.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + + +class TaskBaseMethod(torch.nn.Module): + def __init__(self, **kwargs): + """Construct a basic head for different tasks.""" + super().__init__() + + def forward(self, **kwargs): + """Task Output.""" + raise NotImplementedError diff --git a/deepmd/pt/model/task/type_predict.py b/deepmd/pt/model/task/type_predict.py new file mode 100644 index 0000000000..57227004d0 --- /dev/null +++ b/deepmd/pt/model/task/type_predict.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +import torch + +from deepmd.pt.model.network.network import ( + MaskLMHead, +) +from deepmd.pt.model.task import ( + TaskBaseMethod, +) + + +class TypePredictNet(TaskBaseMethod): + def __init__(self, feature_dim, ntypes, activation_function="gelu", **kwargs): + """Construct a type predict net. + + Args: + - feature_dim: Input dm. + - ntypes: Numer of types to predict. + - activation_function: Activate function. + """ + super().__init__() + self.feature_dim = feature_dim + self.ntypes = ntypes + self.lm_head = MaskLMHead( + embed_dim=self.feature_dim, + output_dim=ntypes, + activation_fn=activation_function, + weight=None, + ) + + def forward(self, features, masked_tokens: Optional[torch.Tensor] = None): + """Calculate the predicted logits. + Args: + - features: Input features with shape [nframes, nloc, feature_dim]. + - masked_tokens: Input masked tokens with shape [nframes, nloc]. + + Returns + ------- + - logits: Predicted probs with shape [nframes, nloc, ntypes]. + """ + # [nframes, nloc, ntypes] + logits = self.lm_head(features, masked_tokens=masked_tokens) + return logits diff --git a/deepmd/pt/optimizer/KFWrapper.py b/deepmd/pt/optimizer/KFWrapper.py new file mode 100644 index 0000000000..3ab7ffe7a9 --- /dev/null +++ b/deepmd/pt/optimizer/KFWrapper.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import math + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.optim.optimizer import ( + Optimizer, +) + + +class KFOptimizerWrapper: + def __init__( + self, + model: nn.Module, + optimizer: Optimizer, + atoms_selected: int, + atoms_per_group: int, + is_distributed: bool = False, + ) -> None: + self.model = model + self.optimizer = optimizer + self.atoms_selected = atoms_selected # 24 + self.atoms_per_group = atoms_per_group # 6 + self.is_distributed = is_distributed + + def update_energy( + self, inputs: dict, Etot_label: torch.Tensor, update_prefactor: float = 1 + ) -> None: + model_pred, _, _ = self.model(**inputs, inference_only=True) + Etot_predict = model_pred["energy"] + natoms_sum = int(inputs["atype"].shape[-1]) + self.optimizer.set_grad_prefactor(natoms_sum) + + self.optimizer.zero_grad() + bs = Etot_label.shape[0] + error = Etot_label - Etot_predict + error = error / natoms_sum + mask = error < 0 + + error = error * update_prefactor + error[mask] = -1 * error[mask] + error = error.mean() + + if self.is_distributed: + dist.all_reduce(error) + error /= dist.get_world_size() + + Etot_predict = update_prefactor * Etot_predict + Etot_predict[mask] = -Etot_predict[mask] + + Etot_predict.sum().backward() + error = error * math.sqrt(bs) + self.optimizer.step(error) + return Etot_predict + + def update_force( + self, inputs: dict, Force_label: torch.Tensor, update_prefactor: float = 1 + ) -> None: + natoms_sum = int(inputs["atype"].shape[-1]) + bs = Force_label.shape[0] + self.optimizer.set_grad_prefactor(natoms_sum * self.atoms_per_group * 3) + + index = self.__sample(self.atoms_selected, self.atoms_per_group, natoms_sum) + + for i in range(index.shape[0]): + self.optimizer.zero_grad() + model_pred, _, _ = self.model(**inputs, inference_only=True) + Etot_predict = model_pred["energy"] + natoms_sum = int(inputs["atype"].shape[-1]) + force_predict = model_pred["force"] + error_tmp = Force_label[:, index[i]] - force_predict[:, index[i]] + error_tmp = update_prefactor * error_tmp + mask = error_tmp < 0 + error_tmp[mask] = -1 * error_tmp[mask] + error = error_tmp.mean() / natoms_sum + + if self.is_distributed: + dist.all_reduce(error) + error /= dist.get_world_size() + + tmp_force_predict = force_predict[:, index[i]] * update_prefactor + tmp_force_predict[mask] = -tmp_force_predict[mask] + + # In order to solve a pytorch bug, reference: https://github.com/pytorch/pytorch/issues/43259 + (tmp_force_predict.sum() + Etot_predict.sum() * 0).backward() + error = error * math.sqrt(bs) + self.optimizer.step(error) + return Etot_predict, force_predict + + def update_denoise_coord( + self, + inputs: dict, + clean_coord: torch.Tensor, + update_prefactor: float = 1, + mask_loss_coord: bool = True, + coord_mask: torch.Tensor = None, + ) -> None: + natoms_sum = int(inputs["atype"].shape[-1]) + bs = clean_coord.shape[0] + self.optimizer.set_grad_prefactor(natoms_sum * self.atoms_per_group * 3) + + index = self.__sample(self.atoms_selected, self.atoms_per_group, natoms_sum) + + for i in range(index.shape[0]): + self.optimizer.zero_grad() + model_pred, _, _ = self.model(**inputs, inference_only=True) + updated_coord = model_pred["updated_coord"] + natoms_sum = int(inputs["atype"].shape[-1]) + error_tmp = clean_coord[:, index[i]] - updated_coord[:, index[i]] + error_tmp = update_prefactor * error_tmp + if mask_loss_coord: + error_tmp[~coord_mask[:, index[i]]] = 0 + mask = error_tmp < 0 + error_tmp[mask] = -1 * error_tmp[mask] + error = error_tmp.mean() / natoms_sum + + if self.is_distributed: + dist.all_reduce(error) + error /= dist.get_world_size() + + tmp_coord_predict = updated_coord[:, index[i]] * update_prefactor + tmp_coord_predict[mask] = -update_prefactor * tmp_coord_predict[mask] + + # In order to solve a pytorch bug, reference: https://github.com/pytorch/pytorch/issues/43259 + (tmp_coord_predict.sum() + updated_coord.sum() * 0).backward() + error = error * math.sqrt(bs) + self.optimizer.step(error) + return model_pred + + def __sample( + self, atoms_selected: int, atoms_per_group: int, natoms: int + ) -> np.ndarray: + if atoms_selected % atoms_per_group: + raise Exception("divider") + index = range(natoms) + rng = np.random.default_rng() + res = rng.choice(index, atoms_selected).reshape(-1, atoms_per_group) + return res + + +# with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False) as prof: +# the code u wanna profile +# print(prof.key_averages().table(sort_by="self_cpu_time_total")) diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py new file mode 100644 index 0000000000..5e18797c7b --- /dev/null +++ b/deepmd/pt/optimizer/LKF.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +import math + +import torch +from torch.optim.optimizer import ( + Optimizer, +) + + +class LKFOptimizer(Optimizer): + def __init__( + self, + params, + kalman_lambda=0.98, + kalman_nue=0.9987, + block_size=5120, + ): + defaults = { + "lr": 0.1, + "kalman_nue": kalman_nue, + "block_size": block_size, + } + super().__init__(params, defaults) + + self._params = self.param_groups[0]["params"] + + if len(self.param_groups) != 1 or len(self._params) == 0: + raise ValueError( + "LKF doesn't support per-parameter options " "(parameter groups)" + ) + + # NOTE: LKF has only global state, but we register it as state for + # the first param, because this helps with casting in load_state_dict + self._state = self.state[self._params[0]] + self._state.setdefault("kalman_lambda", kalman_lambda) + + self.__init_P() + + def __init_P(self): + param_nums = [] + param_sum = 0 + block_size = self.__get_blocksize() + data_type = self._params[0].dtype + device = self._params[0].device + + for param_group in self.param_groups: + params = param_group["params"] + for param in params: + param_num = param.data.nelement() + if param_sum + param_num > block_size: + if param_sum > 0: + param_nums.append(param_sum) + param_sum = param_num + else: + param_sum += param_num + + param_nums.append(param_sum) + + P = [] + params_packed_index = [] + logging.info("LKF parameter nums: %s" % param_nums) + for param_num in param_nums: + if param_num >= block_size: + block_num = math.ceil(param_num / block_size) + for i in range(block_num): + if i != block_num - 1: + P.append( + torch.eye( + block_size, + dtype=data_type, + device=device, + ) + ) + params_packed_index.append(block_size) + else: + P.append( + torch.eye( + param_num - block_size * i, + dtype=data_type, + device=device, + ) + ) + params_packed_index.append(param_num - block_size * i) + else: + P.append(torch.eye(param_num, dtype=data_type, device=device)) + params_packed_index.append(param_num) + + self._state.setdefault("P", P) + self._state.setdefault("weights_num", len(P)) + self._state.setdefault("params_packed_index", params_packed_index) + + def __get_blocksize(self): + return self.param_groups[0]["block_size"] + + def __get_nue(self): + return self.param_groups[0]["kalman_nue"] + + def __split_weights(self, weight): + block_size = self.__get_blocksize() + param_num = weight.nelement() + res = [] + if param_num < block_size: + res.append(weight) + else: + block_num = math.ceil(param_num / block_size) + for i in range(block_num): + if i != block_num - 1: + res.append(weight[i * block_size : (i + 1) * block_size]) + else: + res.append(weight[i * block_size :]) + return res + + def __update(self, H, error, weights): + P = self._state.get("P") + kalman_lambda = self._state.get("kalman_lambda") + weights_num = self._state.get("weights_num") + params_packed_index = self._state.get("params_packed_index") + + block_size = self.__get_blocksize() + kalman_nue = self.__get_nue() + + tmp = 0 + for i in range(weights_num): + tmp = tmp + (kalman_lambda + torch.matmul(torch.matmul(H[i].T, P[i]), H[i])) + + A = 1 / tmp + + for i in range(weights_num): + K = torch.matmul(P[i], H[i]) + + weights[i] = weights[i] + A * error * K + + P[i] = (1 / kalman_lambda) * (P[i] - A * torch.matmul(K, K.T)) + + kalman_lambda = kalman_nue * kalman_lambda + 1 - kalman_nue + self._state.update({"kalman_lambda": kalman_lambda}) + + i = 0 + param_sum = 0 + for param_group in self.param_groups: + params = param_group["params"] + for param in params: + param_num = param.nelement() + weight_tmp = weights[i][param_sum : param_sum + param_num] + if param_num < block_size: + if param.ndim > 1: + param.data = weight_tmp.reshape( + param.data.T.shape + ).T.contiguous() + else: + param.data = weight_tmp.reshape(param.data.shape) + + param_sum += param_num + + if param_sum == params_packed_index[i]: + i += 1 + param_sum = 0 + else: + block_num = math.ceil(param_num / block_size) + for j in range(block_num): + if j == 0: + tmp_weight = weights[i] + else: + tmp_weight = torch.concat([tmp_weight, weights[i]], dim=0) + i += 1 + param.data = tmp_weight.reshape(param.data.T.shape).T.contiguous() + + def set_grad_prefactor(self, grad_prefactor): + self.grad_prefactor = grad_prefactor + + def step(self, error): + params_packed_index = self._state.get("params_packed_index") + + weights = [] + H = [] + param_index = 0 + param_sum = 0 + + for param in self._params: + if param.ndim > 1: + tmp = param.data.T.contiguous().reshape(param.data.nelement(), 1) + if param.grad is None: + tmp_grad = torch.zeros_like(tmp) + else: + tmp_grad = ( + (param.grad / self.grad_prefactor) + .T.contiguous() + .reshape(param.grad.nelement(), 1) + ) + else: + tmp = param.data.reshape(param.data.nelement(), 1) + if param.grad is None: + tmp_grad = torch.zeros_like(tmp) + else: + tmp_grad = (param.grad / self.grad_prefactor).reshape( + param.grad.nelement(), 1 + ) + + tmp = self.__split_weights(tmp) + tmp_grad = self.__split_weights(tmp_grad) + + for split_grad, split_weight in zip(tmp_grad, tmp): + nelement = split_grad.nelement() + + if param_sum == 0: + res_grad = split_grad + res = split_weight + else: + res_grad = torch.concat((res_grad, split_grad), dim=0) + res = torch.concat((res, split_weight), dim=0) + + param_sum += nelement + + if param_sum == params_packed_index[param_index]: + H.append(res_grad) + weights.append(res) + param_sum = 0 + param_index += 1 + + self.__update(H, error, weights) diff --git a/deepmd/pt/optimizer/__init__.py b/deepmd/pt/optimizer/__init__.py new file mode 100644 index 0000000000..db340b3bb9 --- /dev/null +++ b/deepmd/pt/optimizer/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .KFWrapper import ( + KFOptimizerWrapper, +) +from .LKF import ( + LKFOptimizer, +) + +__all__ = ["KFOptimizerWrapper", "LKFOptimizer"] diff --git a/deepmd/pt/train/__init__.py b/deepmd/pt/train/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt/train/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py new file mode 100644 index 0000000000..049685a6e3 --- /dev/null +++ b/deepmd/pt/train/training.py @@ -0,0 +1,849 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +import os +import time +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) +from typing import ( + Any, + Dict, +) + +import numpy as np +import torch +from tqdm import ( + tqdm, +) +from tqdm.contrib.logging import ( + logging_redirect_tqdm, +) + +from deepmd.pt.loss import ( + DenoiseLoss, + EnergyStdLoss, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.optimizer import ( + KFOptimizerWrapper, + LKFOptimizer, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt.utils import ( + dp_random, +) +from deepmd.pt.utils.dataloader import ( + BufferedIterator, + get_weighted_sampler, +) +from deepmd.pt.utils.env import ( + DEVICE, + DISABLE_TQDM, + JIT, + LOCAL_RANK, + NUM_WORKERS, + SAMPLER_RECORD, +) +from deepmd.pt.utils.learning_rate import ( + LearningRateExp, +) + +if torch.__version__.startswith("2"): + import torch._dynamo + +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import ( + DataLoader, +) + + +class Trainer: + def __init__( + self, + config: Dict[str, Any], + training_data, + sampled, + validation_data=None, + init_model=None, + restart_model=None, + finetune_model=None, + force_load=False, + shared_links=None, + ): + """Construct a DeePMD trainer. + + Args: + - config: The Dict-like configuration with training options. + """ + resume_model = init_model if init_model is not None else restart_model + self.restart_training = restart_model is not None + model_params = config["model"] + training_params = config["training"] + self.multi_task = "model_dict" in model_params + self.finetune_multi_task = model_params.pop( + "finetune_multi_task", False + ) # should use pop for next finetune + self.model_keys = ( + list(model_params["model_dict"]) if self.multi_task else ["Default"] + ) + self.rank = dist.get_rank() if dist.is_initialized() else 0 + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.num_model = len(self.model_keys) + + # Iteration config + self.num_steps = training_params["numb_steps"] + self.disp_file = training_params.get("disp_file", "lcurve.out") + self.disp_freq = training_params.get("disp_freq", 1000) + self.save_ckpt = training_params.get("save_ckpt", "model.pt") + self.save_freq = training_params.get("save_freq", 1000) + self.lcurve_should_print_header = True + + def get_opt_param(params): + opt_type = params.get("opt_type", "Adam") + opt_param = { + "kf_blocksize": params.get("kf_blocksize", 5120), + "kf_start_pref_e": params.get("kf_start_pref_e", 1), + "kf_limit_pref_e": params.get("kf_limit_pref_e", 1), + "kf_start_pref_f": params.get("kf_start_pref_f", 1), + "kf_limit_pref_f": params.get("kf_limit_pref_f", 1), + } + return opt_type, opt_param + + def get_data_loader(_training_data, _validation_data, _training_params): + if "auto_prob" in _training_params["training_data"]: + train_sampler = get_weighted_sampler( + _training_data, _training_params["training_data"]["auto_prob"] + ) + elif "sys_probs" in _training_params["training_data"]: + train_sampler = get_weighted_sampler( + _training_data, + _training_params["training_data"]["sys_probs"], + sys_prob=True, + ) + else: + train_sampler = get_weighted_sampler(_training_data, "prob_sys_size") + + if "auto_prob" in _training_params["validation_data"]: + valid_sampler = get_weighted_sampler( + _validation_data, _training_params["validation_data"]["auto_prob"] + ) + elif "sys_probs" in _training_params["validation_data"]: + valid_sampler = get_weighted_sampler( + _validation_data, + _training_params["validation_data"]["sys_probs"], + sys_prob=True, + ) + else: + valid_sampler = get_weighted_sampler(_validation_data, "prob_sys_size") + + if train_sampler is None or valid_sampler is None: + logging.warning( + "Sampler not specified!" + ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. + training_dataloader = DataLoader( + _training_data, + sampler=train_sampler, + batch_size=None, + num_workers=NUM_WORKERS, # setting to 0 diverges the behavior of its iterator; should be >=1 + drop_last=False, + pin_memory=True, + ) + training_data_buffered = BufferedIterator(iter(training_dataloader)) + validation_dataloader = DataLoader( + _validation_data, + sampler=valid_sampler, + batch_size=None, + num_workers=min(NUM_WORKERS, 1), + drop_last=False, + pin_memory=True, + ) + + validation_data_buffered = BufferedIterator(iter(validation_dataloader)) + if _training_params.get("validation_data", None) is not None: + valid_numb_batch = _training_params["validation_data"].get( + "numb_btch", 1 + ) + else: + valid_numb_batch = 1 + return ( + training_dataloader, + training_data_buffered, + validation_dataloader, + validation_data_buffered, + valid_numb_batch, + ) + + def get_single_model(_model_params, _sampled): + model = get_model(deepcopy(_model_params), _sampled).to(DEVICE) + return model + + def get_lr(lr_params): + assert ( + lr_params.get("type", "exp") == "exp" + ), "Only learning rate `exp` is supported!" + lr_params["stop_steps"] = self.num_steps - self.warmup_steps + lr_exp = LearningRateExp(**lr_params) + return lr_exp + + def get_loss(loss_params, start_lr, _ntypes): + loss_type = loss_params.get("type", "ener") + if loss_type == "ener": + loss_params["starter_learning_rate"] = start_lr + return EnergyStdLoss(**loss_params) + elif loss_type == "denoise": + loss_params["ntypes"] = _ntypes + return DenoiseLoss(**loss_params) + else: + raise NotImplementedError + + # Optimizer + if self.multi_task and training_params.get("optim_dict", None) is not None: + self.optim_dict = training_params.get("optim_dict") + missing_keys = [ + key for key in self.model_keys if key not in self.optim_dict + ] + assert ( + not missing_keys + ), f"These keys are not in optim_dict: {missing_keys}!" + self.opt_type = {} + self.opt_param = {} + for model_key in self.model_keys: + self.opt_type[model_key], self.opt_param[model_key] = get_opt_param( + self.optim_dict[model_key] + ) + else: + self.opt_type, self.opt_param = get_opt_param(training_params) + + # Data + Model + dp_random.seed(training_params["seed"]) + if not self.multi_task: + ( + self.training_dataloader, + self.training_data, + self.validation_dataloader, + self.validation_data, + self.valid_numb_batch, + ) = get_data_loader(training_data, validation_data, training_params) + self.model = get_single_model(model_params, sampled) + else: + ( + self.training_dataloader, + self.training_data, + self.validation_dataloader, + self.validation_data, + self.valid_numb_batch, + self.model, + ) = {}, {}, {}, {}, {}, {} + for model_key in self.model_keys: + ( + self.training_dataloader[model_key], + self.training_data[model_key], + self.validation_dataloader[model_key], + self.validation_data[model_key], + self.valid_numb_batch[model_key], + ) = get_data_loader( + training_data[model_key], + validation_data[model_key], + training_params["data_dict"][model_key], + ) + self.model[model_key] = get_single_model( + model_params["model_dict"][model_key], sampled[model_key] + ) + + # Learning rate + self.warmup_steps = training_params.get("warmup_steps", 0) + self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) + assert ( + self.num_steps - self.warmup_steps > 0 + ), "Warm up steps must be less than total training steps!" + if self.multi_task and config.get("learning_rate_dict", None) is not None: + self.lr_exp = {} + for model_key in self.model_keys: + self.lr_exp[model_key] = get_lr(config["learning_rate_dict"][model_key]) + else: + self.lr_exp = get_lr(config["learning_rate"]) + + # Loss + if not self.multi_task: + self.loss = get_loss( + config["loss"], + config["learning_rate"]["start_lr"], + len(model_params["type_map"]), + ) + else: + self.loss = {} + for model_key in self.model_keys: + loss_param = config["loss_dict"][model_key] + if config.get("learning_rate_dict", None) is not None: + lr_param = config["learning_rate_dict"][model_key]["start_lr"] + else: + lr_param = config["learning_rate"]["start_lr"] + ntypes = len(model_params["model_dict"][model_key]["type_map"]) + self.loss[model_key] = get_loss(loss_param, lr_param, ntypes) + + # JIT + if JIT: + self.model = torch.jit.script(self.model) + + # Model Wrapper + self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params) + self.start_step = 0 + + # resuming and finetune + optimizer_state_dict = None + if model_params["resuming"]: + ntest = model_params.get("data_bias_nsample", 1) + origin_model = ( + finetune_model if finetune_model is not None else resume_model + ) + logging.info(f"Resuming from {origin_model}.") + state_dict = torch.load(origin_model, map_location=DEVICE) + if "model" in state_dict: + optimizer_state_dict = ( + state_dict["optimizer"] if finetune_model is None else None + ) + state_dict = state_dict["model"] + self.start_step = ( + state_dict["_extra_state"]["train_infos"]["step"] + if self.restart_training + else 0 + ) + if self.rank == 0: + if force_load: + input_keys = list(state_dict.keys()) + target_keys = list(self.wrapper.state_dict().keys()) + missing_keys = [ + item for item in target_keys if item not in input_keys + ] + if missing_keys: + target_state_dict = self.wrapper.state_dict() + slim_keys = [] + for item in missing_keys: + state_dict[item] = target_state_dict[item].clone().detach() + new_key = True + for slim_key in slim_keys: + if slim_key in item: + new_key = False + break + if new_key: + tmp_keys = ".".join(item.split(".")[:3]) + slim_keys.append(tmp_keys) + slim_keys = [i + ".*" for i in slim_keys] + logging.warning( + f"Force load mode allowed! These keys are not in ckpt and will re-init: {slim_keys}" + ) + elif self.finetune_multi_task: + new_state_dict = {} + model_branch_chosen = model_params.pop("model_branch_chosen") + new_fitting = model_params.pop("new_fitting", False) + target_state_dict = self.wrapper.state_dict() + target_keys = [ + i for i in target_state_dict.keys() if i != "_extra_state" + ] + for item_key in target_keys: + if new_fitting and ".fitting_net." in item_key: + # print(f'Keep {item_key} in old model!') + new_state_dict[item_key] = ( + target_state_dict[item_key].clone().detach() + ) + else: + new_key = item_key.replace( + ".Default.", f".{model_branch_chosen}." + ) + # print(f'Replace {item_key} with {new_key} in pretrained_model!') + new_state_dict[item_key] = ( + state_dict[new_key].clone().detach() + ) + state_dict = new_state_dict + if finetune_model is not None: + state_dict["_extra_state"] = self.wrapper.state_dict()[ + "_extra_state" + ] + + self.wrapper.load_state_dict(state_dict) + # finetune + if finetune_model is not None and model_params["fitting_net"].get( + "type", "ener" + ) in ["ener", "direct_force_ener", "atten_vec_lcc"]: + old_type_map, new_type_map = ( + model_params["type_map"], + model_params["new_type_map"], + ) + self.model.fitting_net.change_energy_bias( + config, + self.model, + old_type_map, + new_type_map, + ntest=ntest, + bias_shift=model_params.get("bias_shift", "delta"), + ) + + # Set trainable params + self.wrapper.set_trainable_params() + + # Multi-task share params + if shared_links is not None: + self.wrapper.share_params(shared_links, resume=model_params["resuming"]) + + if dist.is_initialized(): + torch.cuda.set_device(LOCAL_RANK) + # DDP will guarantee the model parameters are identical across all processes + self.wrapper = DDP( + self.wrapper, + device_ids=[LOCAL_RANK], + find_unused_parameters=True, + output_device=LOCAL_RANK, + ) + + # TODO ZD add lr warmups for multitask + def warm_up_linear(step, warmup_steps): + if step < warmup_steps: + return step / warmup_steps + else: + return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr + + # TODO ZD add optimizers for multitask + if self.opt_type == "Adam": + self.optimizer = torch.optim.Adam( + self.wrapper.parameters(), lr=self.lr_exp.start_lr + ) + if optimizer_state_dict is not None and self.restart_training: + self.optimizer.load_state_dict(optimizer_state_dict) + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), + ) + elif self.opt_type == "LKF": + self.optimizer = LKFOptimizer( + self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"] + ) + else: + raise ValueError("Not supported optimizer type '%s'" % self.opt_type) + + # Get model prob for multi-task + if self.multi_task: + self.model_prob = np.array([0.0 for key in self.model_keys]) + if training_params.get("model_prob", None) is not None: + model_prob = training_params["model_prob"] + for ii, model_key in enumerate(self.model_keys): + if model_key in model_prob: + self.model_prob[ii] += float(model_prob[model_key]) + else: + for ii, model_key in enumerate(self.model_keys): + self.model_prob[ii] += float(len(self.training_data[model_key])) + sum_prob = np.sum(self.model_prob) + assert sum_prob > 0.0, "Sum of model prob must be larger than 0!" + self.model_prob = self.model_prob / sum_prob + + def run(self): + fout = ( + open(self.disp_file, mode="w", buffering=1) if self.rank == 0 else None + ) # line buffered + if SAMPLER_RECORD: + record_file = f"Sample_rank_{self.rank}.txt" + fout1 = open(record_file, mode="w", buffering=1) + logging.info("Start to train %d steps.", self.num_steps) + if dist.is_initialized(): + logging.info(f"Rank: {dist.get_rank()}/{dist.get_world_size()}") + + def step(_step_id, task_key="Default"): + self.wrapper.train() + if isinstance(self.lr_exp, dict): + _lr = self.lr_exp[task_key] + else: + _lr = self.lr_exp + cur_lr = _lr.value(_step_id) + pref_lr = cur_lr + self.optimizer.zero_grad(set_to_none=True) + input_dict, label_dict, log_dict = self.get_data( + is_train=True, task_key=task_key + ) + if SAMPLER_RECORD: + print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" + fout1.write(print_str) + fout1.flush() + if self.opt_type == "Adam": + cur_lr = self.scheduler.get_last_lr()[0] + if _step_id < self.warmup_steps: + pref_lr = _lr.start_lr + else: + pref_lr = cur_lr + model_pred, loss, more_loss = self.wrapper( + **input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key + ) + loss.backward() + if self.gradient_max_norm > 0.0: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.wrapper.parameters(), self.gradient_max_norm + ) + if not torch.isfinite(grad_norm).all(): + # check local gradnorm single GPU case, trigger NanDetector + raise FloatingPointError("gradients are Nan/Inf") + self.optimizer.step() + self.scheduler.step() + elif self.opt_type == "LKF": + if isinstance(self.loss, EnergyStdLoss): + KFOptWrapper = KFOptimizerWrapper( + self.wrapper, self.optimizer, 24, 6, dist.is_initialized() + ) + pref_e = self.opt_param["kf_start_pref_e"] * ( + self.opt_param["kf_limit_pref_e"] + / self.opt_param["kf_start_pref_e"] + ) ** (_step_id / self.num_steps) + _ = KFOptWrapper.update_energy( + input_dict, label_dict["energy"], pref_e + ) + pref_f = self.opt_param["kf_start_pref_f"] * ( + self.opt_param["kf_limit_pref_f"] + / self.opt_param["kf_start_pref_f"] + ) ** (_step_id / self.num_steps) + p_energy, p_force = KFOptWrapper.update_force( + input_dict, label_dict["force"], pref_f + ) + # [coord, atype, natoms, mapping, shift, nlist, box] + model_pred = {"energy": p_energy, "force": p_force} + module = ( + self.wrapper.module if dist.is_initialized() else self.wrapper + ) + loss, more_loss = module.loss[task_key]( + model_pred, + label_dict, + int(input_dict["atype"].shape[-1]), + learning_rate=pref_lr, + ) + elif isinstance(self.loss, DenoiseLoss): + KFOptWrapper = KFOptimizerWrapper( + self.wrapper, self.optimizer, 24, 6, dist.is_initialized() + ) + module = ( + self.wrapper.module if dist.is_initialized() else self.wrapper + ) + model_pred = KFOptWrapper.update_denoise_coord( + input_dict, + label_dict["clean_coord"], + 1, + module.loss[task_key].mask_loss_coord, + label_dict["coord_mask"], + ) + loss, more_loss = module.loss[task_key]( + model_pred, + label_dict, + input_dict["natoms"], + learning_rate=pref_lr, + ) + else: + raise ValueError("Not supported optimizer type '%s'" % self.opt_type) + + # Log and persist + if _step_id % self.disp_freq == 0: + self.wrapper.eval() + msg = f"step={_step_id}, lr={cur_lr:.2e}" + + def log_loss_train(_loss, _more_loss, _task_key="Default"): + results = {} + if not self.multi_task: + suffix = "" + else: + suffix = f"_{_task_key}" + _msg = f"loss{suffix}={_loss:.4f}" + rmse_val = { + item: _more_loss[item] + for item in _more_loss + if "l2_" not in item + } + for item in sorted(rmse_val.keys()): + _msg += f", {item}_train{suffix}={rmse_val[item]:.4f}" + results[item] = rmse_val[item] + return _msg, results + + def log_loss_valid(_task_key="Default"): + single_results = {} + sum_natoms = 0 + if not self.multi_task: + suffix = "" + valid_numb_batch = self.valid_numb_batch + else: + suffix = f"_{_task_key}" + valid_numb_batch = self.valid_numb_batch[_task_key] + for ii in range(valid_numb_batch): + self.optimizer.zero_grad() + input_dict, label_dict, _ = self.get_data( + is_train=False, task_key=_task_key + ) + _, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=pref_lr, + label=label_dict, + task_key=_task_key, + ) + # more_loss.update({"rmse": math.sqrt(loss)}) + natoms = int(input_dict["atype"].shape[-1]) + sum_natoms += natoms + for k, v in more_loss.items(): + if "l2_" not in k: + single_results[k] = ( + single_results.get(k, 0.0) + v * natoms + ) + results = {k: v / sum_natoms for k, v in single_results.items()} + _msg = "" + for item in sorted(results.keys()): + _msg += f", {item}_valid{suffix}={results[item]:.4f}" + return _msg, results + + if not self.multi_task: + temp_msg, train_results = log_loss_train(loss, more_loss) + msg += "\n" + temp_msg + temp_msg, valid_results = log_loss_valid() + msg += temp_msg + else: + train_results = {_key: {} for _key in self.model_keys} + valid_results = {_key: {} for _key in self.model_keys} + train_msg = {} + valid_msg = {} + train_msg[task_key], train_results[task_key] = log_loss_train( + loss, more_loss, _task_key=task_key + ) + for _key in self.model_keys: + if _key != task_key: + self.optimizer.zero_grad() + input_dict, label_dict, _ = self.get_data( + is_train=True, task_key=_key + ) + _, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=pref_lr, + label=label_dict, + task_key=_key, + ) + train_msg[_key], train_results[_key] = log_loss_train( + loss, more_loss, _task_key=_key + ) + valid_msg[_key], valid_results[_key] = log_loss_valid( + _task_key=_key + ) + msg += "\n" + train_msg[_key] + msg += valid_msg[_key] + + train_time = time.time() - self.t0 + self.t0 = time.time() + msg += f", speed={train_time:.2f} s/{self.disp_freq if _step_id else 1} batches" + logging.info(msg) + + if fout: + if self.lcurve_should_print_header: + self.print_header(fout, train_results, valid_results) + self.lcurve_should_print_header = False + self.print_on_training( + fout, _step_id, cur_lr, train_results, valid_results + ) + + if ( + ((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step) + or (_step_id + 1) == self.num_steps + ) and (self.rank == 0 or dist.get_rank() == 0): + # Handle the case if rank 0 aborted and re-assigned + self.latest_model = Path(self.save_ckpt) + self.latest_model = self.latest_model.with_name( + f"{self.latest_model.stem}_{_step_id + 1}{self.latest_model.suffix}" + ) + module = self.wrapper.module if dist.is_initialized() else self.wrapper + self.save_model(self.latest_model, lr=cur_lr, step=_step_id) + logging.info(f"Saved model to {self.latest_model}") + + self.t0 = time.time() + with logging_redirect_tqdm(): + for step_id in tqdm( + range(self.num_steps), + disable=(bool(dist.get_rank()) if dist.is_initialized() else False) + or DISABLE_TQDM, + ): # set to None to disable on non-TTY; disable on not rank 0 + if step_id < self.start_step: + continue + if self.multi_task: + chosen_index_list = dp_random.choice( + np.arange(self.num_model), + p=np.array(self.model_prob), + size=self.world_size, + replace=True, + ) + assert chosen_index_list.size == self.world_size + model_index = chosen_index_list[self.rank] + model_key = self.model_keys[model_index] + else: + model_key = "Default" + step(step_id, model_key) + if JIT: + break + + if ( + self.rank == 0 or dist.get_rank() == 0 + ): # Handle the case if rank 0 aborted and re-assigned + if JIT: + pth_model_path = ( + "frozen_model.pth" # We use .pth to denote the frozen model + ) + self.model.save(pth_model_path) + logging.info( + f"Frozen model for inferencing has been saved to {pth_model_path}" + ) + try: + os.symlink(self.latest_model, self.save_ckpt) + except OSError: + self.save_model(self.save_ckpt, lr=0, step=self.num_steps) + logging.info(f"Trained model has been saved to: {self.save_ckpt}") + + if fout: + fout.close() + if SAMPLER_RECORD: + fout1.close() + + def save_model(self, save_path, lr=0.0, step=0): + module = self.wrapper.module if dist.is_initialized() else self.wrapper + module.train_infos["lr"] = lr + module.train_infos["step"] = step + torch.save( + {"model": module.state_dict(), "optimizer": self.optimizer.state_dict()}, + save_path, + ) + + def get_data(self, is_train=True, task_key="Default"): + if not self.multi_task: + if is_train: + try: + batch_data = next(iter(self.training_data)) + except StopIteration: + # Refresh the status of the dataloader to start from a new epoch + self.training_data = BufferedIterator( + iter(self.training_dataloader) + ) + batch_data = next(iter(self.training_data)) + else: + try: + batch_data = next(iter(self.validation_data)) + except StopIteration: + self.validation_data = BufferedIterator( + iter(self.validation_dataloader) + ) + batch_data = next(iter(self.validation_data)) + else: + if is_train: + try: + batch_data = next(iter(self.training_data[task_key])) + except StopIteration: + # Refresh the status of the dataloader to start from a new epoch + self.training_data[task_key] = BufferedIterator( + iter(self.training_dataloader[task_key]) + ) + batch_data = next(iter(self.training_data[task_key])) + else: + try: + batch_data = next(iter(self.validation_data[task_key])) + except StopIteration: + self.validation_data[task_key] = BufferedIterator( + iter(self.validation_dataloader[task_key]) + ) + batch_data = next(iter(self.validation_data[task_key])) + + for key in batch_data.keys(): + if key == "sid" or key == "fid": + continue + elif not isinstance(batch_data[key], list): + if batch_data[key] is not None: + batch_data[key] = batch_data[key].to(DEVICE) + else: + batch_data[key] = [item.to(DEVICE) for item in batch_data[key]] + input_dict = {} + for item in [ + "coord", + "atype", + "box", + ]: + if item in batch_data: + input_dict[item] = batch_data[item] + else: + input_dict[item] = None + label_dict = {} + for item in [ + "energy", + "force", + "virial", + "clean_coord", + "clean_type", + "coord_mask", + "type_mask", + ]: + if item in batch_data: + label_dict[item] = batch_data[item] + log_dict = {} + if "fid" in batch_data: + log_dict["fid"] = batch_data["fid"] + log_dict["sid"] = batch_data["sid"] + return input_dict, label_dict, log_dict + + def print_header(self, fout, train_results, valid_results): + train_keys = sorted(train_results.keys()) + print_str = "" + print_str += "# %5s" % "step" + if not self.multi_task: + if valid_results is not None: + prop_fmt = " %11s %11s" + for k in train_keys: + print_str += prop_fmt % (k + "_val", k + "_trn") + else: + prop_fmt = " %11s" + for k in train_keys: + print_str += prop_fmt % (k + "_trn") + else: + for model_key in self.model_keys: + if valid_results[model_key] is not None: + prop_fmt = " %11s %11s" + for k in sorted(train_results[model_key].keys()): + print_str += prop_fmt % ( + k + f"_val_{model_key}", + k + f"_trn_{model_key}", + ) + else: + prop_fmt = " %11s" + for k in sorted(train_results[model_key].keys()): + print_str += prop_fmt % (k + f"_trn_{model_key}") + print_str += " %8s\n" % "lr" + fout.write(print_str) + fout.flush() + + def print_on_training(self, fout, step_id, cur_lr, train_results, valid_results): + train_keys = sorted(train_results.keys()) + print_str = "" + print_str += "%7d" % step_id + if not self.multi_task: + if valid_results is not None: + prop_fmt = " %11.2e %11.2e" + for k in train_keys: + print_str += prop_fmt % (valid_results[k], train_results[k]) + else: + prop_fmt = " %11.2e" + for k in train_keys: + print_str += prop_fmt % (train_results[k]) + else: + for model_key in self.model_keys: + if valid_results[model_key] is not None: + prop_fmt = " %11.2e %11.2e" + for k in sorted(valid_results[model_key].keys()): + print_str += prop_fmt % ( + valid_results[model_key][k], + train_results[model_key][k], + ) + else: + prop_fmt = " %11.2e" + for k in sorted(train_results[model_key].keys()): + print_str += prop_fmt % (train_results[model_key][k]) + print_str += " %8.1e\n" % cur_lr + fout.write(print_str) + fout.flush() diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py new file mode 100644 index 0000000000..fe423e6318 --- /dev/null +++ b/deepmd/pt/train/wrapper.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + Optional, + Union, +) + +import torch + +if torch.__version__.startswith("2"): + import torch._dynamo + + +class ModelWrapper(torch.nn.Module): + def __init__( + self, + model: Union[torch.nn.Module, Dict], + loss: Union[torch.nn.Module, Dict] = None, + model_params=None, + shared_links=None, + ): + """Construct a DeePMD model wrapper. + + Args: + - config: The Dict-like configuration with training options. + """ + super().__init__() + self.model_params = model_params if model_params is not None else {} + self.train_infos = { + "lr": 0, + "step": 0, + } + self.multi_task = False + self.model = torch.nn.ModuleDict() + # Model + if isinstance(model, torch.nn.Module): + self.model["Default"] = model + elif isinstance(model, dict): + self.multi_task = True + for task_key in model: + assert isinstance( + model[task_key], torch.nn.Module + ), f"{task_key} in model_dict is not a torch.nn.Module!" + self.model[task_key] = model[task_key] + # Loss + self.loss = None + if loss is not None: + self.loss = torch.nn.ModuleDict() + if isinstance(loss, torch.nn.Module): + self.loss["Default"] = loss + elif isinstance(loss, dict): + for task_key in loss: + assert isinstance( + loss[task_key], torch.nn.Module + ), f"{task_key} in loss_dict is not a torch.nn.Module!" + self.loss[task_key] = loss[task_key] + self.inference_only = self.loss is None + + def set_trainable_params(self): + supported_types = ["type_embedding", "descriptor", "fitting_net"] + for model_item in self.model: + for net_type in supported_types: + trainable = True + if not self.multi_task: + if net_type in self.model_params: + trainable = self.model_params[net_type].get("trainable", True) + else: + if net_type in self.model_params["model_dict"][model_item]: + trainable = self.model_params["model_dict"][model_item][ + net_type + ].get("trainable", True) + if ( + hasattr(self.model[model_item], net_type) + and getattr(self.model[model_item], net_type) is not None + ): + for param in ( + self.model[model_item].__getattr__(net_type).parameters() + ): + param.requires_grad = trainable + + def share_params(self, shared_links, resume=False): + supported_types = ["type_embedding", "descriptor", "fitting_net"] + for shared_item in shared_links: + class_name = shared_links[shared_item]["type"] + shared_base = shared_links[shared_item]["links"][0] + class_type_base = shared_base["shared_type"] + model_key_base = shared_base["model_key"] + shared_level_base = shared_base["shared_level"] + if "descriptor" in class_type_base: + if class_type_base == "descriptor": + base_class = self.model[model_key_base].__getattr__("descriptor") + elif "hybrid" in class_type_base: + hybrid_index = int(class_type_base.split("_")[-1]) + base_class = ( + self.model[model_key_base] + .__getattr__("descriptor") + .descriptor_list[hybrid_index] + ) + else: + raise RuntimeError(f"Unknown class_type {class_type_base}!") + for link_item in shared_links[shared_item]["links"][1:]: + class_type_link = link_item["shared_type"] + model_key_link = link_item["model_key"] + shared_level_link = int(link_item["shared_level"]) + assert ( + shared_level_link >= shared_level_base + ), "The shared_links must be sorted by shared_level!" + assert ( + "descriptor" in class_type_link + ), f"Class type mismatched: {class_type_base} vs {class_type_link}!" + if class_type_link == "descriptor": + link_class = self.model[model_key_link].__getattr__( + "descriptor" + ) + elif "hybrid" in class_type_link: + hybrid_index = int(class_type_link.split("_")[-1]) + link_class = ( + self.model[model_key_link] + .__getattr__("descriptor") + .descriptor_list[hybrid_index] + ) + else: + raise RuntimeError(f"Unknown class_type {class_type_link}!") + link_class.share_params( + base_class, shared_level_link, resume=resume + ) + print( + f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!" + ) + else: + if hasattr(self.model[model_key_base], class_type_base): + base_class = self.model[model_key_base].__getattr__(class_type_base) + for link_item in shared_links[shared_item]["links"][1:]: + class_type_link = link_item["shared_type"] + model_key_link = link_item["model_key"] + shared_level_link = int(link_item["shared_level"]) + assert ( + shared_level_link >= shared_level_base + ), "The shared_links must be sorted by shared_level!" + assert ( + class_type_base == class_type_link + ), f"Class type mismatched: {class_type_base} vs {class_type_link}!" + link_class = self.model[model_key_link].__getattr__( + class_type_link + ) + link_class.share_params( + base_class, shared_level_link, resume=resume + ) + print( + f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!" + ) + + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + cur_lr: Optional[torch.Tensor] = None, + label: Optional[torch.Tensor] = None, + task_key: Optional[torch.Tensor] = None, + inference_only=False, + do_atomic_virial=False, + ): + if not self.multi_task: + task_key = "Default" + else: + assert ( + task_key is not None + ), f"Multitask model must specify the inference task! Supported tasks are {list(self.model.keys())}." + model_pred = self.model[task_key]( + coord, atype, box=box, do_atomic_virial=do_atomic_virial + ) + natoms = atype.shape[-1] + if not self.inference_only and not inference_only: + loss, more_loss = self.loss[task_key]( + model_pred, label, natoms=natoms, learning_rate=cur_lr + ) + return model_pred, loss, more_loss + else: + return model_pred, None, None + + def set_extra_state(self, state: Dict): + self.model_params = state["model_params"] + self.train_infos = state["train_infos"] + return None + + def get_extra_state(self) -> Dict: + state = { + "model_params": self.model_params, + "train_infos": self.train_infos, + } + return state diff --git a/deepmd/pt/utils/__init__.py b/deepmd/pt/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt/utils/ase_calc.py b/deepmd/pt/utils/ase_calc.py new file mode 100644 index 0000000000..8d5fe8bce9 --- /dev/null +++ b/deepmd/pt/utils/ase_calc.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + ClassVar, +) + +import dpdata +import numpy as np +from ase import ( + Atoms, +) +from ase.calculators.calculator import ( + Calculator, + PropertyNotImplementedError, +) + +from deepmd.pt.infer.deep_eval import ( + DeepPot, +) + + +class DPCalculator(Calculator): + implemented_properties: ClassVar[list] = [ + "energy", + "free_energy", + "forces", + "virial", + "stress", + ] + + def __init__(self, model): + Calculator.__init__(self) + self.dp = DeepPot(model) + self.type_map = self.dp.type_map + + def calculate(self, atoms: Atoms, properties, system_changes) -> None: + Calculator.calculate(self, atoms, properties, system_changes) + system = dpdata.System(atoms, fmt="ase/structure") + type_trans = np.array( + [self.type_map.index(i) for i in system.data["atom_names"]] + ) + input_coords = system.data["coords"] + input_cells = system.data["cells"] + input_types = list(type_trans[system.data["atom_types"]]) + model_predict = self.dp.eval(input_coords, input_cells, input_types) + self.results = { + "energy": model_predict[0].item(), + "free_energy": model_predict[0].item(), + "forces": model_predict[1].reshape(-1, 3), + "virial": model_predict[2].reshape(3, 3), + } + + # convert virial into stress for lattice relaxation + if "stress" in properties: + if sum(atoms.get_pbc()) > 0 or (atoms.cell is not None): + # the usual convention (tensile stress is positive) + # stress = -virial / volume + stress = ( + -0.5 + * (self.results["virial"].copy() + self.results["virial"].copy().T) + / atoms.get_volume() + ) + # Voigt notation + self.results["stress"] = stress.flat[[0, 4, 8, 5, 2, 1]] + else: + raise PropertyNotImplementedError diff --git a/deepmd/pt/utils/auto_batch_size.py b/deepmd/pt/utils/auto_batch_size.py new file mode 100644 index 0000000000..5af7760e2a --- /dev/null +++ b/deepmd/pt/utils/auto_batch_size.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + +from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase + + +class AutoBatchSize(AutoBatchSizeBase): + def is_gpu_available(self) -> bool: + """Check if GPU is available. + + Returns + ------- + bool + True if GPU is available + """ + return torch.cuda.is_available() + + def is_oom_error(self, e: Exception) -> bool: + """Check if the exception is an OOM error. + + Parameters + ---------- + e : Exception + Exception + """ + return isinstance(e, RuntimeError) and "CUDA out of memory." in e.args[0] diff --git a/deepmd/pt/utils/cache.py b/deepmd/pt/utils/cache.py new file mode 100644 index 0000000000..c40c4050b7 --- /dev/null +++ b/deepmd/pt/utils/cache.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy as copy_lib +import functools + + +def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False): + if deepcopy: + + def decorator(f): + cached_func = functools.lru_cache(maxsize, typed)(f) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return copy_lib.deepcopy(cached_func(*args, **kwargs)) + + return wrapper + + elif copy: + + def decorator(f): + cached_func = functools.lru_cache(maxsize, typed)(f) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return copy_lib.copy(cached_func(*args, **kwargs)) + + return wrapper + + else: + decorator = functools.lru_cache(maxsize, typed) + return decorator diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py new file mode 100644 index 0000000000..7c95f66c9c --- /dev/null +++ b/deepmd/pt/utils/dataloader.py @@ -0,0 +1,319 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +import os +import queue +import time +from multiprocessing.dummy import ( + Pool, +) +from threading import ( + Thread, +) +from typing import ( + List, +) + +import h5py +import torch +import torch.distributed as dist +import torch.multiprocessing +from torch.utils.data import ( + DataLoader, + Dataset, + WeightedRandomSampler, +) +from torch.utils.data.distributed import ( + DistributedSampler, +) + +from deepmd.pt.model.descriptor import ( + Descriptor, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSetForLoader, +) +from deepmd.utils.data_system import ( + prob_sys_size_ext, + process_sys_probs, +) + +torch.multiprocessing.set_sharing_strategy("file_system") + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +class DpLoaderSet(Dataset): + """A dataset for storing DataLoaders to multiple Systems.""" + + def __init__( + self, + systems, + batch_size, + model_params, + seed=10, + type_split=True, + noise_settings=None, + shuffle=True, + ): + setup_seed(seed) + if isinstance(systems, str): + with h5py.File(systems) as file: + systems = [os.path.join(systems, item) for item in file.keys()] + + self.systems: List[DeepmdDataSetForLoader] = [] + if len(systems) >= 100: + logging.info(f"Constructing DataLoaders from {len(systems)} systems") + + def construct_dataset(system): + ### this design requires "rcut" and "sel" in the descriptor + ### VERY BAD DESIGN!!!! + ### not all descriptors provides these parameter in their constructor + if model_params["descriptor"].get("type") != "hybrid": + info_dict = Descriptor.get_data_process_key(model_params["descriptor"]) + rcut = info_dict["rcut"] + sel = info_dict["sel"] + else: ### need to remove this + rcut = [] + sel = [] + for ii in model_params["descriptor"]["list"]: + rcut.append(ii["rcut"]) + sel.append(ii["sel"]) + return DeepmdDataSetForLoader( + system=system, + type_map=model_params["type_map"], + rcut=rcut, + sel=sel, + type_split=type_split, + noise_settings=noise_settings, + shuffle=shuffle, + ) + + with Pool( + os.cpu_count() + // (int(os.environ["LOCAL_WORLD_SIZE"]) if dist.is_initialized() else 1) + ) as pool: + self.systems = pool.map(construct_dataset, systems) + + self.sampler_list: List[DistributedSampler] = [] + self.index = [] + self.total_batch = 0 + + self.dataloaders = [] + for system in self.systems: + if dist.is_initialized(): + system_sampler = DistributedSampler(system) + self.sampler_list.append(system_sampler) + else: + system_sampler = None + if isinstance(batch_size, str): + if batch_size == "auto": + rule = 32 + elif batch_size.startswith("auto:"): + rule = int(batch_size.split(":")[1]) + else: + rule = None + logging.error("Unsupported batch size type") + self.batch_size = rule // system._natoms + if self.batch_size * system._natoms < rule: + self.batch_size += 1 + else: + self.batch_size = batch_size + system_dataloader = DataLoader( + dataset=system, + batch_size=self.batch_size, + num_workers=0, # Should be 0 to avoid too many threads forked + sampler=system_sampler, + collate_fn=collate_batch, + shuffle=(not dist.is_initialized()) and shuffle, + ) + self.dataloaders.append(system_dataloader) + self.index.append(len(system_dataloader)) + self.total_batch += len(system_dataloader) + # Initialize iterator instances for DataLoader + self.iters = [] + for item in self.dataloaders: + self.iters.append(iter(item)) + + def set_noise(self, noise_settings): + # noise_settings['noise_type'] # "trunc_normal", "normal", "uniform" + # noise_settings['noise'] # float, default 1.0 + # noise_settings['noise_mode'] # "prob", "fix_num" + # noise_settings['mask_num'] # if "fix_num", int + # noise_settings['mask_prob'] # if "prob", float + # noise_settings['same_mask'] # coord and type same mask? + for system in self.systems: + system.set_noise(noise_settings) + + def __len__(self): + return len(self.dataloaders) + + def __getitem__(self, idx): + # logging.warning(str(torch.distributed.get_rank())+" idx: "+str(idx)+" index: "+str(self.index[idx])) + try: + batch = next(self.iters[idx]) + except StopIteration: + self.iters[idx] = iter(self.dataloaders[idx]) + batch = next(self.iters[idx]) + batch["sid"] = idx + return batch + + +_sentinel = object() +QUEUESIZE = 32 + + +class BackgroundConsumer(Thread): + def __init__(self, queue, source, max_len): + Thread.__init__(self) + self._queue = queue + self._source = source # Main DL iterator + self._max_len = max_len # + + def run(self): + for item in self._source: + self._queue.put(item) # Blocking if the queue is full + + # Signal the consumer we are done. + self._queue.put(_sentinel) + + +class BufferedIterator: + def __init__(self, iterable): + self._queue = queue.Queue(QUEUESIZE) + self._iterable = iterable + self._consumer = None + + self.start_time = time.time() + self.warning_time = None + self.total = len(iterable) + + def _create_consumer(self): + self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total) + self._consumer.daemon = True + self._consumer.start() + + def __iter__(self): + return self + + def __len__(self): + return self.total + + def __next__(self): + # Create consumer if not created yet + if self._consumer is None: + self._create_consumer() + # Notify the user if there is a data loading bottleneck + if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): + if time.time() - self.start_time > 5 * 60: + if ( + self.warning_time is None + or time.time() - self.warning_time > 15 * 60 + ): + logging.warning( + "Data loading buffer is empty or nearly empty. This may " + "indicate a data loading bottleneck, and increasing the " + "number of workers (--num-workers) may help." + ) + self.warning_time = time.time() + + # Get next example + item = self._queue.get() + if isinstance(item, Exception): + raise item + if item is _sentinel: + raise StopIteration + return item + + +def collate_tensor_fn(batch): + elem = batch[0] + if not isinstance(elem, list): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in batch) + storage = elem._typed_storage()._new_shared(numel, device=elem.device) + out = elem.new(storage).resize_(len(batch), *list(elem.size())) + return torch.stack(batch, 0, out=out) + else: + out_hybrid = [] + for ii, hybrid_item in enumerate(elem): + out = None + tmp_batch = [x[ii] for x in batch] + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in tmp_batch) + storage = hybrid_item._typed_storage()._new_shared( + numel, device=hybrid_item.device + ) + out = hybrid_item.new(storage).resize_( + len(tmp_batch), *list(hybrid_item.size()) + ) + out_hybrid.append(torch.stack(tmp_batch, 0, out=out)) + return out_hybrid + + +def collate_batch(batch): + example = batch[0] + result = example.copy() + for key in example.keys(): + if key == "shift" or key == "mapping": + natoms_extended = max([d[key].shape[0] for d in batch]) + n_frames = len(batch) + list = [] + for x in range(n_frames): + list.append(batch[x][key]) + if key == "shift": + result[key] = torch.zeros( + (n_frames, natoms_extended, 3), + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.PREPROCESS_DEVICE, + ) + else: + result[key] = torch.zeros( + (n_frames, natoms_extended), + dtype=torch.long, + device=env.PREPROCESS_DEVICE, + ) + for i in range(len(batch)): + natoms_tmp = list[i].shape[0] + result[key][i, :natoms_tmp] = list[i] + elif "find_" in key: + result[key] = batch[0][key] + else: + if batch[0][key] is None: + result[key] = None + elif key == "fid": + result[key] = [d[key] for d in batch] + else: + result[key] = collate_tensor_fn([d[key] for d in batch]) + return result + + +def get_weighted_sampler(training_data, prob_style, sys_prob=False): + if sys_prob is False: + if prob_style == "prob_uniform": + prob_v = 1.0 / float(training_data.__len__()) + probs = [prob_v for ii in range(training_data.__len__())] + else: # prob_sys_size;A:B:p1;C:D:p2 or prob_sys_size = prob_sys_size;0:nsys:1.0 + if prob_style == "prob_sys_size": + style = f"prob_sys_size;0:{len(training_data)}:1.0" + else: + style = prob_style + probs = prob_sys_size_ext(style, len(training_data), training_data.index) + else: + probs = process_sys_probs(prob_style, training_data.index) + logging.info("Generated weighted sampler with prob array: " + str(probs)) + # training_data.total_batch is the size of one epoch, you can increase it to avoid too many rebuilding of iteraters + len_sampler = training_data.total_batch * max(env.NUM_WORKERS, 1) + sampler = WeightedRandomSampler(probs, len_sampler, replacement=True) + return sampler diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py new file mode 100644 index 0000000000..24daa6e37e --- /dev/null +++ b/deepmd/pt/utils/dataset.py @@ -0,0 +1,918 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import glob +import os +from typing import ( + List, + Optional, +) + +import h5py +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import ( + Dataset, +) +from tqdm import ( + trange, +) + +from deepmd.pt.utils import ( + dp_random, + env, +) +from deepmd.pt.utils.cache import ( + lru_cache, +) +from deepmd.pt.utils.preprocess import ( + Region3D, + make_env_mat, + normalize_coord, +) + + +class DeepmdDataSystem: + def __init__( + self, + sys_path: str, + rcut, + sec, + type_map: Optional[List[str]] = None, + type_split=True, + noise_settings=None, + shuffle=True, + ): + """Construct DeePMD-style frame collection of one system. + + Args: + - sys_path: Paths to the system. + - type_map: Atom types. + """ + sys_path = sys_path.replace("#", "") + if ".hdf5" in sys_path: + tmp = sys_path.split("/") + path = "/".join(tmp[:-1]) + sys = tmp[-1] + self.file = h5py.File(path)[sys] + self._dirs = [] + for item in self.file.keys(): + if "set." in item: + self._dirs.append(item) + self._dirs.sort() + else: + self.file = None + self._dirs = glob.glob(os.path.join(sys_path, "set.*")) + self._dirs.sort() + self.type_split = type_split + self.noise_settings = noise_settings + self._check_pbc(sys_path) + self.shuffle = shuffle + if noise_settings is not None: + self.noise_type = noise_settings.get("noise_type", "uniform") + self.noise = float(noise_settings.get("noise", 1.0)) + self.noise_mode = noise_settings.get("noise_mode", "fix_num") + self.mask_num = int(noise_settings.get("mask_num", 1)) + self.mask_prob = float(noise_settings.get("mask_prob", 0.15)) + self.same_mask = noise_settings.get("same_mask", False) + self.mask_coord = noise_settings.get("mask_coord", False) + self.mask_type = noise_settings.get("mask_type", False) + self.mask_type_idx = int(noise_settings.get("mask_type_idx", 0)) + self.max_fail_num = int(noise_settings.get("max_fail_num", 10)) + + # check mixed type + error_format_msg = ( + "if one of the set is of mixed_type format, " + "then all of the sets in this system should be of mixed_type format!" + ) + if len(self._dirs) == 0: + raise RuntimeError(f"No set found in system {sys_path}.") + + self.mixed_type = self._check_mode(self._dirs[0]) + for set_item in self._dirs[1:]: + assert self._check_mode(set_item) == self.mixed_type, error_format_msg + + self._atom_type = self._load_type(sys_path) + self._natoms = len(self._atom_type) + + self._type_map = self._load_type_map(sys_path) + self.enforce_type_map = False + if type_map is not None and self._type_map is not None: + if not self.mixed_type: + atom_type = [ + type_map.index(self._type_map[ii]) for ii in self._atom_type + ] + self._atom_type = np.array(atom_type, dtype=np.int32) + + else: + self.enforce_type_map = True + sorter = np.argsort(type_map) + self.type_idx_map = np.array( + sorter[np.searchsorted(type_map, self._type_map, sorter=sorter)] + ) + # padding for virtual atom + self.type_idx_map = np.append( + self.type_idx_map, np.array([-1], dtype=np.int32) + ) + self._type_map = type_map + if type_map is None and self.type_map is None and self.mixed_type: + raise RuntimeError("mixed_type format must have type_map!") + self._idx_map = _make_idx_map(self._atom_type) + + self._data_dict = {} + self.add("box", 9, must=self.pbc) + self.add("coord", 3, atomic=True, must=True) + self.add("energy", 1, atomic=False, must=False, high_prec=True) + self.add("force", 3, atomic=True, must=False, high_prec=False) + self.add("virial", 9, atomic=False, must=False, high_prec=False) + + self._sys_path = sys_path + self.rcut = rcut + self.sec = sec + if isinstance(rcut, float): + self.hybrid = False + elif isinstance(rcut, list): + self.hybrid = True + else: + RuntimeError("Unkown rcut type!") + self.sets = [None for i in range(len(self._sys_path))] + + self.nframes = 0 + i = 1 + self.prefix_sum = [0] * (len(self._dirs) + 1) + for item in self._dirs: + frames = self._load_set(item, fast=True) + self.prefix_sum[i] = self.prefix_sum[i - 1] + frames + i += 1 + self.nframes += frames + + def _check_pbc(self, sys_path): + pbc = True + if os.path.isfile(os.path.join(sys_path, "nopbc")): + pbc = False + self.pbc = pbc + + def set_noise(self, noise_settings): + # noise_settings['noise_type'] # "trunc_normal", "normal", "uniform" + # noise_settings['noise'] # float, default 1.0 + # noise_settings['noise_mode'] # "prob", "fix_num" + # noise_settings['mask_num'] # if "fix_num", int + # noise_settings['mask_prob'] # if "prob", float + # noise_settings['same_mask'] # coord and type same mask? + self.noise_settings = noise_settings + self.noise_type = noise_settings.get("noise_type", "uniform") + self.noise = float(noise_settings.get("noise", 1.0)) + self.noise_mode = noise_settings.get("noise_mode", "fix_num") + self.mask_num = int(noise_settings.get("mask_num", 1)) + self.mask_coord = noise_settings.get("mask_coord", False) + self.mask_type = noise_settings.get("mask_type", False) + self.mask_prob = float(noise_settings.get("mask_prob", 0.15)) + self.same_mask = noise_settings.get("noise_type", False) + + def add( + self, + key: str, + ndof: int, + atomic: bool = False, + must: bool = False, + high_prec: bool = False, + ): + """Add a data item that to be loaded. + + Args: + - key: The key of the item. The corresponding data is stored in `sys_path/set.*/key.npy` + - ndof: The number of dof + - atomic: The item is an atomic property. + - must: The data file `sys_path/set.*/key.npy` must exist. Otherwise, value is set to zero. + - high_prec: Load the data and store in float64, otherwise in float32. + """ + self._data_dict[key] = { + "ndof": ndof, + "atomic": atomic, + "must": must, + "high_prec": high_prec, + } + + # deprecated TODO + def get_batch_for_train(self, batch_size: int): + """Get a batch of data with at most `batch_size` frames. The frames are randomly picked from the data system. + + Args: + - batch_size: Frame count. + """ + if not hasattr(self, "_frames"): + self.set_size = 0 + self._set_count = 0 + self._iterator = 0 + if batch_size == "auto": + batch_size = -(-32 // self._natoms) + if self._iterator + batch_size > self.set_size: + set_idx = self._set_count % len(self._dirs) + if self.sets[set_idx] is None: + frames = self._load_set(self._dirs[set_idx]) + frames = self.preprocess(frames) + cnt = 0 + for item in self.sets: + if item is not None: + cnt += 1 + if cnt < env.CACHE_PER_SYS: + self.sets[set_idx] = frames + else: + frames = self.sets[set_idx] + self._frames = frames + self._shuffle_data() + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + ssize = self._frames["coord"].shape[0] + subsize = ssize // world_size + self._iterator = rank * subsize + self.set_size = min((rank + 1) * subsize, ssize) + else: + self.set_size = self._frames["coord"].shape[0] + self._iterator = 0 + self._set_count += 1 + iterator = min(self._iterator + batch_size, self.set_size) + idx = np.arange(self._iterator, iterator) + self._iterator += batch_size + return self._get_subdata(idx) + + # deprecated TODO + def get_batch(self, batch_size: int): + """Get a batch of data with at most `batch_size` frames. The frames are randomly picked from the data system. + Args: + - batch_size: Frame count. + """ + if not hasattr(self, "_frames"): + self.set_size = 0 + self._set_count = 0 + self._iterator = 0 + if batch_size == "auto": + batch_size = -(-32 // self._natoms) + if self._iterator + batch_size > self.set_size: + set_idx = self._set_count % len(self._dirs) + if self.sets[set_idx] is None: + frames = self._load_set(self._dirs[set_idx]) + frames = self.preprocess(frames) + cnt = 0 + for item in self.sets: + if item is not None: + cnt += 1 + if cnt < env.CACHE_PER_SYS: + self.sets[set_idx] = frames + else: + frames = self.sets[set_idx] + self._frames = frames + self._shuffle_data() + self.set_size = self._frames["coord"].shape[0] + self._iterator = 0 + self._set_count += 1 + iterator = min(self._iterator + batch_size, self.set_size) + idx = np.arange(self._iterator, iterator) + self._iterator += batch_size + return self._get_subdata(idx) + + def get_ntypes(self): + """Number of atom types in the system.""" + if self._type_map is not None: + return len(self._type_map) + else: + return max(self._atom_type) + 1 + + def get_natoms_vec(self, ntypes: int): + """Get number of atoms and number of atoms in different types. + + Args: + - ntypes: Number of types (may be larger than the actual number of types in the system). + """ + natoms = len(self._atom_type) + natoms_vec = np.zeros(ntypes).astype(int) + for ii in range(ntypes): + natoms_vec[ii] = np.count_nonzero(self._atom_type == ii) + tmp = [natoms, natoms] + tmp = np.append(tmp, natoms_vec) + return tmp.astype(np.int32) + + def _load_type(self, sys_path): + if self.file is not None: + return self.file["type.raw"][:] + else: + return np.loadtxt( + os.path.join(sys_path, "type.raw"), dtype=np.int32, ndmin=1 + ) + + def _load_type_map(self, sys_path): + if self.file is not None: + tmp = self.file["type_map.raw"][:].tolist() + tmp = [item.decode("ascii") for item in tmp] + return tmp + else: + fname = os.path.join(sys_path, "type_map.raw") + if os.path.isfile(fname): + with open(fname) as fin: + content = fin.read() + return content.split() + else: + return None + + def _check_mode(self, sys_path): + return os.path.isfile(sys_path + "/real_atom_types.npy") + + def _load_type_mix(self, set_name): + type_path = set_name + "/real_atom_types.npy" + real_type = np.load(type_path).astype(np.int32).reshape([-1, self._natoms]) + return real_type + + @lru_cache(maxsize=16, copy=True) + def _load_set(self, set_name, fast=False): + if self.file is None: + path = os.path.join(set_name, "coord.npy") + if self._data_dict["coord"]["high_prec"]: + coord = np.load(path).astype(env.GLOBAL_ENER_FLOAT_PRECISION) + else: + coord = np.load(path).astype(env.GLOBAL_NP_FLOAT_PRECISION) + if coord.ndim == 1: + coord = coord.reshape([1, -1]) + assert coord.shape[1] == self._data_dict["coord"]["ndof"] * self._natoms + nframes = coord.shape[0] + if fast: + return nframes + data = {"type": np.tile(self._atom_type[self._idx_map], (nframes, 1))} + for kk in self._data_dict.keys(): + data["find_" + kk], data[kk] = self._load_data( + set_name, + kk, + nframes, + self._data_dict[kk]["ndof"], + atomic=self._data_dict[kk]["atomic"], + high_prec=self._data_dict[kk]["high_prec"], + must=self._data_dict[kk]["must"], + ) + if self.mixed_type: + # nframes x natoms + atom_type_mix = self._load_type_mix(set_name) + if self.enforce_type_map: + try: + atom_type_mix_ = self.type_idx_map[atom_type_mix].astype( + np.int32 + ) + except IndexError as e: + raise IndexError( + "some types in 'real_atom_types.npy' of set {} are not contained in {} types!".format( + set_name, self.get_ntypes() + ) + ) from e + atom_type_mix = atom_type_mix_ + real_type = atom_type_mix.reshape([nframes, self._natoms]) + data["type"] = real_type + natoms = data["type"].shape[1] + # nframes x ntypes + atom_type_nums = np.array( + [(real_type == i).sum(axis=-1) for i in range(self.get_ntypes())], + dtype=np.int32, + ).T + ghost_nums = np.array( + [(real_type == -1).sum(axis=-1)], + dtype=np.int32, + ).T + assert ( + atom_type_nums.sum(axis=-1) + ghost_nums.sum(axis=-1) == natoms + ).all(), "some types in 'real_atom_types.npy' of set {} are not contained in {} types!".format( + set_name, self.get_ntypes() + ) + data["real_natoms_vec"] = np.concatenate( + ( + np.tile( + np.array([natoms, natoms], dtype=np.int32), (nframes, 1) + ), + atom_type_nums, + ), + axis=-1, + ) + + return data + else: + data = {} + nframes = self.file[set_name]["coord.npy"].shape[0] + if fast: + return nframes + for key in ["coord", "energy", "force", "box"]: + data[key] = self.file[set_name][f"{key}.npy"][:] + if self._data_dict[key]["atomic"]: + data[key] = data[key].reshape(nframes, self._natoms, -1)[ + :, self._idx_map, : + ] + if self.mixed_type: + # nframes x natoms + atom_type_mix = self._load_type_mix(set_name) + if self.enforce_type_map: + try: + atom_type_mix_ = self.type_idx_map[atom_type_mix].astype( + np.int32 + ) + except IndexError as e: + raise IndexError( + "some types in 'real_atom_types.npy' of set {} are not contained in {} types!".format( + set_name, self.get_ntypes() + ) + ) from e + atom_type_mix = atom_type_mix_ + real_type = atom_type_mix.reshape([nframes, self._natoms]) + data["type"] = real_type + natoms = data["type"].shape[1] + # nframes x ntypes + atom_type_nums = np.array( + [(real_type == i).sum(axis=-1) for i in range(self.get_ntypes())], + dtype=np.int32, + ).T + ghost_nums = np.array( + [(real_type == -1).sum(axis=-1)], + dtype=np.int32, + ).T + assert ( + atom_type_nums.sum(axis=-1) + ghost_nums.sum(axis=-1) == natoms + ).all(), "some types in 'real_atom_types.npy' of set {} are not contained in {} types!".format( + set_name, self.get_ntypes() + ) + data["real_natoms_vec"] = np.concatenate( + ( + np.tile( + np.array([natoms, natoms], dtype=np.int32), (nframes, 1) + ), + atom_type_nums, + ), + axis=-1, + ) + else: + data["type"] = np.tile(self._atom_type[self._idx_map], (nframes, 1)) + return data + + def _load_data( + self, set_name, key, nframes, ndof, atomic=False, must=True, high_prec=False + ): + if atomic: + ndof *= self._natoms + path = os.path.join(set_name, key + ".npy") + # logging.info('Loading data from: %s', path) + if os.path.isfile(path): + if high_prec: + data = np.load(path).astype(env.GLOBAL_ENER_FLOAT_PRECISION) + else: + data = np.load(path).astype(env.GLOBAL_NP_FLOAT_PRECISION) + if atomic: + data = data.reshape([nframes, self._natoms, -1]) + data = data[:, self._idx_map, :] + data = data.reshape([nframes, -1]) + data = np.reshape(data, [nframes, ndof]) + return np.float32(1.0), data + elif must: + raise RuntimeError("%s not found!" % path) + else: + if high_prec: + data = np.zeros([nframes, ndof]).astype(env.GLOBAL_ENER_FLOAT_PRECISION) + else: + data = np.zeros([nframes, ndof]).astype(env.GLOBAL_NP_FLOAT_PRECISION) + return np.float32(0.0), data + + # deprecated TODO + def preprocess(self, batch): + n_frames = batch["coord"].shape[0] + for kk in self._data_dict.keys(): + if "find_" in kk: + pass + else: + batch[kk] = torch.tensor( + batch[kk], + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.PREPROCESS_DEVICE, + ) + if self._data_dict[kk]["atomic"]: + batch[kk] = batch[kk].view( + n_frames, -1, self._data_dict[kk]["ndof"] + ) + + for kk in ["type", "real_natoms_vec"]: + if kk in batch.keys(): + batch[kk] = torch.tensor( + batch[kk], dtype=torch.long, device=env.PREPROCESS_DEVICE + ) + batch["atype"] = batch.pop("type") + + keys = ["nlist", "nlist_loc", "nlist_type", "shift", "mapping"] + coord = batch["coord"] + atype = batch["atype"] + box = batch["box"] + rcut = self.rcut + sec = self.sec + assert batch["atype"].max() < len(self._type_map) + nlist, nlist_loc, nlist_type, shift, mapping = [], [], [], [], [] + + for sid in trange(n_frames, disable=env.DISABLE_TQDM): + region = Region3D(box[sid]) + nloc = atype[sid].shape[0] + _coord = normalize_coord(coord[sid], region, nloc) + coord[sid] = _coord + a, b, c, d, e = make_env_mat( + _coord, atype[sid], region, rcut, sec, type_split=self.type_split + ) + nlist.append(a) + nlist_loc.append(b) + nlist_type.append(c) + shift.append(d) + mapping.append(e) + nlist = torch.stack(nlist) + nlist_loc = torch.stack(nlist_loc) + nlist_type = torch.stack(nlist_type) + batch["nlist"] = nlist + batch["nlist_loc"] = nlist_loc + batch["nlist_type"] = nlist_type + natoms_extended = max([item.shape[0] for item in shift]) + batch["shift"] = torch.zeros( + (n_frames, natoms_extended, 3), + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.PREPROCESS_DEVICE, + ) + batch["mapping"] = torch.zeros( + (n_frames, natoms_extended), dtype=torch.long, device=env.PREPROCESS_DEVICE + ) + for i in range(len(shift)): + natoms_tmp = shift[i].shape[0] + batch["shift"][i, :natoms_tmp] = shift[i] + batch["mapping"][i, :natoms_tmp] = mapping[i] + return batch + + def _shuffle_data(self): + nframes = self._frames["coord"].shape[0] + idx = np.arange(nframes) + if self.shuffle: + dp_random.shuffle(idx) + self.idx_mapping = idx + + def _get_subdata(self, idx=None): + data = self._frames + idx = self.idx_mapping[idx] + new_data = {} + for ii in data: + dd = data[ii] + if "find_" in ii: + new_data[ii] = dd + else: + if idx is not None: + new_data[ii] = dd[idx] + else: + new_data[ii] = dd + return new_data + + # note: this function needs to be optimized for single frame process + def single_preprocess(self, batch, sid): + for kk in self._data_dict.keys(): + if "find_" in kk: + pass + else: + batch[kk] = torch.tensor( + batch[kk][sid], + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.PREPROCESS_DEVICE, + ) + if self._data_dict[kk]["atomic"]: + batch[kk] = batch[kk].view(-1, self._data_dict[kk]["ndof"]) + for kk in ["type", "real_natoms_vec"]: + if kk in batch.keys(): + batch[kk] = torch.tensor( + batch[kk][sid], dtype=torch.long, device=env.PREPROCESS_DEVICE + ) + clean_coord = batch.pop("coord") + clean_type = batch.pop("type") + nloc = clean_type.shape[0] + rcut = self.rcut + sec = self.sec + nlist, nlist_loc, nlist_type, shift, mapping = [], [], [], [], [] + if self.pbc: + box = batch["box"] + region = Region3D(box) + else: + box = None + batch["box"] = None + region = None + if self.noise_settings is None: + batch["atype"] = clean_type + batch["coord"] = clean_coord + coord = clean_coord + atype = batch["atype"] + if self.pbc: + _coord = normalize_coord(coord, region, nloc) + + else: + _coord = coord.clone() + batch["coord"] = _coord + nlist, nlist_loc, nlist_type, shift, mapping = make_env_mat( + _coord, + atype, + region, + rcut, + sec, + pbc=self.pbc, + type_split=self.type_split, + ) + batch["nlist"] = nlist + batch["nlist_loc"] = nlist_loc + batch["nlist_type"] = nlist_type + batch["shift"] = shift + batch["mapping"] = mapping + return batch + else: + batch["clean_type"] = clean_type + if self.pbc: + _clean_coord = normalize_coord(clean_coord, region, nloc) + else: + _clean_coord = clean_coord.clone() + batch["clean_coord"] = _clean_coord + # add noise + for i in range(self.max_fail_num): + mask_num = 0 + if self.noise_mode == "fix_num": + mask_num = self.mask_num + if len(batch["clean_type"]) < mask_num: + mask_num = len(batch["clean_type"]) + elif self.noise_mode == "prob": + mask_num = int(self.mask_prob * nloc) + if mask_num == 0: + mask_num = 1 + else: + NotImplementedError(f"Unknown noise mode {self.noise_mode}!") + rng = np.random.default_rng() + coord_mask_res = rng.choice( + range(nloc), mask_num, replace=False + ).tolist() + coord_mask = np.isin(range(nloc), coord_mask_res) + if self.same_mask: + type_mask = coord_mask.copy() + else: + rng = np.random.default_rng() + type_mask_res = rng.choice( + range(nloc), mask_num, replace=False + ).tolist() + type_mask = np.isin(range(nloc), type_mask_res) + + # add noise for coord + if self.mask_coord: + noise_on_coord = 0.0 + rng = np.random.default_rng() + if self.noise_type == "trunc_normal": + noise_on_coord = np.clip( + rng.standard_normal((mask_num, 3)) * self.noise, + a_min=-self.noise * 2.0, + a_max=self.noise * 2.0, + ) + elif self.noise_type == "normal": + noise_on_coord = rng.standard_normal((mask_num, 3)) * self.noise + elif self.noise_type == "uniform": + noise_on_coord = rng.uniform( + low=-self.noise, high=self.noise, size=(mask_num, 3) + ) + else: + NotImplementedError(f"Unknown noise type {self.noise_type}!") + noised_coord = _clean_coord.clone().detach() + noised_coord[coord_mask] += noise_on_coord + batch["coord_mask"] = torch.tensor( + coord_mask, dtype=torch.bool, device=env.PREPROCESS_DEVICE + ) + else: + noised_coord = _clean_coord + batch["coord_mask"] = torch.tensor( + np.zeros_like(coord_mask, dtype=bool), + dtype=torch.bool, + device=env.PREPROCESS_DEVICE, + ) + + # add mask for type + if self.mask_type: + masked_type = clean_type.clone().detach() + masked_type[type_mask] = self.mask_type_idx + batch["type_mask"] = torch.tensor( + type_mask, dtype=torch.bool, device=env.PREPROCESS_DEVICE + ) + else: + masked_type = clean_type + batch["type_mask"] = torch.tensor( + np.zeros_like(type_mask, dtype=bool), + dtype=torch.bool, + device=env.PREPROCESS_DEVICE, + ) + if self.pbc: + _coord = normalize_coord(noised_coord, region, nloc) + else: + _coord = noised_coord.clone() + try: + nlist, nlist_loc, nlist_type, shift, mapping = make_env_mat( + _coord, + masked_type, + region, + rcut, + sec, + pbc=self.pbc, + type_split=self.type_split, + min_check=True, + ) + except RuntimeError as e: + if i == self.max_fail_num - 1: + RuntimeError( + f"Add noise times beyond max tries {self.max_fail_num}!" + ) + continue + batch["atype"] = masked_type + batch["coord"] = noised_coord + batch["nlist"] = nlist + batch["nlist_loc"] = nlist_loc + batch["nlist_type"] = nlist_type + batch["shift"] = shift + batch["mapping"] = mapping + return batch + + def _get_item(self, index): + for i in range( + 0, len(self._dirs) + 1 + ): # note: if different sets can be merged, prefix sum is unused to calculate + if index < self.prefix_sum[i]: + break + frames = self._load_set(self._dirs[i - 1]) + frame = self.single_preprocess(frames, index - self.prefix_sum[i - 1]) + frame["fid"] = index + return frame + + +def _make_idx_map(atom_type): + natoms = atom_type.shape[0] + idx = np.arange(natoms) + idx_map = np.lexsort((idx, atom_type)) + return idx_map + + +class DeepmdDataSetForLoader(Dataset): + def __init__( + self, + system: str, + type_map: str, + rcut, + sel, + weight=None, + type_split=True, + noise_settings=None, + shuffle=True, + ): + """Construct DeePMD-style dataset containing frames cross different systems. + + Args: + - systems: Paths to systems. + - batch_size: Max frame count in a batch. + - type_map: Atom types. + """ + self._type_map = type_map + if not isinstance(rcut, list): + if isinstance(sel, int): + sel = [sel] + sec = torch.cumsum(torch.tensor(sel), dim=0) + else: + sec = [] + for sel_item in sel: + if isinstance(sel_item, int): + sel_item = [sel_item] + sec.append(torch.cumsum(torch.tensor(sel_item), dim=0)) + self._data_system = DeepmdDataSystem( + system, + rcut, + sec, + type_map=self._type_map, + type_split=type_split, + noise_settings=noise_settings, + shuffle=shuffle, + ) + self.mixed_type = self._data_system.mixed_type + self._ntypes = self._data_system.get_ntypes() + self._natoms = self._data_system._natoms + self._natoms_vec = self._data_system.get_natoms_vec(self._ntypes) + + def set_noise(self, noise_settings): + # noise_settings['noise_type'] # "trunc_normal", "normal", "uniform" + # noise_settings['noise'] # float, default 1.0 + # noise_settings['noise_mode'] # "prob", "fix_num" + # noise_settings['mask_num'] # if "fix_num", int + # noise_settings['mask_prob'] # if "prob", float + # noise_settings['same_mask'] # coord and type same mask? + self._data_system.set_noise(noise_settings) + + def __len__(self): + return self._data_system.nframes + + def __getitem__(self, index): + """Get a frame from the selected system.""" + b_data = self._data_system._get_item(index) + b_data["natoms"] = torch.tensor(self._natoms_vec, device=env.PREPROCESS_DEVICE) + return b_data + + +# deprecated TODO +class DeepmdDataSet(Dataset): + def __init__( + self, + systems: List[str], + batch_size: int, + type_map: List[str], + rcut=None, + sel=None, + weight=None, + type_split=True, + ): + """Construct DeePMD-style dataset containing frames cross different systems. + + Args: + - systems: Paths to systems. + - batch_size: Max frame count in a batch. + - type_map: Atom types. + """ + self._batch_size = batch_size + self._type_map = type_map + if sel is not None: + if isinstance(sel, int): + sel = [sel] + sec = torch.cumsum(torch.tensor(sel), dim=0) + if isinstance(systems, str): + with h5py.File(systems) as file: + systems = [os.path.join(systems, item) for item in file.keys()] + self._data_systems = [ + DeepmdDataSystem( + ii, rcut, sec, type_map=self._type_map, type_split=type_split + ) + for ii in systems + ] + # check mix_type format + error_format_msg = ( + "if one of the system is of mixed_type format, " + "then all of the systems in this dataset should be of mixed_type format!" + ) + self.mixed_type = self._data_systems[0].mixed_type + for sys_item in self._data_systems[1:]: + assert sys_item.mixed_type == self.mixed_type, error_format_msg + + if weight is None: + + def weight(name, sys): + return sys.nframes + + self.probs = [ + weight(item, self._data_systems[i]) for i, item in enumerate(systems) + ] + self.probs = np.array(self.probs, dtype=float) + self.probs /= self.probs.sum() + self._ntypes = max([ii.get_ntypes() for ii in self._data_systems]) + self._natoms_vec = [ + ii.get_natoms_vec(self._ntypes) for ii in self._data_systems + ] + self.cache = [{} for _ in self._data_systems] + + @property + def nsystems(self): + return len(self._data_systems) + + def __len__(self): + return self.nsystems + + def __getitem__(self, index=None): + """Get a batch of frames from the selected system.""" + if index is None: + index = dp_random.choice(np.arange(self.nsystems), self.probs) + b_data = self._data_systems[index].get_batch(self._batch_size) + b_data["natoms"] = torch.tensor( + self._natoms_vec[index], device=env.PREPROCESS_DEVICE + ) + batch_size = b_data["coord"].shape[0] + b_data["natoms"] = b_data["natoms"].unsqueeze(0).expand(batch_size, -1) + return b_data + + # deprecated TODO + def get_training_batch(self, index=None): + """Get a batch of frames from the selected system.""" + if index is None: + index = dp_random.choice(np.arange(self.nsystems), self.probs) + b_data = self._data_systems[index].get_batch_for_train(self._batch_size) + b_data["natoms"] = torch.tensor( + self._natoms_vec[index], device=env.PREPROCESS_DEVICE + ) + batch_size = b_data["coord"].shape[0] + b_data["natoms"] = b_data["natoms"].unsqueeze(0).expand(batch_size, -1) + return b_data + + def get_batch(self, sys_idx=None): + """TF-compatible batch for testing.""" + pt_batch = self[sys_idx] + np_batch = {} + for key in ["coord", "box", "force", "energy", "virial"]: + if key in pt_batch.keys(): + np_batch[key] = pt_batch[key].cpu().numpy() + for key in ["atype", "natoms"]: + if key in pt_batch.keys(): + np_batch[key] = pt_batch[key].cpu().numpy() + batch_size = pt_batch["coord"].shape[0] + np_batch["coord"] = np_batch["coord"].reshape(batch_size, -1) + np_batch["natoms"] = np_batch["natoms"][0] + np_batch["force"] = np_batch["force"].reshape(batch_size, -1) + return np_batch, pt_batch diff --git a/deepmd/pt/utils/dp_random.py b/deepmd/pt/utils/dp_random.py new file mode 100644 index 0000000000..e81488c506 --- /dev/null +++ b/deepmd/pt/utils/dp_random.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.utils.random import ( + choice, + random, + seed, + shuffle, +) + +__all__ = [ + "choice", + "random", + "seed", + "shuffle", +] diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py new file mode 100644 index 0000000000..5b6eaf7c14 --- /dev/null +++ b/deepmd/pt/utils/env.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os + +import numpy as np +import torch + +PRECISION = os.environ.get("PRECISION", "float64") +GLOBAL_NP_FLOAT_PRECISION = getattr(np, PRECISION) +GLOBAL_PT_FLOAT_PRECISION = getattr(torch, PRECISION) +GLOBAL_ENER_FLOAT_PRECISION = getattr(np, PRECISION) +DISABLE_TQDM = os.environ.get("DISABLE_TQDM", False) +SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) +try: + # only linux + ncpus = len(os.sched_getaffinity(0)) +except AttributeError: + ncpus = os.cpu_count() +NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(8, ncpus))) +# Make sure DDP uses correct device if applicable +LOCAL_RANK = os.environ.get("LOCAL_RANK") +LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK) + +if os.environ.get("DEVICE") == "cpu" or torch.cuda.is_available() is False: + DEVICE = torch.device("cpu") +else: + DEVICE = torch.device(f"cuda:{LOCAL_RANK}") + +if os.environ.get("PREPROCESS_DEVICE") == "gpu": + PREPROCESS_DEVICE = torch.device(f"cuda:{LOCAL_RANK}") +else: + PREPROCESS_DEVICE = torch.device("cpu") + +JIT = False +CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory +ENERGY_BIAS_TRAINABLE = True + +PRECISION_DICT = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "half": torch.float16, + "single": torch.float32, + "double": torch.float64, +} +DEFAULT_PRECISION = "float64" diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py new file mode 100644 index 0000000000..9d82783cc0 --- /dev/null +++ b/deepmd/pt/utils/finetune.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging + +import torch + +from deepmd.pt.utils import ( + env, +) + + +def change_finetune_model_params( + ckpt, finetune_model, model_config, multi_task=False, model_branch="" +): + """Load model_params according to the pretrained one. + + Args: + - ckpt & finetune_model: origin model. + - config: Read from json file. + """ + if multi_task: + # TODO + print("finetune mode need modification for multitask mode!") + if finetune_model is not None: + state_dict = torch.load(finetune_model, map_location=env.DEVICE) + if "model" in state_dict: + state_dict = state_dict["model"] + last_model_params = state_dict["_extra_state"]["model_params"] + finetune_multi_task = "model_dict" in last_model_params + trainable_param = { + "type_embedding": True, + "descriptor": True, + "fitting_net": True, + } + for net_type in trainable_param: + if net_type in model_config: + trainable_param[net_type] = model_config[net_type].get( + "trainable", True + ) + if not finetune_multi_task: + old_type_map, new_type_map = ( + last_model_params["type_map"], + model_config["type_map"], + ) + assert set(new_type_map).issubset( + old_type_map + ), "Only support for smaller type map when finetuning or resuming." + model_config = last_model_params + logging.info( + "Change the model configurations according to the pretrained one..." + ) + model_config["new_type_map"] = new_type_map + else: + model_config["finetune_multi_task"] = finetune_multi_task + model_dict_params = last_model_params["model_dict"] + new_fitting = False + if model_branch == "": + model_branch_chosen = next(iter(model_dict_params.keys())) + new_fitting = True + model_config["bias_shift"] = "statistic" # fitting net re-init + print( + "The fitting net will be re-init instead of using that in the pretrained model! " + "The bias_shift will be statistic!" + ) + else: + model_branch_chosen = model_branch + assert model_branch_chosen in model_dict_params, ( + f"No model branch named '{model_branch_chosen}'! " + f"Available ones are {list(model_dict_params.keys())}." + ) + old_type_map, new_type_map = ( + model_dict_params[model_branch_chosen]["type_map"], + model_config["type_map"], + ) + assert set(new_type_map).issubset( + old_type_map + ), "Only support for smaller type map when finetuning or resuming." + for key_item in ["type_map", "type_embedding", "descriptor"]: + if key_item in model_dict_params[model_branch_chosen]: + model_config[key_item] = model_dict_params[model_branch_chosen][ + key_item + ] + if not new_fitting: + model_config["fitting_net"] = model_dict_params[model_branch_chosen][ + "fitting_net" + ] + logging.info( + f"Change the model configurations according to the model branch " + f"{model_branch_chosen} in the pretrained one..." + ) + model_config["new_type_map"] = new_type_map + model_config["model_branch_chosen"] = model_branch_chosen + model_config["new_fitting"] = new_fitting + for net_type in trainable_param: + if net_type in model_config: + model_config[net_type]["trainable"] = trainable_param[net_type] + else: + model_config[net_type] = {"trainable": trainable_param[net_type]} + return model_config diff --git a/deepmd/pt/utils/learning_rate.py b/deepmd/pt/utils/learning_rate.py new file mode 100644 index 0000000000..eca3c6ad87 --- /dev/null +++ b/deepmd/pt/utils/learning_rate.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + + +class LearningRateExp: + def __init__(self, start_lr, stop_lr, decay_steps, stop_steps, **kwargs): + """Construct an exponential-decayed learning rate. + + Args: + - start_lr: Initial learning rate. + - stop_lr: Learning rate at the last step. + - decay_steps: Decay learning rate every N steps. + - stop_steps: When is the last step. + """ + self.start_lr = start_lr + default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1 + self.decay_steps = decay_steps + if self.decay_steps >= stop_steps: + self.decay_steps = default_ds + self.decay_rate = np.exp( + np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps) + ) + if "decay_rate" in kwargs: + self.decay_rate = kwargs["decay_rate"] + if "min_lr" in kwargs: + self.min_lr = kwargs["min_lr"] + else: + self.min_lr = 3e-10 + + def value(self, step): + """Get the learning rate at the given step.""" + step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps) + if step_lr < self.min_lr: + step_lr = self.min_lr + return step_lr diff --git a/deepmd/pt/utils/multi_task.py b/deepmd/pt/utils/multi_task.py new file mode 100644 index 0000000000..f97a826b03 --- /dev/null +++ b/deepmd/pt/utils/multi_task.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) + +from deepmd.pt.model.descriptor import ( + DescrptDPA1, + DescrptDPA2, + DescrptSeA, +) +from deepmd.pt.model.network.network import ( + TypeEmbedNet, +) +from deepmd.pt.model.task import ( + EnergyFittingNet, + EnergyFittingNetDirect, + FittingNetAttenLcc, +) + + +def preprocess_shared_params(model_config): + """Preprocess the model params for multitask model, and generate the links dict for further sharing. + + Args: + model_config: Model params of multitask model. + + Returns + ------- + model_config: Preprocessed model params of multitask model. + Those string names are replaced with real params in `shared_dict` of model params. + shared_links: Dict of link infos for further sharing. + Each item, whose key must be in `shared_dict`, is a dict with following keys: + - "type": The real class type of this item. + - "links": List of shared settings, each sub-item is a dict with following keys: + - "model_key": Model key in the `model_dict` to share this item. + - "shared_type": Type of this shard item. + - "shared_level": Shared level (int) of this item in this model. + Lower for more params to share, 0 means to share all params in this item. + This list are sorted by "shared_level". + """ + assert "model_dict" in model_config, "only multi-task model can use this method!" + supported_types = ["type_map", "type_embedding", "descriptor", "fitting_net"] + shared_dict = model_config.get("shared_dict", {}) + shared_links = {} + type_map_keys = [] + + def replace_one_item(params_dict, key_type, key_in_dict, suffix="", index=None): + shared_type = key_type + shared_key = key_in_dict + shared_level = 0 + if ":" in key_in_dict: + shared_key = key_in_dict.split(":")[0] + shared_level = int(key_in_dict.split(":")[1]) + assert ( + shared_key in shared_dict + ), f"Appointed {shared_type} {shared_key} are not in the shared_dict! Please check the input params." + if index is None: + params_dict[shared_type] = deepcopy(shared_dict[shared_key]) + else: + params_dict[index] = deepcopy(shared_dict[shared_key]) + if shared_type == "type_map": + if key_in_dict not in type_map_keys: + type_map_keys.append(key_in_dict) + else: + if shared_key not in shared_links: + class_name = get_class_name(shared_type, shared_dict[key_in_dict]) + shared_links[shared_key] = {"type": class_name, "links": []} + link_item = { + "model_key": model_key, + "shared_type": shared_type + suffix, + "shared_level": shared_level, + } + shared_links[shared_key]["links"].append(link_item) + + for model_key in model_config["model_dict"]: + model_params_item = model_config["model_dict"][model_key] + for item_key in model_params_item: + if item_key in supported_types: + item_params = model_params_item[item_key] + if isinstance(item_params, str): + replace_one_item(model_params_item, item_key, item_params) + elif item_params.get("type", "") == "hybrid": + for ii, hybrid_item in enumerate(item_params["list"]): + if isinstance(hybrid_item, str): + replace_one_item( + model_params_item[item_key]["list"], + item_key, + hybrid_item, + suffix=f"_hybrid_{ii}", + index=ii, + ) + for shared_key in shared_links: + shared_links[shared_key]["links"] = sorted( + shared_links[shared_key]["links"], key=lambda x: x["shared_level"] + ) + assert len(type_map_keys) == 1, "Multitask model must have only one type_map!" + return model_config, shared_links + + +def get_class_name(item_key, item_params): + if item_key == "type_embedding": + return TypeEmbedNet.__name__ + elif item_key == "descriptor": + item_type = item_params.get("type", "se_e2_a") + if item_type == "se_e2_a": + return DescrptSeA.__name__ + elif item_type in ["se_atten", "dpa1"]: + return DescrptDPA1.__name__ + elif item_type in ["dpa2"]: + return DescrptDPA2.__name__ + # todo add support for other combination + # elif item_type == "gaussian_lcc": + # return DescrptGaussianLcc.__name__ + # elif item_type == "hybrid": + # return DescrptHybrid.__name__ + else: + raise RuntimeError(f"Unknown descriptor type {item_type}") + elif item_key == "fitting_net": + item_type = item_params.get("type", "ener") + if item_type == "ener": + return EnergyFittingNet.__name__ + elif item_type in ["direct_force", "direct_force_ener"]: + return EnergyFittingNetDirect.__name__ + elif item_type == "atten_vec_lcc": + return FittingNetAttenLcc.__name__ + else: + raise RuntimeError(f"Unknown fitting_net type {item_type}") + else: + raise RuntimeError(f"Unknown class_name type {item_key}") diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py new file mode 100644 index 0000000000..23a11684a5 --- /dev/null +++ b/deepmd/pt/utils/nlist.py @@ -0,0 +1,431 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Union, +) + +import torch + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.region import ( + to_face_distance, +) + + +def _build_neighbor_list( + coord1: torch.Tensor, + nloc: int, + rcut: float, + nsel: int, + rmin: float = 1e-10, + cut_nearest: bool = True, +) -> torch.Tensor: + """Build neightbor list for a single frame. keeps nsel neighbors. + coord1 : [nall x 3]. + + ret: [nloc x nsel] stores indexes of coord1. + """ + nall = coord1.shape[-1] // 3 + coord0 = torch.split(coord1, [nloc * 3, (nall - nloc) * 3])[0] + # nloc x nall x 3 + diff = coord1.view([-1, 3])[None, :, :] - coord0.view([-1, 3])[:, None, :] + assert list(diff.shape) == [nloc, nall, 3] + # nloc x nall + rr = torch.linalg.norm(diff, dim=-1) + rr, nlist = torch.sort(rr, dim=-1) + if cut_nearest: + # nloc x (nall-1) + rr = torch.split(rr, [1, nall - 1], dim=-1)[-1] + nlist = torch.split(nlist, [1, nall - 1], dim=-1)[-1] + # nloc x nsel + nnei = rr.shape[1] + rr = torch.split(rr, [nsel, nnei - nsel], dim=-1)[0] + nlist = torch.split(nlist, [nsel, nnei - nsel], dim=-1)[0] + nlist = nlist.masked_fill((rr > rcut), -1) + return nlist + + +def build_neighbor_list_lower( + coord1: torch.Tensor, + atype: torch.Tensor, + nloc: int, + rcut: float, + sel: Union[int, List[int]], + distinguish_types: bool = True, +) -> torch.Tensor: + """Build neightbor list for a single frame. keeps nsel neighbors. + + Parameters + ---------- + coord1 : torch.Tensor + exptended coordinates of shape [nall x 3] + atype : torch.Tensor + extended atomic types of shape [nall] + nloc : int + number of local atoms. + rcut : float + cut-off radius + sel : int or List[int] + maximal number of neighbors (of each type). + if distinguish_types==True, nsel should be list and + the length of nsel should be equal to number of + types. + distinguish_types : bool + distinguish different types. + + Returns + ------- + neighbor_list : torch.Tensor + Neighbor list of shape [nloc, nsel], the neighbors + are stored in an ascending order. If the number of + neighbors is less than nsel, the positions are masked + with -1. The neighbor list of an atom looks like + |------ nsel ------| + xx xx xx xx -1 -1 -1 + if distinguish_types==True and we have two types + |---- nsel[0] -----| |---- nsel[1] -----| + xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1 + + """ + nall = coord1.shape[0] // 3 + if isinstance(sel, int): + sel = [sel] + nsel = sum(sel) + # nloc x 3 + coord0 = coord1[: nloc * 3] + # nloc x nall x 3 + diff = coord1.view([-1, 3]).unsqueeze(0) - coord0.view([-1, 3]).unsqueeze(1) + assert list(diff.shape) == [nloc, nall, 3] + # nloc x nall + rr = torch.linalg.norm(diff, dim=-1) + rr, nlist = torch.sort(rr, dim=-1) + # nloc x (nall-1) + rr = rr[:, 1:] + nlist = nlist[:, 1:] + # nloc x nsel + nnei = rr.shape[1] + if nsel <= nnei: + rr = rr[:, :nsel] + nlist = nlist[:, :nsel] + else: + rr = torch.cat( + [rr, torch.ones([nloc, nsel - nnei]).to(rr.device) + rcut], dim=-1 + ) + nlist = torch.cat( + [nlist, torch.ones([nloc, nsel - nnei], dtype=torch.long).to(rr.device)], + dim=-1, + ) + assert list(nlist.shape) == [nloc, nsel] + nlist = nlist.masked_fill((rr > rcut), -1) + + if not distinguish_types: + return nlist + else: + ret_nlist = [] + # nloc x nall + tmp_atype = torch.tile(atype.unsqueeze(0), [nloc, 1]) + mask = nlist == -1 + # nloc x s(nsel) + tnlist = torch.gather( + tmp_atype, + 1, + nlist.masked_fill(mask, 0), + ) + tnlist = tnlist.masked_fill(mask, -1) + snsel = tnlist.shape[1] + for ii, ss in enumerate(sel): + # nloc x s(nsel) + # to int because bool cannot be sort on GPU + pick_mask = (tnlist == ii).to(torch.int32) + # nloc x s(nsel), stable sort, nearer neighbors first + pick_mask, imap = torch.sort( + pick_mask, dim=-1, descending=True, stable=True + ) + # nloc x s(nsel) + inlist = torch.gather(nlist, 1, imap) + inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1) + # nloc x nsel[ii] + ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0]) + return torch.concat(ret_nlist, dim=-1) + + +def build_neighbor_list( + coord1: torch.Tensor, + atype: torch.Tensor, + nloc: int, + rcut: float, + sel: Union[int, List[int]], + distinguish_types: bool = True, +) -> torch.Tensor: + """Build neightbor list for a single frame. keeps nsel neighbors. + + Parameters + ---------- + coord1 : torch.Tensor + exptended coordinates of shape [batch_size, nall x 3] + atype : torch.Tensor + extended atomic types of shape [batch_size, nall] + nloc : int + number of local atoms. + rcut : float + cut-off radius + sel : int or List[int] + maximal number of neighbors (of each type). + if distinguish_types==True, nsel should be list and + the length of nsel should be equal to number of + types. + distinguish_types : bool + distinguish different types. + + Returns + ------- + neighbor_list : torch.Tensor + Neighbor list of shape [batch_size, nloc, nsel], the neighbors + are stored in an ascending order. If the number of + neighbors is less than nsel, the positions are masked + with -1. The neighbor list of an atom looks like + |------ nsel ------| + xx xx xx xx -1 -1 -1 + if distinguish_types==True and we have two types + |---- nsel[0] -----| |---- nsel[1] -----| + xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1 + + """ + batch_size = coord1.shape[0] + coord1 = coord1.view(batch_size, -1) + nall = coord1.shape[1] // 3 + if isinstance(sel, int): + sel = [sel] + nsel = sum(sel) + # nloc x 3 + coord0 = coord1[:, : nloc * 3] + # nloc x nall x 3 + diff = coord1.view([batch_size, -1, 3]).unsqueeze(1) - coord0.view( + [batch_size, -1, 3] + ).unsqueeze(2) + assert list(diff.shape) == [batch_size, nloc, nall, 3] + # nloc x nall + rr = torch.linalg.norm(diff, dim=-1) + rr, nlist = torch.sort(rr, dim=-1) + # nloc x (nall-1) + rr = rr[:, :, 1:] + nlist = nlist[:, :, 1:] + # nloc x nsel + nnei = rr.shape[2] + if nsel <= nnei: + rr = rr[:, :, :nsel] + nlist = nlist[:, :, :nsel] + else: + rr = torch.cat( + [rr, torch.ones([batch_size, nloc, nsel - nnei]).to(rr.device) + rcut], + dim=-1, + ) + nlist = torch.cat( + [ + nlist, + torch.ones([batch_size, nloc, nsel - nnei], dtype=torch.long).to( + rr.device + ), + ], + dim=-1, + ) + assert list(nlist.shape) == [batch_size, nloc, nsel] + nlist = nlist.masked_fill((rr > rcut), -1) + + if not distinguish_types: + return nlist + else: + ret_nlist = [] + # nloc x nall + tmp_atype = torch.tile(atype.unsqueeze(1), [1, nloc, 1]) + mask = nlist == -1 + # nloc x s(nsel) + tnlist = torch.gather( + tmp_atype, + 2, + nlist.masked_fill(mask, 0), + ) + tnlist = tnlist.masked_fill(mask, -1) + snsel = tnlist.shape[2] + for ii, ss in enumerate(sel): + # nloc x s(nsel) + # to int because bool cannot be sort on GPU + pick_mask = (tnlist == ii).to(torch.int32) + # nloc x s(nsel), stable sort, nearer neighbors first + pick_mask, imap = torch.sort( + pick_mask, dim=-1, descending=True, stable=True + ) + # nloc x s(nsel) + inlist = torch.gather(nlist, 2, imap) + inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1) + # nloc x nsel[ii] + ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0]) + return torch.concat(ret_nlist, dim=-1) + + +# build_neighbor_list = torch.vmap( +# build_neighbor_list_lower, +# in_dims=(0,0,None,None,None), +# out_dims=(0), +# ) + + +def get_multiple_nlist_key( + rcut: float, + nsel: int, +) -> str: + return str(rcut) + "_" + str(nsel) + + +def build_multiple_neighbor_list( + coord: torch.Tensor, + nlist: torch.Tensor, + rcuts: List[float], + nsels: List[int], +) -> Dict[str, torch.Tensor]: + """Input one neighbor list, and produce multiple neighbor lists with + different cutoff radius and numbers of selection out of it. The + required rcuts and nsels should be smaller or equal to the input nlist. + + Parameters + ---------- + coord : torch.Tensor + exptended coordinates of shape [batch_size, nall x 3] + nlist : torch.Tensor + Neighbor list of shape [batch_size, nloc, nsel], the neighbors + should be stored in an ascending order. + rcuts : List[float] + list of cut-off radius in ascending order. + nsels : List[int] + maximal number of neighbors in ascending order. + + Returns + ------- + nlist_dict : Dict[str, torch.Tensor] + A dict of nlists, key given by get_multiple_nlist_key(rc, nsel) + value being the corresponding nlist. + + """ + assert len(rcuts) == len(nsels) + if len(rcuts) == 0: + return {} + nb, nloc, nsel = nlist.shape + if nsel < nsels[-1]: + pad = -1 * torch.ones( + [nb, nloc, nsels[-1] - nsel], + dtype=nlist.dtype, + device=nlist.device, + ) + # nb x nloc x nsel + nlist = torch.cat([nlist, pad], dim=-1) + nsel = nsels[-1] + # nb x nall x 3 + coord1 = coord.view(nb, -1, 3) + nall = coord1.shape[1] + # nb x nloc x 3 + coord0 = coord1[:, :nloc, :] + nlist_mask = nlist == -1 + # nb x (nloc x nsel) x 3 + index = ( + nlist.masked_fill(nlist_mask, 0) + .view(nb, nloc * nsel) + .unsqueeze(-1) + .expand(-1, -1, 3) + ) + # nb x nloc x nsel x 3 + coord2 = torch.gather(coord1, dim=1, index=index).view(nb, nloc, nsel, 3) + # nb x nloc x nsel x 3 + diff = coord2 - coord0[:, :, None, :] + # nb x nloc x nsel + rr = torch.linalg.norm(diff, dim=-1) + rr.masked_fill(nlist_mask, float("inf")) + nlist0 = nlist + ret = {} + for rc, ns in zip(rcuts[::-1], nsels[::-1]): + nlist0 = nlist0[:, :, :ns].masked_fill(rr[:, :, :ns] > rc, int(-1)) + ret[get_multiple_nlist_key(rc, ns)] = nlist0 + return ret + + +def extend_coord_with_ghosts( + coord: torch.Tensor, + atype: torch.Tensor, + cell: Optional[torch.Tensor], + rcut: float, +): + """Extend the coordinates of the atoms by appending peridoc images. + The number of images is large enough to ensure all the neighbors + within rcut are appended. + + Parameters + ---------- + coord : torch.Tensor + original coordinates of shape [-1, nloc*3]. + atype : torch.Tensor + atom type of shape [-1, nloc]. + cell : torch.Tensor + simulation cell tensor of shape [-1, 9]. + + Returns + ------- + extended_coord: torch.Tensor + extended coordinates of shape [-1, nall*3]. + extended_atype: torch.Tensor + extended atom type of shape [-1, nall]. + index_mapping: torch.Tensor + maping extended index to the local index + + """ + nf, nloc = atype.shape + aidx = torch.tile(torch.arange(nloc).unsqueeze(0), [nf, 1]) + if cell is None: + nall = nloc + extend_coord = coord.clone() + extend_atype = atype.clone() + extend_aidx = aidx.clone() + else: + coord = coord.view([nf, nloc, 3]) + cell = cell.view([nf, 3, 3]) + # nf x 3 + to_face = to_face_distance(cell) + # nf x 3 + # *2: ghost copies on + and - directions + # +1: central cell + nbuff = torch.ceil(rcut / to_face).to(torch.long) + # 3 + nbuff = torch.max(nbuff, dim=0, keepdim=False).values + xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=env.DEVICE) + yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=env.DEVICE) + zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=env.DEVICE) + xyz = xi.view(-1, 1, 1, 1) * torch.tensor( + [1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor( + [0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor( + [0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + xyz = xyz.view(-1, 3) + # ns x 3 + shift_idx = xyz[torch.argsort(torch.norm(xyz, dim=1))] + ns, _ = shift_idx.shape + nall = ns * nloc + # nf x ns x 3 + shift_vec = torch.einsum("sd,fdk->fsk", shift_idx, cell) + # nf x ns x nloc x 3 + extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :] + # nf x ns x nloc + extend_atype = torch.tile(atype.unsqueeze(-2), [1, ns, 1]) + # nf x ns x nloc + extend_aidx = torch.tile(aidx.unsqueeze(-2), [1, ns, 1]) + + return ( + extend_coord.reshape([nf, nall * 3]).to(env.DEVICE), + extend_atype.view([nf, nall]).to(env.DEVICE), + extend_aidx.view([nf, nall]).to(env.DEVICE), + ) diff --git a/deepmd/pt/utils/plugin.py b/deepmd/pt/utils/plugin.py new file mode 100644 index 0000000000..c24f36f574 --- /dev/null +++ b/deepmd/pt/utils/plugin.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Base of plugin systems.""" +from deepmd.utils.plugin import ( + Plugin, + PluginVariant, + VariantABCMeta, + VariantMeta, +) + +__all__ = [ + "Plugin", + "VariantMeta", + "VariantABCMeta", + "PluginVariant", +] diff --git a/deepmd/pt/utils/preprocess.py b/deepmd/pt/utils/preprocess.py new file mode 100644 index 0000000000..463ac112ad --- /dev/null +++ b/deepmd/pt/utils/preprocess.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + Union, +) + +import torch + +from deepmd.pt.utils import ( + env, +) + + +class Region3D: + def __init__(self, boxt): + """Construct a simulation box.""" + boxt = boxt.reshape([3, 3]) + self.boxt = boxt # convert physical coordinates to internal ones + self.rec_boxt = torch.linalg.inv( + self.boxt + ) # convert internal coordinates to physical ones + + self.volume = torch.linalg.det(self.boxt) # compute the volume + + # boxt = boxt.permute(1, 0) + c_yz = torch.cross(boxt[1], boxt[2]) + self._h2yz = self.volume / torch.linalg.norm(c_yz) + c_zx = torch.cross(boxt[2], boxt[0]) + self._h2zx = self.volume / torch.linalg.norm(c_zx) + c_xy = torch.cross(boxt[0], boxt[1]) + self._h2xy = self.volume / torch.linalg.norm(c_xy) + + def phys2inter(self, coord): + """Convert physical coordinates to internal ones.""" + return coord @ self.rec_boxt + + def inter2phys(self, coord): + """Convert internal coordinates to physical ones.""" + return coord @ self.boxt + + def get_face_distance(self): + """Return face distinces to each surface of YZ, ZX, XY.""" + return torch.stack([self._h2yz, self._h2zx, self._h2xy]) + + +def normalize_coord(coord, region: Region3D, nloc: int): + """Move outer atoms into region by mirror. + + Args: + - coord: shape is [nloc*3] + """ + tmp_coord = coord.clone() + inter_cood = torch.remainder(region.phys2inter(tmp_coord), 1.0) + tmp_coord = region.inter2phys(inter_cood) + return tmp_coord + + +def compute_serial_cid(cell_offset, ncell): + """Tell the sequential cell ID in its 3D space. + + Args: + - cell_offset: shape is [3] + - ncell: shape is [3] + """ + cell_offset[:, 0] *= ncell[1] * ncell[2] + cell_offset[:, 1] *= ncell[2] + return cell_offset.sum(-1) + + +def compute_pbc_shift(cell_offset, ncell): + """Tell shift count to move the atom into region.""" + shift = torch.zeros_like(cell_offset) + shift = shift + (cell_offset < 0) * -( + torch.div(cell_offset, ncell, rounding_mode="floor") + ) + shift = shift + (cell_offset >= ncell) * -( + torch.div((cell_offset - ncell), ncell, rounding_mode="floor") + 1 + ) + assert torch.all(cell_offset + shift * ncell >= 0) + assert torch.all(cell_offset + shift * ncell < ncell) + return shift + + +def build_inside_clist(coord, region: Region3D, ncell): + """Build cell list on atoms inside region. + + Args: + - coord: shape is [nloc*3] + - ncell: shape is [3] + """ + loc_ncell = int(torch.prod(ncell)) # num of local cells + nloc = coord.numel() // 3 # num of local atoms + inter_cell_size = 1.0 / ncell + + inter_cood = region.phys2inter(coord.view(-1, 3)) + cell_offset = torch.floor(inter_cood / inter_cell_size).to(torch.long) + # numerical error brought by conversion from phys to inter back and force + # may lead to negative value + cell_offset[cell_offset < 0] = 0 + delta = cell_offset - ncell + a2c = compute_serial_cid(cell_offset, ncell) # cell id of atoms + arange = torch.arange(0, loc_ncell, 1, device=env.PREPROCESS_DEVICE) + cellid = a2c == arange.unsqueeze(-1) # one hot cellid + c2a = cellid.nonzero() + lst = [] + cnt = 0 + bincount = torch.bincount(a2c, minlength=loc_ncell) + for i in range(loc_ncell): + n = bincount[i] + lst.append(c2a[cnt : cnt + n, 1]) + cnt += n + return a2c, lst + + +def append_neighbors(coord, region: Region3D, atype, rcut: float): + """Make ghost atoms who are valid neighbors. + + Args: + - coord: shape is [nloc*3] + - atype: shape is [nloc] + """ + to_face = region.get_face_distance() + + # compute num and size of local cells + ncell = torch.floor(to_face / rcut).to(torch.long) + ncell[ncell == 0] = 1 + cell_size = to_face / ncell + ngcell = ( + torch.floor(rcut / cell_size).to(torch.long) + 1 + ) # num of cells out of local, which contain ghost atoms + + # add ghost atoms + a2c, c2a = build_inside_clist(coord, region, ncell) + xi = torch.arange(-ngcell[0], ncell[0] + ngcell[0], 1, device=env.PREPROCESS_DEVICE) + yi = torch.arange(-ngcell[1], ncell[1] + ngcell[1], 1, device=env.PREPROCESS_DEVICE) + zi = torch.arange(-ngcell[2], ncell[2] + ngcell[2], 1, device=env.PREPROCESS_DEVICE) + xyz = xi.view(-1, 1, 1, 1) * torch.tensor( + [1, 0, 0], dtype=torch.long, device=env.PREPROCESS_DEVICE + ) + xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor( + [0, 1, 0], dtype=torch.long, device=env.PREPROCESS_DEVICE + ) + xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor( + [0, 0, 1], dtype=torch.long, device=env.PREPROCESS_DEVICE + ) + xyz = xyz.view(-1, 3) + mask_a = (xyz >= 0).all(dim=-1) + mask_b = (xyz < ncell).all(dim=-1) + mask = ~torch.logical_and(mask_a, mask_b) + xyz = xyz[mask] # cell coord + shift = compute_pbc_shift(xyz, ncell) + coord_shift = region.inter2phys(shift.to(env.GLOBAL_PT_FLOAT_PRECISION)) + mirrored = shift * ncell + xyz + cid = compute_serial_cid(mirrored, ncell) + + n_atoms = coord.shape[0] + aid = [c2a[ci] + i * n_atoms for i, ci in enumerate(cid)] + aid = torch.cat(aid) + tmp = torch.div(aid, n_atoms, rounding_mode="trunc") + aid = aid % n_atoms + tmp_coord = coord[aid] - coord_shift[tmp] + tmp_atype = atype[aid] + + # merge local and ghost atoms + merged_coord = torch.cat([coord, tmp_coord]) + merged_coord_shift = torch.cat([torch.zeros_like(coord), coord_shift[tmp]]) + merged_atype = torch.cat([atype, tmp_atype]) + merged_mapping = torch.cat( + [torch.arange(atype.numel(), device=env.PREPROCESS_DEVICE), aid] + ) + return merged_coord_shift, merged_atype, merged_mapping + + +def build_neighbor_list( + nloc: int, coord, atype, rcut: float, sec, mapping, type_split=True, min_check=False +): + """For each atom inside region, build its neighbor list. + + Args: + - coord: shape is [nall*3] + - atype: shape is [nall] + """ + nall = coord.numel() // 3 + coord = coord.float() + nlist = [[] for _ in range(nloc)] + coord_l = coord.view(-1, 1, 3)[:nloc] + coord_r = coord.view(1, -1, 3) + distance = coord_l - coord_r + distance = torch.linalg.norm(distance, dim=-1) + DISTANCE_INF = distance.max().detach() + rcut + distance[:nloc, :nloc] += ( + torch.eye(nloc, dtype=torch.bool, device=env.PREPROCESS_DEVICE) * DISTANCE_INF + ) + if min_check: + if distance.min().abs() < 1e-6: + RuntimeError("Atom dist too close!") + if not type_split: + sec = sec[-1:] + lst = [] + nlist = torch.zeros((nloc, sec[-1].item()), device=env.PREPROCESS_DEVICE).long() - 1 + nlist_loc = ( + torch.zeros((nloc, sec[-1].item()), device=env.PREPROCESS_DEVICE).long() - 1 + ) + nlist_type = ( + torch.zeros((nloc, sec[-1].item()), device=env.PREPROCESS_DEVICE).long() - 1 + ) + for i, nnei in enumerate(sec): + if i > 0: + nnei = nnei - sec[i - 1] + if not type_split: + tmp = distance + else: + mask = atype.unsqueeze(0) == i + tmp = distance + (~mask) * DISTANCE_INF + if tmp.shape[1] >= nnei: + _sorted, indices = torch.topk(tmp, nnei, dim=1, largest=False) + else: + # when nnei > nall + indices = torch.zeros((nloc, nnei), device=env.PREPROCESS_DEVICE).long() - 1 + _sorted = ( + torch.ones((nloc, nnei), device=env.PREPROCESS_DEVICE).long() + * DISTANCE_INF + ) + _sorted_nnei, indices_nnei = torch.topk( + tmp, tmp.shape[1], dim=1, largest=False + ) + _sorted[:, : tmp.shape[1]] = _sorted_nnei + indices[:, : tmp.shape[1]] = indices_nnei + mask = (_sorted < rcut).to(torch.long) + indices_loc = mapping[indices] + indices = indices * mask + -1 * (1 - mask) # -1 for padding + indices_loc = indices_loc * mask + -1 * (1 - mask) # -1 for padding + if i == 0: + start = 0 + else: + start = sec[i - 1] + end = min(sec[i], start + indices.shape[1]) + nlist[:, start:end] = indices[:, :nnei] + nlist_loc[:, start:end] = indices_loc[:, :nnei] + nlist_type[:, start:end] = atype[indices[:, :nnei]] * mask + -1 * (1 - mask) + return nlist, nlist_loc, nlist_type + + +def compute_smooth_weight(distance, rmin: float, rmax: float): + """Compute smooth weight for descriptor elements.""" + min_mask = distance <= rmin + max_mask = distance >= rmax + mid_mask = torch.logical_not(torch.logical_or(min_mask, max_mask)) + uu = (distance - rmin) / (rmax - rmin) + vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1 + return vv * mid_mask + min_mask + + +def make_env_mat( + coord, + atype, + region, + rcut: Union[float, list], + sec, + pbc=True, + type_split=True, + min_check=False, +): + """Based on atom coordinates, return environment matrix. + + Returns + ------- + nlist: nlist, [nloc, nnei] + merged_coord_shift: shift on nall atoms, [nall, 3] + merged_mapping: mapping from nall index to nloc index, [nall] + """ + # move outer atoms into cell + hybrid = isinstance(rcut, list) + _rcut = rcut + if hybrid: + _rcut = max(rcut) + if pbc: + merged_coord_shift, merged_atype, merged_mapping = append_neighbors( + coord, region, atype, _rcut + ) + merged_coord = coord[merged_mapping] - merged_coord_shift + if merged_coord.shape[0] <= coord.shape[0]: + logging.warning("No ghost atom is added for system ") + else: + merged_coord_shift = torch.zeros_like(coord) + merged_atype = atype.clone() + merged_mapping = torch.arange(atype.numel(), device=env.PREPROCESS_DEVICE) + merged_coord = coord.clone() + + # build nlist + if not hybrid: + nlist, nlist_loc, nlist_type = build_neighbor_list( + coord.shape[0], + merged_coord, + merged_atype, + rcut, + sec, + merged_mapping, + type_split=type_split, + min_check=min_check, + ) + else: + nlist, nlist_loc, nlist_type = [], [], [] + for ii, single_rcut in enumerate(rcut): + nlist_tmp, nlist_loc_tmp, nlist_type_tmp = build_neighbor_list( + coord.shape[0], + merged_coord, + merged_atype, + single_rcut, + sec[ii], + merged_mapping, + type_split=type_split, + min_check=min_check, + ) + nlist.append(nlist_tmp) + nlist_loc.append(nlist_loc_tmp) + nlist_type.append(nlist_type_tmp) + return nlist, nlist_loc, nlist_type, merged_coord_shift, merged_mapping diff --git a/deepmd/pt/utils/region.py b/deepmd/pt/utils/region.py new file mode 100644 index 0000000000..b07d2f73bf --- /dev/null +++ b/deepmd/pt/utils/region.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + + +def phys2inter( + coord: torch.Tensor, + cell: torch.Tensor, +) -> torch.Tensor: + """Convert physical coordinates to internal(direct) coordinates. + + Parameters + ---------- + coord : torch.Tensor + physical coordinates of shape [*, na, 3]. + cell : torch.Tensor + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + inter_coord: torch.Tensor + the internal coordinates + + """ + rec_cell = torch.linalg.inv(cell) + return torch.matmul(coord, rec_cell) + + +def inter2phys( + coord: torch.Tensor, + cell: torch.Tensor, +) -> torch.Tensor: + """Convert internal(direct) coordinates to physical coordinates. + + Parameters + ---------- + coord : torch.Tensor + internal coordinates of shape [*, na, 3]. + cell : torch.Tensor + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + phys_coord: torch.Tensor + the physical coordinates + + """ + return torch.matmul(coord, cell) + + +def to_face_distance( + cell: torch.Tensor, +) -> torch.Tensor: + """Compute the to-face-distance of the simulation cell. + + Parameters + ---------- + cell : torch.Tensor + simulation cell tensor of shape [*, 3, 3]. + + Returns + ------- + dist: torch.Tensor + the to face distances of shape [*, 3] + + """ + cshape = cell.shape + dist = b_to_face_distance(cell.view([-1, 3, 3])) + return dist.view(list(cshape[:-2]) + [3]) # noqa:RUF005 + + +def _to_face_distance(cell): + volume = torch.linalg.det(cell) + c_yz = torch.cross(cell[1], cell[2]) + _h2yz = volume / torch.linalg.norm(c_yz) + c_zx = torch.cross(cell[2], cell[0]) + _h2zx = volume / torch.linalg.norm(c_zx) + c_xy = torch.cross(cell[0], cell[1]) + _h2xy = volume / torch.linalg.norm(c_xy) + return torch.stack([_h2yz, _h2zx, _h2xy]) + + +def b_to_face_distance(cell): + volume = torch.linalg.det(cell) + c_yz = torch.cross(cell[:, 1], cell[:, 2], dim=-1) + _h2yz = volume / torch.linalg.norm(c_yz, dim=-1) + c_zx = torch.cross(cell[:, 2], cell[:, 0], dim=-1) + _h2zx = volume / torch.linalg.norm(c_zx, dim=-1) + c_xy = torch.cross(cell[:, 0], cell[:, 1], dim=-1) + _h2xy = volume / torch.linalg.norm(c_xy, dim=-1) + return torch.stack([_h2yz, _h2zx, _h2xy], dim=1) + + +# b_to_face_distance = torch.vmap( +# _to_face_distance, in_dims=(0), out_dims=(0)) + + +def normalize_coord( + coord: torch.Tensor, + cell: torch.Tensor, +) -> torch.Tensor: + """Apply PBC according to the atomic coordinates. + + Parameters + ---------- + coord : torch.Tensor + orignal coordinates of shape [*, na, 3]. + + Returns + ------- + wrapped_coord: torch.Tensor + wrapped coordinates of shape [*, na, 3]. + + """ + icoord = phys2inter(coord, cell) + icoord = torch.remainder(icoord, 1.0) + return inter2phys(icoord, cell) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py new file mode 100644 index 0000000000..837a0104f9 --- /dev/null +++ b/deepmd/pt/utils/stat.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging + +import numpy as np +import torch +from tqdm import ( + trange, +) + +from deepmd.pt.utils import ( + env, +) + + +def make_stat_input(datasets, dataloaders, nbatches): + """Pack data for statistics. + + Args: + - dataset: A list of dataset to analyze. + - nbatches: Batch count for collecting stats. + + Returns + ------- + - a list of dicts, each of which contains data from a system + """ + lst = [] + keys = [ + "coord", + "force", + "energy", + "atype", + "box", + "natoms", + "mapping", + "nlist", + "nlist_loc", + "nlist_type", + "shift", + ] + if datasets[0].mixed_type: + keys.append("real_natoms_vec") + logging.info(f"Packing data for statistics from {len(datasets)} systems") + for i in trange(len(datasets), disable=env.DISABLE_TQDM): + sys_stat = {key: [] for key in keys} + iterator = iter(dataloaders[i]) + for _ in range(nbatches): + try: + stat_data = next(iterator) + except StopIteration: + iterator = iter(dataloaders[i]) + stat_data = next(iterator) + for dd in stat_data: + if dd in keys: + sys_stat[dd].append(stat_data[dd]) + for key in keys: + if key == "mapping" or key == "shift": + extend = max(d.shape[1] for d in sys_stat[key]) + for jj in range(len(sys_stat[key])): + l = [] + item = sys_stat[key][jj] + for ii in range(item.shape[0]): + l.append(item[ii]) + n_frames = len(item) + if key == "shift": + shape = torch.zeros( + (n_frames, extend, 3), + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.PREPROCESS_DEVICE, + ) + else: + shape = torch.zeros( + (n_frames, extend), + dtype=torch.long, + device=env.PREPROCESS_DEVICE, + ) + for i in range(len(item)): + natoms_tmp = l[i].shape[0] + shape[i, :natoms_tmp] = l[i] + sys_stat[key][jj] = shape + if not isinstance(sys_stat[key][0], list): + if sys_stat[key][0] is None: + sys_stat[key] = None + else: + sys_stat[key] = torch.cat(sys_stat[key], dim=0) + else: + sys_stat_list = [] + for ii, _ in enumerate(sys_stat[key][0]): + tmp_stat = [x[ii] for x in sys_stat[key]] + sys_stat_list.append(torch.cat(tmp_stat, dim=0)) + sys_stat[key] = sys_stat_list + lst.append(sys_stat) + return lst + + +def compute_output_stats(energy, natoms, rcond=None): + """Update mean and stddev for descriptor elements. + + Args: + - energy: Batched energy with shape [nframes, 1]. + - natoms: Batched atom statisics with shape [self.ntypes+2]. + + Returns + ------- + - energy_coef: Average enery per atom for each element. + """ + for i in range(len(energy)): + energy[i] = energy[i].mean(dim=0, keepdim=True) + natoms[i] = natoms[i].double().mean(dim=0, keepdim=True) + sys_ener = torch.cat(energy).cpu() + sys_tynatom = torch.cat(natoms)[:, 2:].cpu() + energy_coef, _, _, _ = np.linalg.lstsq(sys_tynatom, sys_ener, rcond) + return energy_coef diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py new file mode 100644 index 0000000000..780dbf7e62 --- /dev/null +++ b/deepmd/pt/utils/utils.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, +) + +import torch +import torch.nn.functional as F + + +def get_activation_fn(activation: str) -> Callable: + """Returns the activation function corresponding to `activation`.""" + if activation.lower() == "relu": + return F.relu + elif activation.lower() == "gelu": + return F.gelu + elif activation.lower() == "tanh": + return torch.tanh + elif activation.lower() == "linear" or activation.lower() == "none": + return lambda x: x + else: + raise RuntimeError(f"activation function {activation} not supported") + + +class ActivationFn(torch.nn.Module): + def __init__(self, activation: Optional[str]): + super().__init__() + self.activation: str = activation if activation is not None else "linear" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Returns the tensor after applying activation function corresponding to `activation`.""" + # See jit supported types: https://pytorch.org/docs/stable/jit_language_reference.html#supported-type + + if self.activation.lower() == "relu": + return F.relu(x) + elif self.activation.lower() == "gelu": + return F.gelu(x) + elif self.activation.lower() == "tanh": + return torch.tanh(x) + elif self.activation.lower() == "linear" or self.activation.lower() == "none": + return x + else: + raise RuntimeError(f"activation function {self.activation} not supported") diff --git a/examples/water/dpa2/input_torch.json b/examples/water/dpa2/input_torch.json new file mode 100644 index 0000000000..9d783b35d5 --- /dev/null +++ b/examples/water/dpa2/input_torch.json @@ -0,0 +1,102 @@ +{ + "_comment": "that's all", + "model": { + "type_embedding": { + "neuron": [ + 8 + ], + "tebd_input_mode": "concat" + }, + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa2", + "repinit_rcut": 9.0, + "repinit_rcut_smth": 8.0, + "repinit_nsel": 120, + "repformer_rcut": 4.0, + "repformer_rcut_smth": 3.5, + "repformer_nsel": 40, + "repinit_neuron": [ + 25, + 50, + 100 + ], + "repinit_axis_neuron": 12, + "repinit_activation": "tanh", + "repformer_nlayers": 12, + "repformer_g1_dim": 128, + "repformer_g2_dim": 32, + "repformer_attn2_hidden": 32, + "repformer_attn2_nhead": 4, + "repformer_attn1_hidden": 128, + "repformer_attn1_nhead": 4, + "repformer_axis_dim": 4, + "repformer_update_h2": false, + "repformer_update_g1_has_conv": true, + "repformer_update_g1_has_grrg": true, + "repformer_update_g1_has_drrd": true, + "repformer_update_g1_has_attn": true, + "repformer_update_g2_has_g1g1": true, + "repformer_update_g2_has_attn": true, + "repformer_attn2_has_gate": true, + "repformer_add_type_ebd_to_seq": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.0002, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000, + "_comment": "that's all" + } +} diff --git a/examples/water/se_atten/input_torch.json b/examples/water/se_atten/input_torch.json new file mode 100644 index 0000000000..7da3d64164 --- /dev/null +++ b/examples/water/se_atten/input_torch.json @@ -0,0 +1,91 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa1", + "sel": 120, + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [ + 25, + 50, + 100 + ], + "axis_neuron": 16, + "attn": 128, + "attn_layer": 2, + "attn_dotr": true, + "attn_mask": false, + "post_ln": true, + "ffn": false, + "ffn_embed_dim": 1024, + "activation": "tanh", + "scaling_factor": 1.0, + "head_num": 1, + "normalize": true, + "temperature": 1.0 + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment": "that's all" + }, + "wandb_config": { + "wandb_enabled": false, + "entity": "dp_model_engineering", + "project": "DPA" + }, + "numb_steps": 1000000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000, + "_comment": "that's all" + } +} diff --git a/examples/water/se_e2_a/input_torch.json b/examples/water/se_e2_a/input_torch.json new file mode 100644 index 0000000000..053a721a44 --- /dev/null +++ b/examples/water/se_e2_a/input_torch.json @@ -0,0 +1,79 @@ +{ + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "se_e2_a", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1, + "_comment": " that's all" + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "data_stat_nbatch": 20, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "_comment": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment": "that's all" + }, + "numb_steps": 100000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 10000, + "_comment": "that's all" + }, + "_comment": "that's all" +} diff --git a/source/install/docker/Dockerfile b/source/install/docker/Dockerfile index 26b7be9f19..793272ae6a 100644 --- a/source/install/docker/Dockerfile +++ b/source/install/docker/Dockerfile @@ -6,7 +6,7 @@ RUN python -m venv /opt/deepmd-kit ENV PATH="/opt/deepmd-kit/bin:$PATH" # Install package COPY dist /dist -RUN pip install "$(ls /dist/deepmd_kit${VARIANT}-*manylinux*_x86_64.whl)[gpu,cu${CUDA_VERSION},lmp,ipi]" \ +RUN pip install "$(ls /dist/deepmd_kit${VARIANT}-*manylinux*_x86_64.whl)[gpu,cu${CUDA_VERSION},lmp,ipi,torch]" \ && dp -h \ && lmp -h \ && dp_ipi \ diff --git a/source/tests/pt/__init__.py b/source/tests/pt/__init__.py new file mode 100644 index 0000000000..fdbdd73f79 --- /dev/null +++ b/source/tests/pt/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) diff --git a/source/tests/pt/models/dpa1.json b/source/tests/pt/models/dpa1.json new file mode 100644 index 0000000000..dd838ac692 --- /dev/null +++ b/source/tests/pt/models/dpa1.json @@ -0,0 +1,39 @@ +{ + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "se_atten", + "sel": 30, + "rcut_smth": 2.0, + "rcut": 6.0, + "neuron": [ + 2, + 4, + 8 + ], + "axis_neuron": 4, + "attn": 5, + "attn_layer": 2, + "attn_dotr": true, + "attn_mask": false, + "post_ln": true, + "ffn": false, + "ffn_embed_dim": 10, + "activation": "tanh", + "scaling_factor": 1.0, + "head_num": 1, + "normalize": true, + "temperature": 1.0 + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1 + } +} diff --git a/source/tests/pt/models/dpa1.pth b/source/tests/pt/models/dpa1.pth new file mode 100644 index 0000000000..75acf2fa15 Binary files /dev/null and b/source/tests/pt/models/dpa1.pth differ diff --git a/source/tests/pt/models/dpa2.json b/source/tests/pt/models/dpa2.json new file mode 100644 index 0000000000..8b9c735851 --- /dev/null +++ b/source/tests/pt/models/dpa2.json @@ -0,0 +1,48 @@ +{ + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa2", + "repinit_rcut": 6.0, + "repinit_rcut_smth": 2.0, + "repinit_nsel": 30, + "repformer_rcut": 4.0, + "repformer_rcut_smth": 0.5, + "repformer_nsel": 10, + "repinit_neuron": [ + 2, + 4, + 8 + ], + "repinit_axis_neuron": 4, + "repinit_activation": "tanh", + "repformer_nlayers": 12, + "repformer_g1_dim": 8, + "repformer_g2_dim": 5, + "repformer_attn2_hidden": 3, + "repformer_attn2_nhead": 1, + "repformer_attn1_hidden": 5, + "repformer_attn1_nhead": 1, + "repformer_axis_dim": 4, + "repformer_update_h2": false, + "repformer_update_g1_has_conv": true, + "repformer_update_g1_has_grrg": true, + "repformer_update_g1_has_drrd": true, + "repformer_update_g1_has_attn": true, + "repformer_update_g2_has_g1g1": true, + "repformer_update_g2_has_attn": true, + "repformer_attn2_has_gate": true, + "repformer_add_type_ebd_to_seq": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1 + } +} diff --git a/source/tests/pt/models/dpa2.pth b/source/tests/pt/models/dpa2.pth new file mode 100644 index 0000000000..0559d30c48 Binary files /dev/null and b/source/tests/pt/models/dpa2.pth differ diff --git a/source/tests/pt/models/dpa2_hyb.json b/source/tests/pt/models/dpa2_hyb.json new file mode 100644 index 0000000000..b5d53b0246 --- /dev/null +++ b/source/tests/pt/models/dpa2_hyb.json @@ -0,0 +1,69 @@ +{ + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "hybrid", + "hybrid_mode": "sequential", + "list": [ + { + "type": "se_atten", + "sel": 30, + "rcut_smth": 2.0, + "rcut": 6.0, + "neuron": [ + 2, + 4, + 8 + ], + "axis_neuron": 4, + "attn": 5, + "attn_layer": 0, + "attn_dotr": true, + "attn_mask": false, + "post_ln": true, + "ffn": false, + "ffn_embed_dim": 10, + "activation": "tanh", + "scaling_factor": 1.0, + "head_num": 1, + "normalize": true, + "temperature": 1.0 + }, + { + "type": "se_uni", + "sel": 10, + "rcut_smth": 0.5, + "rcut": 4.0, + "nlayers": 12, + "g1_dim": 8, + "g2_dim": 5, + "attn2_hidden": 3, + "attn2_nhead": 1, + "attn1_hidden": 5, + "attn1_nhead": 1, + "axis_dim": 4, + "update_h2": false, + "update_g1_has_conv": true, + "update_g1_has_grrg": true, + "update_g1_has_drrd": true, + "update_g1_has_attn": true, + "update_g2_has_g1g1": true, + "update_g2_has_attn": true, + "attn2_has_gate": true, + "add_type_ebd_to_seq": false, + "smooth": true + } + ] + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1 + } +} diff --git a/source/tests/pt/models/dpa2_tebd.pth b/source/tests/pt/models/dpa2_tebd.pth new file mode 100644 index 0000000000..3d4fc5511c Binary files /dev/null and b/source/tests/pt/models/dpa2_tebd.pth differ diff --git a/source/tests/pt/requirements.txt b/source/tests/pt/requirements.txt new file mode 100644 index 0000000000..74abad719e --- /dev/null +++ b/source/tests/pt/requirements.txt @@ -0,0 +1,6 @@ +tensorflow>=2.14.0 +deepmd-kit>=2.2.7 +dpdata +ase +coverage +pytest diff --git a/source/tests/pt/test_LKF.py b/source/tests/pt/test_LKF.py new file mode 100644 index 0000000000..33aeac7f4f --- /dev/null +++ b/source/tests/pt/test_LKF.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest +from pathlib import ( + Path, +) + +from deepmd.pt.entrypoints.main import ( + main, +) + + +class TestLKF(unittest.TestCase): + def test_lkf(self): + with open(str(Path(__file__).parent / "water/lkf.json")) as fin: + content = fin.read() + self.config = json.loads(content) + self.config["training"]["training_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/data_0") + ] + self.config["training"]["validation_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/data_0") + ] + self.input_json = "test_lkf.json" + with open(self.input_json, "w") as fp: + json.dump(self.config, fp, indent=4) + main(["train", self.input_json]) + + def tearDown(self): + os.remove(self.input_json) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_autodiff.py b/source/tests/pt/test_autodiff.py new file mode 100644 index 0000000000..4f303a8bb3 --- /dev/null +++ b/source/tests/pt/test_autodiff.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import numpy as np +import torch + +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) + +dtype = torch.float64 + +from .test_permutation import ( + eval_model, + make_sample, + model_dpa1, + model_dpa2, + model_se_e2_a, +) + + +# from deepmd-kit repo +def finite_difference(f, x, delta=1e-6): + in_shape = x.shape + y0 = f(x) + out_shape = y0.shape + res = np.empty(out_shape + in_shape) + for idx in np.ndindex(*in_shape): + diff = np.zeros(in_shape) + diff[idx] += delta + y1p = f(x + diff) + y1n = f(x - diff) + res[(Ellipsis, *idx)] = (y1p - y1n) / (2 * delta) + return res + + +def stretch_box(old_coord, old_box, new_box): + ocoord = old_coord.reshape(-1, 3) + obox = old_box.reshape(3, 3) + nbox = new_box.reshape(3, 3) + ncoord = ocoord @ np.linalg.inv(obox) @ nbox + return ncoord.reshape(old_coord.shape) + + +class ForceTest: + def test( + self, + ): + places = 8 + delta = 1e-5 + natoms = 5 + cell = torch.rand([3, 3], dtype=dtype) + cell = (cell + cell.T) + 5.0 * torch.eye(3) + coord = torch.rand([natoms, 3], dtype=dtype) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]) + # assumes input to be numpy tensor + coord = coord.numpy() + + def np_infer( + coord, + ): + e0, f0, v0 = eval_model( + self.model, torch.tensor(coord).unsqueeze(0), cell.unsqueeze(0), atype + ) + ret = { + "energy": e0.squeeze(0), + "force": f0.squeeze(0), + "virial": v0.squeeze(0), + } + # detach + ret = {kk: ret[kk].detach().cpu().numpy() for kk in ret} + return ret + + def ff(_coord): + return np_infer(_coord)["energy"] + + fdf = -finite_difference(ff, coord, delta=delta).squeeze() + rff = np_infer(coord)["force"] + np.testing.assert_almost_equal(fdf, rff, decimal=places) + + +class VirialTest: + def test( + self, + ): + places = 8 + delta = 1e-4 + natoms = 5 + cell = torch.rand([3, 3], dtype=dtype) + cell = (cell) + 5.0 * torch.eye(3) + coord = torch.rand([natoms, 3], dtype=dtype) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]) + # assumes input to be numpy tensor + coord = coord.numpy() + cell = cell.numpy() + + def np_infer( + new_cell, + ): + e0, f0, v0 = eval_model( + self.model, + torch.tensor(stretch_box(coord, cell, new_cell)).unsqueeze(0), + torch.tensor(new_cell).unsqueeze(0), + atype, + ) + ret = { + "energy": e0.squeeze(0), + "force": f0.squeeze(0), + "virial": v0.squeeze(0), + } + # detach + ret = {kk: ret[kk].detach().cpu().numpy() for kk in ret} + return ret + + def ff(bb): + return np_infer(bb)["energy"] + + fdv = -( + finite_difference(ff, cell, delta=delta).transpose(0, 2, 1) @ cell + ).squeeze() + rfv = np_infer(cell)["virial"] + np.testing.assert_almost_equal(fdv, rfv, decimal=places) + + +class TestEnergyModelSeAForce(unittest.TestCase, ForceTest): + def setUp(self): + model_params = copy.deepcopy(model_se_e2_a) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelSeAVirial(unittest.TestCase, VirialTest): + def setUp(self): + model_params = copy.deepcopy(model_se_e2_a) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelDPA1Force(unittest.TestCase, ForceTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelDPA1Virial(unittest.TestCase, VirialTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelDPA2Force(unittest.TestCase, ForceTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelDPAUniVirial(unittest.TestCase, VirialTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) diff --git a/source/tests/pt/test_calculator.py b/source/tests/pt/test_calculator.py new file mode 100644 index 0000000000..e8382b22b8 --- /dev/null +++ b/source/tests/pt/test_calculator.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import torch + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.utils.ase_calc import ( + DPCalculator, +) + +dtype = torch.float64 + + +class TestCalculator(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single") + ] + self.input_json = "test_dp_test.json" + with open(self.input_json, "w") as fp: + json.dump(self.config, fp, indent=4) + + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + + input_dict, label_dict, _ = trainer.get_data(is_train=False) + _, _, more_loss = trainer.wrapper(**input_dict, label=label_dict, cur_lr=1.0) + + self.calculator = DPCalculator("model.pt") + + def test_calculator(self): + from ase import ( + Atoms, + ) + + natoms = 5 + cell = torch.eye(3, dtype=dtype) * 10 + coord = torch.rand([natoms, 3], dtype=dtype) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]) + atomic_numbers = [1, 1, 1, 8, 8] + idx_perm = [1, 0, 4, 3, 2] + + prec = 1e-10 + low_prec = 1e-4 + + ase_atoms0 = Atoms( + numbers=atomic_numbers, + positions=coord, + # positions=[tuple(item) for item in coordinate], + cell=cell, + calculator=self.calculator, + ) + e0, f0 = ase_atoms0.get_potential_energy(), ase_atoms0.get_forces() + s0, v0 = ( + ase_atoms0.get_stress(voigt=True), + -ase_atoms0.get_stress(voigt=False) * ase_atoms0.get_volume(), + ) + + ase_atoms1 = Atoms( + numbers=[atomic_numbers[i] for i in idx_perm], + positions=coord[idx_perm, :], + # positions=[tuple(item) for item in coordinate], + cell=cell, + calculator=self.calculator, + ) + e1, f1 = ase_atoms1.get_potential_energy(), ase_atoms1.get_forces() + s1, v1 = ( + ase_atoms1.get_stress(voigt=True), + -ase_atoms1.get_stress(voigt=False) * ase_atoms1.get_volume(), + ) + + assert isinstance(e0, float) + assert f0.shape == (natoms, 3) + assert v0.shape == (3, 3) + torch.testing.assert_close(e0, e1, rtol=low_prec, atol=prec) + torch.testing.assert_close(f0[idx_perm, :], f1, rtol=low_prec, atol=prec) + torch.testing.assert_close(s0, s1, rtol=low_prec, atol=prec) + torch.testing.assert_close(v0, v1, rtol=low_prec, atol=prec) diff --git a/source/tests/pt/test_deeppot.py b/source/tests/pt/test_deeppot.py new file mode 100644 index 0000000000..7f3ecf7d1b --- /dev/null +++ b/source/tests/pt/test_deeppot.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import numpy as np + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.infer.deep_eval import ( + DeepPot, +) + + +class TestDeepPot(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.config["training"]["training_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single") + ] + self.config["training"]["validation_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single") + ] + self.input_json = "test_dp_test.json" + with open(self.input_json, "w") as fp: + json.dump(self.config, fp, indent=4) + + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + + input_dict, label_dict, _ = trainer.get_data(is_train=False) + trainer.wrapper(**input_dict, label=label_dict, cur_lr=1.0) + self.model = "model.pt" + + def test_dp_test(self): + dp = DeepPot(str(self.model)) + cell = np.array( + [ + 5.122106549439247480e00, + 4.016537340154059388e-01, + 6.951654033828678081e-01, + 4.016537340154059388e-01, + 6.112136112297989143e00, + 8.178091365465004481e-01, + 6.951654033828678081e-01, + 8.178091365465004481e-01, + 6.159552512682983760e00, + ] + ).reshape(1, 3, 3) + coord = np.array( + [ + 2.978060152121375648e00, + 3.588469695887098077e00, + 2.792459820604495491e00, + 3.895592322591093115e00, + 2.712091020667753760e00, + 1.366836847133650501e00, + 9.955616170888935690e-01, + 4.121324820711413039e00, + 1.817239061889086571e00, + 3.553661462345699906e00, + 5.313046969500791583e00, + 6.635182659098815883e00, + 6.088601018589653080e00, + 6.575011420004332585e00, + 6.825240650611076099e00, + ] + ).reshape(1, -1, 3) + atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1) + + e, f, v, ae, av = dp.eval(coord, cell, atype, atomic=True) diff --git a/source/tests/pt/test_descriptor.py b/source/tests/pt/test_descriptor.py new file mode 100644 index 0000000000..da38cf007f --- /dev/null +++ b/source/tests/pt/test_descriptor.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +import unittest + +import numpy as np +import tensorflow.compat.v1 as tf +import torch + +tf.disable_eager_execution() + +import json +from pathlib import ( + Path, +) + +from deepmd.pt.model.descriptor import ( + prod_env_mat_se_a, +) +from deepmd.pt.utils import ( + dp_random, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSet, +) +from deepmd.pt.utils.env import ( + DEVICE, + GLOBAL_NP_FLOAT_PRECISION, + GLOBAL_PT_FLOAT_PRECISION, +) +from deepmd.tf.common import ( + expand_sys_str, +) +from deepmd.tf.env import ( + op_module, +) + +CUR_DIR = os.path.dirname(__file__) + + +def base_se_a(rcut, rcut_smth, sel, batch, mean, stddev): + g = tf.Graph() + with g.as_default(): + coord = tf.placeholder(GLOBAL_NP_FLOAT_PRECISION, [None, None]) + box = tf.placeholder(GLOBAL_NP_FLOAT_PRECISION, [None, None]) + atype = tf.placeholder(tf.int32, [None, None]) + natoms_vec = tf.placeholder(tf.int32, [None]) + default_mesh = tf.placeholder(tf.int32, [None]) + stat_descrpt, descrpt_deriv, rij, nlist = op_module.prod_env_mat_a( + coord, + atype, + natoms_vec, + box, + default_mesh, + tf.constant(mean), + tf.constant(stddev), + rcut_a=-1.0, + rcut_r=rcut, + rcut_r_smth=rcut_smth, + sel_a=sel, + sel_r=[0 for i in sel], + ) + + net_deriv_reshape = tf.ones_like(stat_descrpt) + force = op_module.prod_force_se_a( + net_deriv_reshape, + descrpt_deriv, + nlist, + natoms_vec, + n_a_sel=sum(sel), + n_r_sel=0, + ) + + with tf.Session(graph=g) as sess: + y = sess.run( + [stat_descrpt, force, nlist], + feed_dict={ + coord: batch["coord"], + box: batch["box"], + natoms_vec: batch["natoms"], + atype: batch["atype"], + default_mesh: np.array([0, 0, 0, 2, 2, 2]), + }, + ) + tf.reset_default_graph() + return y + + +class TestSeA(unittest.TestCase): + def setUp(self): + dp_random.seed(20) + with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: + content = fin.read() + config = json.loads(content) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + config["training"]["training_data"]["systems"] = data_file + config["training"]["validation_data"]["systems"] = data_file + model_config = config["model"] + self.rcut = model_config["descriptor"]["rcut"] + self.rcut_smth = model_config["descriptor"]["rcut_smth"] + self.sel = model_config["descriptor"]["sel"] + self.bsz = config["training"]["training_data"]["batch_size"] + self.systems = config["training"]["validation_data"]["systems"] + if isinstance(self.systems, str): + self.systems = expand_sys_str(self.systems) + ds = DeepmdDataSet( + self.systems, self.bsz, model_config["type_map"], self.rcut, self.sel + ) + self.np_batch, self.pt_batch = ds.get_batch() + self.sec = np.cumsum(self.sel) + self.ntypes = len(self.sel) + self.nnei = sum(self.sel) + + def test_consistency(self): + avg_zero = torch.zeros( + [self.ntypes, self.nnei * 4], dtype=GLOBAL_PT_FLOAT_PRECISION + ) + std_ones = torch.ones( + [self.ntypes, self.nnei * 4], dtype=GLOBAL_PT_FLOAT_PRECISION + ) + base_d, base_force, nlist = base_se_a( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=self.sel, + batch=self.np_batch, + mean=avg_zero, + stddev=std_ones, + ) + + pt_coord = self.pt_batch["coord"] + pt_coord.requires_grad_(True) + index = self.pt_batch["mapping"].unsqueeze(-1).expand(-1, -1, 3) + extended_coord = torch.gather(pt_coord, dim=1, index=index) + extended_coord = extended_coord - self.pt_batch["shift"] + my_d, _, _ = prod_env_mat_se_a( + extended_coord.to(DEVICE), + self.pt_batch["nlist"], + self.pt_batch["atype"], + avg_zero.reshape([-1, self.nnei, 4]).to(DEVICE), + std_ones.reshape([-1, self.nnei, 4]).to(DEVICE), + self.rcut, + self.rcut_smth, + ) + my_d.sum().backward() + bsz = pt_coord.shape[0] + my_force = pt_coord.grad.view(bsz, -1, 3).cpu().detach().numpy() + base_force = base_force.reshape(bsz, -1, 3) + base_d = base_d.reshape(bsz, -1, self.nnei, 4) + my_d = my_d.view(bsz, -1, self.nnei, 4).cpu().detach().numpy() + nlist = nlist.reshape(bsz, -1, self.nnei) + + mapping = self.pt_batch["mapping"].cpu() + my_nlist = self.pt_batch["nlist"].view(bsz, -1).cpu() + mask = my_nlist == -1 + my_nlist = my_nlist * ~mask + my_nlist = torch.gather(mapping, dim=-1, index=my_nlist) + my_nlist = my_nlist * ~mask - mask.long() + my_nlist = my_nlist.cpu().view(bsz, -1, self.nnei).numpy() + self.assertTrue(np.allclose(nlist, my_nlist)) + self.assertTrue(np.allclose(np.mean(base_d, axis=2), np.mean(my_d, axis=2))) + self.assertTrue(np.allclose(np.std(base_d, axis=2), np.std(my_d, axis=2))) + # descriptors may be different when there are multiple neighbors in the same distance + self.assertTrue(np.allclose(base_force, -my_force)) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_descriptor_dpa1.py b/source/tests/pt/test_descriptor_dpa1.py new file mode 100644 index 0000000000..689fa7e49c --- /dev/null +++ b/source/tests/pt/test_descriptor_dpa1.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest +from pathlib import ( + Path, +) + +import torch + +from deepmd.pt.model.descriptor import ( + DescrptBlockSeAtten, + DescrptDPA1, +) +from deepmd.pt.model.network.network import ( + TypeEmbedNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.pt.utils.region import ( + normalize_coord, +) + +dtype = torch.float64 +torch.set_default_dtype(dtype) + +CUR_DIR = os.path.dirname(__file__) + + +class TestDPA1(unittest.TestCase): + def setUp(self): + cell = [ + 5.122106549439247480e00, + 4.016537340154059388e-01, + 6.951654033828678081e-01, + 4.016537340154059388e-01, + 6.112136112297989143e00, + 8.178091365465004481e-01, + 6.951654033828678081e-01, + 8.178091365465004481e-01, + 6.159552512682983760e00, + ] + self.cell = torch.Tensor(cell).view(1, 3, 3).to(env.DEVICE) + coord = [ + 2.978060152121375648e00, + 3.588469695887098077e00, + 2.792459820604495491e00, + 3.895592322591093115e00, + 2.712091020667753760e00, + 1.366836847133650501e00, + 9.955616170888935690e-01, + 4.121324820711413039e00, + 1.817239061889086571e00, + 3.553661462345699906e00, + 5.313046969500791583e00, + 6.635182659098815883e00, + 6.088601018589653080e00, + 6.575011420004332585e00, + 6.825240650611076099e00, + ] + self.coord = torch.Tensor(coord).view(1, -1, 3).to(env.DEVICE) + self.atype = torch.IntTensor([0, 0, 0, 1, 1]).view(1, -1).to(env.DEVICE) + self.ref_d = torch.Tensor( + [ + 8.382518544113587780e-03, + -3.390120566088597812e-03, + 6.145981571114964362e-03, + -4.880300873973819273e-03, + -3.390120566088597812e-03, + 1.372540996564941464e-03, + -2.484163690574096341e-03, + 1.972313058658722688e-03, + 6.145981571114964362e-03, + -2.484163690574096341e-03, + 4.507748738021747671e-03, + -3.579717194906019764e-03, + -4.880300873973819273e-03, + 1.972313058658722688e-03, + -3.579717194906019764e-03, + 2.842794615687799838e-03, + 6.733043802494966066e-04, + -2.721540313345096771e-04, + 4.936158526085561134e-04, + -3.919743287822345223e-04, + -1.311123004527576900e-02, + 5.301179352601203924e-03, + -9.614612349318877454e-03, + 7.634884975521277241e-03, + 8.877088452901006621e-03, + -3.590945566653638409e-03, + 6.508042782015627942e-03, + -5.167671664327699171e-03, + -2.697241463040870365e-03, + 1.091350446825975137e-03, + -1.976895708961905022e-03, + 1.569671412121975348e-03, + 8.645131636261189911e-03, + -3.557395265621639355e-03, + 6.298048561552698106e-03, + -4.999272007935521948e-03, + -3.557395265621639355e-03, + 1.467866637220284964e-03, + -2.587004431651147504e-03, + 2.052752235601402672e-03, + 6.298048561552698106e-03, + -2.587004431651147504e-03, + 4.594085551315935101e-03, + -3.647656549789176847e-03, + -4.999272007935521948e-03, + 2.052752235601402672e-03, + -3.647656549789176847e-03, + 2.896359275520481256e-03, + 6.689620176492027878e-04, + -2.753606422414641049e-04, + 4.864958810186969444e-04, + -3.860599754167503119e-04, + -1.349238259226558101e-02, + 5.547478630961994242e-03, + -9.835472300819447095e-03, + 7.808197926069362048e-03, + 9.220744348752592245e-03, + -3.795799103392961601e-03, + 6.716516319358462918e-03, + -5.331265718473574867e-03, + -2.783836698392940304e-03, + 1.147461939123531121e-03, + -2.025013030986024063e-03, + 1.606944814423778541e-03, + 9.280385723343491378e-03, + -3.515852178447095942e-03, + 7.085282215778941628e-03, + -5.675852414643783178e-03, + -3.515852178447095942e-03, + 1.337760635271160884e-03, + -2.679428786337713451e-03, + 2.145400621815936413e-03, + 7.085282215778941628e-03, + -2.679428786337713451e-03, + 5.414439648102228192e-03, + -4.338426468139268931e-03, + -5.675852414643783178e-03, + 2.145400621815936413e-03, + -4.338426468139268931e-03, + 3.476467482674507146e-03, + 7.166961981167455130e-04, + -2.697932188839837972e-04, + 5.474643906631899504e-04, + -4.386556623669893621e-04, + -1.480434821331240956e-02, + 5.604647062899507579e-03, + -1.130745349141585449e-02, + 9.059113563516829268e-03, + 9.758791063112262978e-03, + -3.701477720487638626e-03, + 7.448215522796466058e-03, + -5.966057584545172120e-03, + -2.845102393948158344e-03, + 1.078743584169829543e-03, + -2.170093031447992756e-03, + 1.738010461687942770e-03, + 9.867599071916231118e-03, + -3.811041717688905522e-03, + 7.121877634386481262e-03, + -5.703120290113914553e-03, + -3.811041717688905522e-03, + 1.474046183772771213e-03, + -2.747386907428428938e-03, + 2.199711055637492037e-03, + 7.121877634386481262e-03, + -2.747386907428428938e-03, + 5.145050639440944609e-03, + -4.120642824501622239e-03, + -5.703120290113914553e-03, + 2.199711055637492037e-03, + -4.120642824501622239e-03, + 3.300262321758350853e-03, + 1.370499995344566383e-03, + -5.313041843655797901e-04, + 9.860110343046961986e-04, + -7.892505817954784597e-04, + -1.507686316307561489e-02, + 5.818961290579217904e-03, + -1.088774506142304276e-02, + 8.719460408506790952e-03, + 9.764630842803939323e-03, + -3.770134041110058572e-03, + 7.049438389985595785e-03, + -5.645302934019884485e-03, + -3.533582373572779437e-03, + 1.367148320603491559e-03, + -2.546602904764623705e-03, + 2.038882844528267305e-03, + 7.448297038731285964e-03, + -2.924276815200288742e-03, + 5.355960540523636154e-03, + -4.280386435083473329e-03, + -2.924276815200288742e-03, + 1.150311064893848757e-03, + -2.100635980860638373e-03, + 1.678427895009850001e-03, + 5.355960540523636154e-03, + -2.100635980860638373e-03, + 3.853607053247790071e-03, + -3.080076301871465493e-03, + -4.280386435083473329e-03, + 1.678427895009850001e-03, + -3.080076301871465493e-03, + 2.461876613756722523e-03, + 9.730712866459405395e-04, + -3.821759579990726546e-04, + 6.994242056622360787e-04, + -5.589662297882965055e-04, + -1.138916742131982317e-02, + 4.469391132927387489e-03, + -8.192016282448397885e-03, + 6.547234460517113892e-03, + 7.460070829043288082e-03, + -2.929867802018087421e-03, + 5.363646855497249989e-03, + -4.286347242903034739e-03, + -2.643569023340565718e-03, + 1.038826463247002245e-03, + -1.899910089750410976e-03, + 1.518237240362583541e-03, + ] + ).to(env.DEVICE) + with open(Path(CUR_DIR) / "models" / "dpa1.json") as fp: + self.model_json = json.load(fp) + self.file_model_param = Path(CUR_DIR) / "models" / "dpa1.pth" + self.file_type_embed = Path(CUR_DIR) / "models" / "dpa2_tebd.pth" + + def test_descriptor_block(self): + # torch.manual_seed(0) + model_dpa1 = self.model_json + dparams = model_dpa1["descriptor"] + ntypes = len(model_dpa1["type_map"]) + assert "se_atten" == dparams.pop("type") + dparams["ntypes"] = ntypes + des = DescrptBlockSeAtten( + **dparams, + ) + des.load_state_dict(torch.load(self.file_model_param)) + rcut = dparams["rcut"] + nsel = dparams["sel"] + coord = self.coord + atype = self.atype + box = self.cell + nf, nloc = coord.shape[:2] + coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, rcut + ) + # single nlist + nlist = build_neighbor_list( + extended_coord, extended_atype, nloc, rcut, nsel, distinguish_types=False + ) + # handel type_embedding + type_embedding = TypeEmbedNet(ntypes, 8) + type_embedding.load_state_dict(torch.load(self.file_type_embed)) + + ## to save model parameters + # torch.save(des.state_dict(), 'model_weights.pth') + # torch.save(type_embedding.state_dict(), 'model_weights.pth') + descriptor, env_mat, diff, rot_mat, sw = des( + nlist, + extended_coord, + extended_atype, + type_embedding(extended_atype), + mapping=None, + ) + # np.savetxt('tmp.out', descriptor.detach().numpy().reshape(1,-1), delimiter=",") + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + self.assertAlmostEqual(6.0, des.get_rcut()) + self.assertEqual(30, des.get_nsel()) + self.assertEqual(2, des.get_ntype()) + torch.testing.assert_close( + descriptor.view(-1), self.ref_d, atol=1e-10, rtol=1e-10 + ) + + def test_descriptor(self): + with open(Path(CUR_DIR) / "models" / "dpa1.json") as fp: + self.model_json = json.load(fp) + model_dpa2 = self.model_json + ntypes = len(model_dpa2["type_map"]) + dparams = model_dpa2["descriptor"] + dparams["ntypes"] = ntypes + assert dparams.pop("type") == "se_atten" + dparams["concat_output_tebd"] = False + des = DescrptDPA1( + **dparams, + ) + target_dict = des.state_dict() + source_dict = torch.load(self.file_model_param) + type_embd_dict = torch.load(self.file_type_embed) + target_dict = translate_se_atten_and_type_embd_dicts_to_dpa1( + target_dict, + source_dict, + type_embd_dict, + ) + des.load_state_dict(target_dict) + + coord = self.coord + atype = self.atype + box = self.cell + nf, nloc = coord.shape[:2] + coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, des.get_rcut() + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + des.get_rcut(), + des.get_nsel(), + distinguish_types=False, + ) + descriptor, env_mat, diff, rot_mat, sw = des( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + self.assertAlmostEqual(6.0, des.get_rcut()) + self.assertEqual(30, des.get_nsel()) + self.assertEqual(2, des.get_ntype()) + torch.testing.assert_close( + descriptor.view(-1), self.ref_d, atol=1e-10, rtol=1e-10 + ) + + dparams["concat_output_tebd"] = True + des = DescrptDPA1( + **dparams, + ) + descriptor, env_mat, diff, rot_mat, sw = des( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + + +def translate_se_atten_and_type_embd_dicts_to_dpa1( + target_dict, + source_dict, + type_embd_dict, +): + all_keys = list(target_dict.keys()) + record = [False for ii in all_keys] + for kk, vv in source_dict.items(): + tk = "se_atten." + kk + record[all_keys.index(tk)] = True + target_dict[tk] = vv + assert len(type_embd_dict.keys()) == 1 + kk = next(iter(type_embd_dict.keys())) + tk = "type_embedding." + kk + record[all_keys.index(tk)] = True + target_dict[tk] = type_embd_dict[kk] + assert all(record) + return target_dict diff --git a/source/tests/pt/test_descriptor_dpa2.py b/source/tests/pt/test_descriptor_dpa2.py new file mode 100644 index 0000000000..45c95961fe --- /dev/null +++ b/source/tests/pt/test_descriptor_dpa2.py @@ -0,0 +1,264 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest +from pathlib import ( + Path, +) + +import torch + +from deepmd.pt.model.descriptor import ( + DescrptBlockHybrid, + DescrptDPA2, +) +from deepmd.pt.model.network.network import ( + TypeEmbedNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.pt.utils.region import ( + normalize_coord, +) + +dtype = torch.float64 +torch.set_default_dtype(dtype) + +CUR_DIR = os.path.dirname(__file__) + + +class TestDPA2(unittest.TestCase): + def setUp(self): + cell = [ + 5.122106549439247480e00, + 4.016537340154059388e-01, + 6.951654033828678081e-01, + 4.016537340154059388e-01, + 6.112136112297989143e00, + 8.178091365465004481e-01, + 6.951654033828678081e-01, + 8.178091365465004481e-01, + 6.159552512682983760e00, + ] + self.cell = torch.Tensor(cell).view(1, 3, 3).to(env.DEVICE) + coord = [ + 2.978060152121375648e00, + 3.588469695887098077e00, + 2.792459820604495491e00, + 3.895592322591093115e00, + 2.712091020667753760e00, + 1.366836847133650501e00, + 9.955616170888935690e-01, + 4.121324820711413039e00, + 1.817239061889086571e00, + 3.553661462345699906e00, + 5.313046969500791583e00, + 6.635182659098815883e00, + 6.088601018589653080e00, + 6.575011420004332585e00, + 6.825240650611076099e00, + ] + self.coord = torch.Tensor(coord).view(1, -1, 3).to(env.DEVICE) + self.atype = torch.IntTensor([0, 0, 0, 1, 1]).view(1, -1).to(env.DEVICE) + self.ref_d = torch.Tensor( + [ + 8.435412613327306630e-01, + -4.717109614540972440e-01, + -1.812643456954206256e00, + -2.315248767961955167e-01, + -7.112973006771171613e-01, + -4.162041919507591392e-01, + -1.505159810095323181e00, + -1.191652416985768403e-01, + 8.439214937875325617e-01, + -4.712976890460106594e-01, + -1.812605149396642856e00, + -2.307222236291133766e-01, + -7.115427800870099961e-01, + -4.164729253167227530e-01, + -1.505483119125936797e00, + -1.191288524278367872e-01, + 8.286420823261241297e-01, + -4.535033763979030574e-01, + -1.787877160970498425e00, + -1.961763875645104460e-01, + -7.475459187804838201e-01, + -5.231446874663764346e-01, + -1.488399984491664219e00, + -3.974117581747104583e-02, + 8.283793431613817315e-01, + -4.551551577556525729e-01, + -1.789253136645859943e00, + -1.977673627726055372e-01, + -7.448826048241211639e-01, + -5.161350182531234676e-01, + -1.487589463573479209e00, + -4.377376017839779143e-02, + 8.295404560710329944e-01, + -4.492219258475603216e-01, + -1.784484611185287450e00, + -1.901182059718481143e-01, + -7.537407667483000395e-01, + -5.384371277650709109e-01, + -1.490368056268364549e00, + -3.073744832541754762e-02, + ] + ).to(env.DEVICE) + with open(Path(CUR_DIR) / "models" / "dpa2_hyb.json") as fp: + self.model_json = json.load(fp) + self.file_model_param = Path(CUR_DIR) / "models" / "dpa2.pth" + self.file_type_embed = Path(CUR_DIR) / "models" / "dpa2_tebd.pth" + + def test_descriptor_hyb(self): + # torch.manual_seed(0) + model_hybrid_dpa2 = self.model_json + dparams = model_hybrid_dpa2["descriptor"] + ntypes = len(model_hybrid_dpa2["type_map"]) + dlist = dparams.pop("list") + des = DescrptBlockHybrid( + dlist, + ntypes, + hybrid_mode=dparams["hybrid_mode"], + ) + model_dict = torch.load(self.file_model_param) + # type_embd of repformer is removed + model_dict.pop("descriptor_list.1.type_embd.embedding.weight") + des.load_state_dict(model_dict) + all_rcut = [ii["rcut"] for ii in dlist] + all_nsel = [ii["sel"] for ii in dlist] + rcut_max = max(all_rcut) + coord = self.coord + atype = self.atype + box = self.cell + nf, nloc = coord.shape[:2] + coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, rcut_max + ) + ## single nlist + # nlist = build_neighbor_list( + # extended_coord, extended_atype, nloc, + # rcut_max, nsel, distinguish_types=False) + nlist_list = [] + for rcut, sel in zip(all_rcut, all_nsel): + nlist_list.append( + build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=False, + ) + ) + nlist = torch.cat(nlist_list, -1) + # handel type_embedding + type_embedding = TypeEmbedNet(ntypes, 8) + type_embedding.load_state_dict(torch.load(self.file_type_embed)) + + ## to save model parameters + # torch.save(des.state_dict(), 'model_weights.pth') + # torch.save(type_embedding.state_dict(), 'model_weights.pth') + descriptor, env_mat, diff, rot_mat, sw = des( + nlist, + extended_coord, + extended_atype, + type_embedding(extended_atype), + mapping=mapping, + ) + torch.testing.assert_close( + descriptor.view(-1), self.ref_d, atol=1e-10, rtol=1e-10 + ) + + def test_descriptor(self): + with open(Path(CUR_DIR) / "models" / "dpa2.json") as fp: + self.model_json = json.load(fp) + model_dpa2 = self.model_json + ntypes = len(model_dpa2["type_map"]) + dparams = model_dpa2["descriptor"] + dparams["ntypes"] = ntypes + assert dparams.pop("type") == "dpa2" + dparams["concat_output_tebd"] = False + des = DescrptDPA2( + **dparams, + ) + target_dict = des.state_dict() + source_dict = torch.load(self.file_model_param) + # type_embd of repformer is removed + source_dict.pop("descriptor_list.1.type_embd.embedding.weight") + type_embd_dict = torch.load(self.file_type_embed) + target_dict = translate_hybrid_and_type_embd_dicts_to_dpa2( + target_dict, + source_dict, + type_embd_dict, + ) + des.load_state_dict(target_dict) + + coord = self.coord + atype = self.atype + box = self.cell + nf, nloc = coord.shape[:2] + coord_normalized = normalize_coord(coord, box.reshape(-1, 3, 3)) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, des.repinit.rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + des.repinit.rcut, + des.repinit.sel, + distinguish_types=False, + ) + descriptor, env_mat, diff, rot_mat, sw = des( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + self.assertAlmostEqual(6.0, des.get_rcut()) + self.assertEqual(30, des.get_nsel()) + self.assertEqual(2, des.get_ntype()) + torch.testing.assert_close( + descriptor.view(-1), self.ref_d, atol=1e-10, rtol=1e-10 + ) + + dparams["concat_output_tebd"] = True + des = DescrptDPA2( + **dparams, + ) + descriptor, env_mat, diff, rot_mat, sw = des( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + + +def translate_hybrid_and_type_embd_dicts_to_dpa2( + target_dict, + source_dict, + type_embd_dict, +): + all_keys = list(target_dict.keys()) + record = [False for ii in all_keys] + for kk, vv in source_dict.items(): + tk = kk.replace("descriptor_list.1", "repformers") + tk = tk.replace("descriptor_list.0", "repinit") + tk = tk.replace("sequential_transform.0", "g1_shape_tranform") + record[all_keys.index(tk)] = True + target_dict[tk] = vv + assert len(type_embd_dict.keys()) == 1 + kk = next(iter(type_embd_dict.keys())) + tk = "type_embedding." + kk + record[all_keys.index(tk)] = True + target_dict[tk] = type_embd_dict[kk] + assert all(record) + return target_dict diff --git a/source/tests/pt/test_dp_test.py b/source/tests/pt/test_dp_test.py new file mode 100644 index 0000000000..3db66f073f --- /dev/null +++ b/source/tests/pt/test_dp_test.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import shutil +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import numpy as np + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.infer import ( + inference, +) + + +class TestDPTest(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single") + ] + self.input_json = "test_dp_test.json" + with open(self.input_json, "w") as fp: + json.dump(self.config, fp, indent=4) + + def test_dp_test(self): + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + + input_dict, label_dict, _ = trainer.get_data(is_train=False) + _, _, more_loss = trainer.wrapper(**input_dict, label=label_dict, cur_lr=1.0) + + tester = inference.Tester("model.pt", input_script=self.input_json) + try: + res = tester.run() + except StopIteration: + print("Unexpected stop iteration.(test step < total batch)") + raise StopIteration + for k, v in res.items(): + if k == "rmse" or "mae" in k or k not in more_loss: + continue + np.testing.assert_allclose( + v, more_loss[k].cpu().detach().numpy(), rtol=1e-04, atol=1e-07 + ) + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + os.remove(self.input_json) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_embedding_net.py b/source/tests/pt/test_embedding_net.py new file mode 100644 index 0000000000..fc98ddc9f9 --- /dev/null +++ b/source/tests/pt/test_embedding_net.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import re +import unittest + +import numpy as np +import tensorflow.compat.v1 as tf +import torch + +tf.disable_eager_execution() + +from pathlib import ( + Path, +) + +from deepmd.pt.model.descriptor import ( + DescrptSeA, +) +from deepmd.pt.utils import ( + dp_random, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSet, +) +from deepmd.pt.utils.env import ( + DEVICE, + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.tf.common import ( + expand_sys_str, +) +from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf + +CUR_DIR = os.path.dirname(__file__) + + +def gen_key(worb, depth, elemid): + return (worb, depth, elemid) + + +def base_se_a(descriptor, coord, atype, natoms, box): + g = tf.Graph() + with g.as_default(): + name_pfx = "d_sea_" + t_coord = tf.placeholder( + GLOBAL_NP_FLOAT_PRECISION, [None, None], name=name_pfx + "t_coord" + ) + t_atype = tf.placeholder(tf.int32, [None, None], name=name_pfx + "t_type") + t_natoms = tf.placeholder( + tf.int32, [descriptor.ntypes + 2], name=name_pfx + "t_natoms" + ) + t_box = tf.placeholder( + GLOBAL_NP_FLOAT_PRECISION, [None, None], name=name_pfx + "t_box" + ) + t_default_mesh = tf.placeholder(tf.int32, [None], name=name_pfx + "t_mesh") + t_embedding = descriptor.build( + t_coord, t_atype, t_natoms, t_box, t_default_mesh, input_dict={} + ) + fake_energy = tf.reduce_sum(t_embedding) + t_force = descriptor.prod_force_virial(fake_energy, t_natoms)[0] + t_vars = {} + for var in tf.global_variables(): + ms = re.findall(r"([a-z]+)_(\d)_(\d)", var.name) + if len(ms) == 1: + m = ms[0] + key = gen_key(worb=m[0], depth=int(m[1]), elemid=int(m[2])) + t_vars[key] = var + init_op = tf.global_variables_initializer() + + with tf.Session(graph=g) as sess: + sess.run(init_op) + embedding, force, values = sess.run( + [t_embedding, t_force, t_vars], + feed_dict={ + t_coord: coord, + t_atype: atype, + t_natoms: natoms, + t_box: box, + t_default_mesh: np.array([0, 0, 0, 2, 2, 2]), + }, + ) + tf.reset_default_graph() + return embedding, force, values + + +class TestSeA(unittest.TestCase): + def setUp(self): + dp_random.seed(0) + with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: + content = fin.read() + config = json.loads(content) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + config["training"]["training_data"]["systems"] = data_file + config["training"]["validation_data"]["systems"] = data_file + model_config = config["model"] + self.rcut = model_config["descriptor"]["rcut"] + self.rcut_smth = model_config["descriptor"]["rcut_smth"] + self.sel = model_config["descriptor"]["sel"] + self.bsz = config["training"]["training_data"]["batch_size"] + self.systems = config["training"]["validation_data"]["systems"] + if isinstance(self.systems, str): + self.systems = expand_sys_str(self.systems) + ds = DeepmdDataSet( + self.systems, self.bsz, model_config["type_map"], self.rcut, self.sel + ) + self.filter_neuron = model_config["descriptor"]["neuron"] + self.axis_neuron = model_config["descriptor"]["axis_neuron"] + self.np_batch, self.torch_batch = ds.get_batch() + + def test_consistency(self): + dp_d = DescrptSeA_tf( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=self.sel, + neuron=self.filter_neuron, + axis_neuron=self.axis_neuron, + seed=1, + ) + dp_embedding, dp_force, dp_vars = base_se_a( + descriptor=dp_d, + coord=self.np_batch["coord"], + atype=self.np_batch["atype"], + natoms=self.np_batch["natoms"], + box=self.np_batch["box"], + ) + + # Reproduced + old_impl = False + descriptor = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + neuron=self.filter_neuron, + axis_neuron=self.axis_neuron, + old_impl=old_impl, + ).to(DEVICE) + for name, param in descriptor.named_parameters(): + if old_impl: + ms = re.findall(r"(\d)\.deep_layers\.(\d)\.([a-z]+)", name) + else: + ms = re.findall(r"(\d)\.layers\.(\d)\.([a-z]+)", name) + if len(ms) == 1: + m = ms[0] + key = gen_key(worb=m[2], depth=int(m[1]) + 1, elemid=int(m[0])) + var = dp_vars[key] + with torch.no_grad(): + # Keep parameter value consistency between 2 implentations + param.data.copy_(torch.from_numpy(var)) + + pt_coord = self.torch_batch["coord"] + pt_coord.requires_grad_(True) + index = self.torch_batch["mapping"].unsqueeze(-1).expand(-1, -1, 3) + extended_coord = torch.gather(pt_coord, dim=1, index=index) + extended_coord = extended_coord - self.torch_batch["shift"] + extended_atype = torch.gather( + self.torch_batch["atype"], dim=1, index=self.torch_batch["mapping"] + ) + descriptor_out, _, _, _, _ = descriptor( + extended_coord, + extended_atype, + self.torch_batch["nlist"], + ) + my_embedding = descriptor_out.cpu().detach().numpy() + fake_energy = torch.sum(descriptor_out) + fake_energy.backward() + my_force = -pt_coord.grad.cpu().numpy() + + # Check + np.testing.assert_allclose(dp_embedding, my_embedding) + dp_force = dp_force.reshape(*my_force.shape) + np.testing.assert_allclose(dp_force, my_force) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_env_mat.py b/source/tests/pt/test_env_mat.py new file mode 100644 index 0000000000..f4931e9ecc --- /dev/null +++ b/source/tests/pt/test_env_mat.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +try: + from deepmd.model_format import ( + EnvMat, + ) + + support_env_mat = True +except ModuleNotFoundError: + support_env_mat = False +except ImportError: + support_env_mat = False + +from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat_se_a, +) +from deepmd.pt.utils import ( + env, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestCaseSingleFrameWithNlist: + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 3 + self.nall = 4 + self.nf, self.nt = 1, 2 + self.coord_ext = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, -2, 0], + ], + dtype=np.float64, + ).reshape([1, self.nall * 3]) + self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) + # sel = [5, 2] + self.sel = [5, 2] + self.nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, 0, -1], + ], + dtype=int, + ).reshape([1, self.nloc, sum(self.sel)]) + self.rcut = 0.4 + self.rcut_smth = 2.2 + + +# to be merged with the tf test case +@unittest.skipIf(not support_env_mat, "EnvMat not supported") +class TestEnvMat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + em0 = EnvMat(self.rcut, self.rcut_smth) + mm0, ww0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd) + mm1, _, ww1 = prod_env_mat_se_a( + torch.tensor(self.coord_ext, dtype=dtype), + torch.tensor(self.nlist, dtype=int), + torch.tensor(self.atype_ext[:, :nloc], dtype=int), + davg, + dstd, + self.rcut, + self.rcut_smth, + ) + np.testing.assert_allclose(mm0, mm1) + np.testing.assert_allclose(ww0, ww1) diff --git a/source/tests/pt/test_fitting_net.py b/source/tests/pt/test_fitting_net.py new file mode 100644 index 0000000000..3feb4f4739 --- /dev/null +++ b/source/tests/pt/test_fitting_net.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import re +import unittest + +import numpy as np +import tensorflow.compat.v1 as tf +import torch + +tf.disable_eager_execution() + +from deepmd.pt.model.task import ( + EnergyFittingNet, +) +from deepmd.pt.utils.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.tf.fit.ener import ( + EnerFitting, +) + + +class FakeDescriptor: + def __init__(self, ntypes, embedding_width): + self._ntypes = ntypes + self._dim_out = embedding_width + + def get_ntypes(self): + return self._ntypes + + def get_dim_out(self): + return self._dim_out + + +def gen_key(type_id, layer_id, w_or_b): + return (type_id, layer_id, w_or_b) + + +def base_fitting_net(dp_fn, embedding, natoms, atype): + g = tf.Graph() + with g.as_default(): + t_embedding = tf.placeholder(GLOBAL_NP_FLOAT_PRECISION, [None, None]) + t_natoms = tf.placeholder(tf.int32, [None]) + t_atype = tf.placeholder(tf.int32, [None, None]) + t_energy = dp_fn.build(t_embedding, t_natoms, {"atype": t_atype}) + init_op = tf.global_variables_initializer() + t_vars = {} + for var in tf.global_variables(): + key = None + matched = re.match(r"layer_(\d)_type_(\d)/([a-z]+)", var.name) + if matched: + key = gen_key( + type_id=matched.group(2), + layer_id=matched.group(1), + w_or_b=matched.group(3), + ) + else: + matched = re.match(r"final_layer_type_(\d)/([a-z]+)", var.name) + if matched: + key = gen_key( + type_id=matched.group(1), layer_id=-1, w_or_b=matched.group(2) + ) + if key is not None: + t_vars[key] = var + + with tf.Session(graph=g) as sess: + sess.run(init_op) + energy, values = sess.run( + [t_energy, t_vars], + feed_dict={ + t_embedding: embedding, + t_natoms: natoms, + t_atype: atype, + }, + ) + tf.reset_default_graph() + return energy, values + + +class TestFittingNet(unittest.TestCase): + def setUp(self): + nloc = 7 + self.embedding_width = 30 + self.natoms = np.array([nloc, nloc, 2, 5], dtype=np.int32) + rng = np.random.default_rng() + self.embedding = rng.uniform(size=[4, nloc * self.embedding_width]) + self.ntypes = self.natoms.size - 2 + self.n_neuron = [32, 32, 32] + self.atype = np.zeros([4, nloc], dtype=np.int32) + cnt = 0 + for i in range(self.ntypes): + self.atype[:, cnt : cnt + self.natoms[i + 2]] = i + cnt += self.natoms[i + 2] + + fake_d = FakeDescriptor(2, 30) + self.dp_fn = EnerFitting(fake_d, self.n_neuron) + self.dp_fn.bias_atom_e = rng.uniform(size=[self.ntypes]) + + def test_consistency(self): + dp_energy, values = base_fitting_net( + self.dp_fn, self.embedding, self.natoms, self.atype + ) + my_fn = EnergyFittingNet( + self.ntypes, + self.embedding_width, + self.n_neuron, + self.dp_fn.bias_atom_e, + use_tebd=False, + ) + for name, param in my_fn.named_parameters(): + matched = re.match("filter_layers\.(\d).deep_layers\.(\d)\.([a-z]+)", name) + key = None + if matched: + key = gen_key( + type_id=matched.group(1), + layer_id=matched.group(2), + w_or_b=matched.group(3), + ) + else: + matched = re.match("filter_layers\.(\d).final_layer\.([a-z]+)", name) + if matched: + key = gen_key( + type_id=matched.group(1), layer_id=-1, w_or_b=matched.group(2) + ) + assert key is not None + var = values[key] + with torch.no_grad(): + # Keep parameter value consistency between 2 implentations + param.data.copy_(torch.from_numpy(var)) + embedding = torch.from_numpy(self.embedding) + embedding = embedding.view(4, -1, self.embedding_width) + atype = torch.from_numpy(self.atype) + ret = my_fn(embedding, atype) + my_energy = ret["energy"] + my_energy = my_energy.detach() + self.assertTrue(np.allclose(dp_energy, my_energy.numpy().reshape([-1]))) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_force_grad.py b/source/tests/pt/test_force_grad.py new file mode 100644 index 0000000000..1ea4321d21 --- /dev/null +++ b/source/tests/pt/test_force_grad.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import json +import unittest +from pathlib import ( + Path, +) +from typing import ( + List, + Optional, +) + +import numpy as np +import torch + +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSystem, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) + + +class CheckSymmetry(DeepmdDataSystem): + def __init__( + self, + sys_path: str, + rcut, + sec, + type_map: Optional[List[str]] = None, + type_split=True, + ): + super().__init__(sys_path, rcut, sec, type_map, type_split) + + def get_disturb(self, index, atom_index, axis_index, delta): + for i in range( + 0, len(self._dirs) + 1 + ): # note: if different sets can be merged, prefix sum is unused to calculate + if index < self.prefix_sum[i]: + break + frames = self._load_set(self._dirs[i - 1]) + tmp = copy.deepcopy(frames["coord"].reshape(self.nframes, -1, 3)) + tmp[:, atom_index, axis_index] += delta + frames["coord"] = tmp + frame = self.single_preprocess(frames, index - self.prefix_sum[i - 1]) + return frame + + +def get_data(batch): + inputs = {} + for key in ["coord", "atype", "box"]: + inputs[key] = batch[key].unsqueeze(0).to(env.DEVICE) + return inputs + + +class TestForceGrad(unittest.TestCase): + def setUp(self): + with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: + self.config = json.load(fin) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.system_index = 0 + self.batch_index = 0 + self.get_dataset(self.system_index, self.batch_index) + self.get_model() + + def get_model(self): + training_systems = self.config["training"]["training_data"]["systems"] + model_params = self.config["model"] + data_stat_nbatch = model_params.get("data_stat_nbatch", 10) + train_data = DpLoaderSet( + training_systems, + self.config["training"]["training_data"]["batch_size"], + model_params, + ) + sampled = make_stat_input( + train_data.systems, train_data.dataloaders, data_stat_nbatch + ) + self.model = get_model(self.config["model"], sampled).to(env.DEVICE) + + def get_dataset(self, system_index=0, batch_index=0): + systems = self.config["training"]["training_data"]["systems"] + rcut = self.config["model"]["descriptor"]["rcut"] + sel = self.config["model"]["descriptor"]["sel"] + sec = torch.cumsum(torch.tensor(sel), dim=0) + type_map = self.config["model"]["type_map"] + self.dpdatasystem = CheckSymmetry( + sys_path=systems[system_index], rcut=rcut, sec=sec, type_map=type_map + ) + self.origin_batch = self.dpdatasystem._get_item(batch_index) + + @unittest.skip("it can be replaced by autodiff") + def test_force_grad(self, threshold=1e-2, delta0=1e-6, seed=20): + result0 = self.model(**get_data(self.origin_batch)) + np.random.default_rng(seed) + errors = np.zeros((self.dpdatasystem._natoms, 3)) + for atom_index in range(self.dpdatasystem._natoms): + for axis_index in range(3): + delta = np.random.random() * delta0 + disturb_batch = self.dpdatasystem.get_disturb( + self.batch_index, atom_index, axis_index, delta + ) + disturb_result = self.model(**get_data(disturb_batch)) + disturb_force = -(disturb_result["energy"] - result0["energy"]) / delta + disturb_error = ( + result0["force"][0, atom_index, axis_index] - disturb_force + ) + errors[atom_index, axis_index] = disturb_error.detach().cpu().numpy() + self.assertTrue(np.abs(errors).max() < threshold, msg=str(np.abs(errors).max())) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_jit.py b/source/tests/pt/test_jit.py new file mode 100644 index 0000000000..f13dade183 --- /dev/null +++ b/source/tests/pt/test_jit.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import shutil +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import torch + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.infer import ( + inference, +) + +from .test_permutation import ( + model_dpa1, + model_dpa2, + model_hybrid, + model_se_e2_a, +) + + +class JITTest: + def test_jit(self): + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + model = torch.jit.script(inference.Tester("./model.pt", numb_test=1).model) + torch.jit.save(model, "./frozen_model.pth", {}) + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("model") and f.endswith("pt"): + os.remove(f) + if f in ["lcurve.out", "frozen_model.pth"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + +class TestEnergyModelSeA(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) + + +class TestEnergyModelDPA1(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa1) + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) + + +class TestEnergyModelDPA2(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa2) + self.config["model"]["descriptor"]["rcut"] = self.config["model"]["descriptor"][ + "repinit_rcut" + ] + self.config["model"]["descriptor"]["rcut_smth"] = self.config["model"][ + "descriptor" + ]["repinit_rcut_smth"] + self.config["model"]["descriptor"]["sel"] = self.config["model"]["descriptor"][ + "repinit_nsel" + ] + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) + + +@unittest.skip("hybrid not supported at the moment") +class TestEnergyModelHybrid(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_hybrid) + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) + + +@unittest.skip("hybrid not supported at the moment") +class TestEnergyModelHybrid2(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_hybrid) + self.config["model"]["descriptor"]["hybrid_mode"] = "sequential" + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py new file mode 100644 index 0000000000..14934c7be0 --- /dev/null +++ b/source/tests/pt/test_loss.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest + +import numpy as np +import tensorflow.compat.v1 as tf +import torch + +tf.disable_eager_execution() +from pathlib import ( + Path, +) + +from deepmd.pt.loss import ( + EnergyStdLoss, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSet, +) +from deepmd.tf.common import ( + expand_sys_str, +) +from deepmd.tf.loss.ener import ( + EnerStdLoss, +) + +CUR_DIR = os.path.dirname(__file__) + + +def get_batch(): + with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: + content = fin.read() + config = json.loads(content) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + config["training"]["training_data"]["systems"] = data_file + config["training"]["validation_data"]["systems"] = data_file + model_config = config["model"] + rcut = model_config["descriptor"]["rcut"] + # self.rcut_smth = model_config['descriptor']['rcut_smth'] + sel = model_config["descriptor"]["sel"] + batch_size = config["training"]["training_data"]["batch_size"] + systems = config["training"]["validation_data"]["systems"] + if isinstance(systems, str): + systems = expand_sys_str(systems) + dataset = DeepmdDataSet(systems, batch_size, model_config["type_map"], rcut, sel) + np_batch, pt_batch = dataset.get_batch() + return np_batch, pt_batch + + +class TestLearningRate(unittest.TestCase): + def setUp(self): + self.start_lr = 1.1 + self.start_pref_e = 0.02 + self.limit_pref_e = 1.0 + self.start_pref_f = 1000.0 + self.limit_pref_f = 1.0 + self.start_pref_v = 0.02 + self.limit_pref_v = 1.0 + self.cur_lr = 1.2 + # data + np_batch, pt_batch = get_batch() + natoms = np_batch["natoms"] + self.nloc = natoms[0] + l_energy, l_force, l_virial = ( + np_batch["energy"], + np_batch["force"], + np_batch["virial"], + ) + p_energy, p_force, p_virial = ( + np.ones_like(l_energy), + np.ones_like(l_force), + np.ones_like(l_virial), + ) + nloc = natoms[0] + batch_size = pt_batch["coord"].shape[0] + atom_energy = np.zeros(shape=[batch_size, nloc]) + atom_pref = np.zeros(shape=[batch_size, nloc * 3]) + # tf + base = EnerStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + ) + self.g = tf.Graph() + with self.g.as_default(): + t_cur_lr = tf.placeholder(shape=[], dtype=tf.float64) + t_natoms = tf.placeholder(shape=[None], dtype=tf.int32) + t_penergy = tf.placeholder(shape=[None, 1], dtype=tf.float64) + t_pforce = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_pvirial = tf.placeholder(shape=[None, 9], dtype=tf.float64) + t_patom_energy = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_lenergy = tf.placeholder(shape=[None, 1], dtype=tf.float64) + t_lforce = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_lvirial = tf.placeholder(shape=[None, 9], dtype=tf.float64) + t_latom_energy = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_atom_pref = tf.placeholder(shape=[None, None], dtype=tf.float64) + find_energy = tf.constant(1.0, dtype=tf.float64) + find_force = tf.constant(1.0, dtype=tf.float64) + find_virial = tf.constant(1.0, dtype=tf.float64) + find_atom_energy = tf.constant(0.0, dtype=tf.float64) + find_atom_pref = tf.constant(0.0, dtype=tf.float64) + model_dict = { + "energy": t_penergy, + "force": t_pforce, + "virial": t_pvirial, + "atom_ener": t_patom_energy, + } + label_dict = { + "energy": t_lenergy, + "force": t_lforce, + "virial": t_lvirial, + "atom_ener": t_latom_energy, + "atom_pref": t_atom_pref, + "find_energy": find_energy, + "find_force": find_force, + "find_virial": find_virial, + "find_atom_ener": find_atom_energy, + "find_atom_pref": find_atom_pref, + } + self.base_loss_sess = base.build( + t_cur_lr, t_natoms, model_dict, label_dict, "" + ) + # torch + self.feed_dict = { + t_cur_lr: self.cur_lr, + t_natoms: natoms, + t_penergy: p_energy, + t_pforce: p_force, + t_pvirial: p_virial.reshape(-1, 9), + t_patom_energy: atom_energy, + t_lenergy: l_energy, + t_lforce: l_force, + t_lvirial: l_virial.reshape(-1, 9), + t_latom_energy: atom_energy, + t_atom_pref: atom_pref, + } + self.model_pred = { + "energy": torch.from_numpy(p_energy), + "force": torch.from_numpy(p_force), + "virial": torch.from_numpy(p_virial), + } + self.label = { + "energy": torch.from_numpy(l_energy), + "force": torch.from_numpy(l_force), + "virial": torch.from_numpy(l_virial), + } + self.natoms = pt_batch["natoms"] + + def tearDown(self) -> None: + tf.reset_default_graph() + return super().tearDown() + + def test_consistency(self): + with tf.Session(graph=self.g) as sess: + base_loss, base_more_loss = sess.run( + self.base_loss_sess, feed_dict=self.feed_dict + ) + mine = EnergyStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + ) + my_loss, my_more_loss = mine( + self.label, + self.model_pred, + self.nloc, + self.cur_lr, + ) + my_loss = my_loss.detach().cpu() + self.assertTrue(np.allclose(base_loss, my_loss.numpy())) + for key in ["ener", "force", "virial"]: + self.assertTrue( + np.allclose( + base_more_loss["l2_%s_loss" % key], my_more_loss["l2_%s_loss" % key] + ) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_lr.py b/source/tests/pt/test_lr.py new file mode 100644 index 0000000000..ca1ec7e490 --- /dev/null +++ b/source/tests/pt/test_lr.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import tensorflow.compat.v1 as tf + +tf.disable_eager_execution() + +from deepmd.pt.utils.learning_rate import ( + LearningRateExp, +) +from deepmd.tf.utils import ( + learning_rate, +) + + +class TestLearningRate(unittest.TestCase): + def setUp(self): + self.start_lr = 0.001 + self.stop_lr = 3.51e-8 + self.decay_steps = np.arange(400, 601, 100) + self.stop_steps = np.arange(500, 1600, 500) + + def test_consistency(self): + for decay_step in self.decay_steps: + for stop_step in self.stop_steps: + self.decay_step = decay_step + self.stop_step = stop_step + self.judge_it() + + def judge_it(self): + base_lr = learning_rate.LearningRateExp( + self.start_lr, self.stop_lr, self.decay_step + ) + g = tf.Graph() + with g.as_default(): + global_step = tf.placeholder(shape=[], dtype=tf.int32) + t_lr = base_lr.build(global_step, self.stop_step) + + my_lr = LearningRateExp( + self.start_lr, self.stop_lr, self.decay_step, self.stop_step + ) + with tf.Session(graph=g) as sess: + base_vals = [ + sess.run(t_lr, feed_dict={global_step: step_id}) + for step_id in range(self.stop_step) + if step_id % self.decay_step != 0 + ] + my_vals = [ + my_lr.value(step_id) + for step_id in range(self.stop_step) + if step_id % self.decay_step != 0 + ] + self.assertTrue(np.allclose(base_vals, my_vals)) + tf.reset_default_graph() + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_mlp.py b/source/tests/pt/test_mlp.py new file mode 100644 index 0000000000..c06047b2a5 --- /dev/null +++ b/source/tests/pt/test_mlp.py @@ -0,0 +1,321 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +try: + from deepmd.pt.model.network.mlp import ( + MLP, + MLPLayer, + ) + + support_native_net = True +except ModuleNotFoundError: + support_native_net = False + +try: + from deepmd.pt.model.network.mlp import ( + EmbeddingNet, + ) + + support_embedding_net = True +except ModuleNotFoundError: + support_embedding_net = False + +try: + from deepmd.pt.model.network.mlp import ( + FittingNet, + ) + + support_fitting_net = True +except ModuleNotFoundError: + support_fitting_net = False + + +try: + from deepmd.model_format import ( + NativeLayer, + NativeNet, + ) + + support_native_net = True +except ModuleNotFoundError: + support_native_net = False +except ImportError: + support_native_net = False + +try: + from deepmd.model_format import EmbeddingNet as DPEmbeddingNet + + support_embedding_net = True +except ModuleNotFoundError: + support_embedding_net = False +except ImportError: + support_embedding_net = False + +try: + from deepmd.model_format import FittingNet as DPFittingNet + + support_fitting_net = True +except ModuleNotFoundError: + support_fitting_net = False +except ImportError: + support_fitting_net = False + + +def get_tols(prec): + if prec in ["single", "float32"]: + rtol, atol = 0.0, 1e-4 + elif prec in ["double", "float64"]: + rtol, atol = 0.0, 1e-12 + # elif prec in ["half", "float16"]: + # rtol, atol=1e-2, 0 + else: + raise ValueError(f"unknown prec {prec}") + return rtol, atol + + +@unittest.skipIf(not support_native_net, "NativeLayer not supported") +class TestMLPLayer(unittest.TestCase): + def setUp(self): + self.test_cases = itertools.product( + [(5, 5), (5, 10), (5, 8), (8, 5)], # inp, out + [True, False], # bias + [True, False], # use time step + ["tanh", "none"], # activation + [True, False], # resnet + [None, [4], [3, 2]], # prefix shapes + ["float32", "double"], # precision + ) + + def test_match_native_layer( + self, + ): + for (ninp, nout), bias, ut, ac, resnet, ashp, prec in self.test_cases: + # input + inp_shap = [ninp] + if ashp is not None: + inp_shap = ashp + inp_shap + rtol, atol = get_tols(prec) + dtype = PRECISION_DICT[prec] + xx = torch.arange(np.prod(inp_shap), dtype=dtype).view(inp_shap) + # def mlp layer + ml = MLPLayer(ninp, nout, bias, ut, ac, resnet, precision=prec) + # check consistency + nl = NativeLayer.deserialize(ml.serialize()) + np.testing.assert_allclose( + ml.forward(xx).detach().numpy(), + nl.call(xx.detach().numpy()), + rtol=rtol, + atol=atol, + err_msg=f"(i={ninp}, o={nout}) bias={bias} use_dt={ut} act={ac} resnet={resnet} prec={prec}", + ) + # check self-consistency + ml1 = MLPLayer.deserialize(ml.serialize()) + np.testing.assert_allclose( + ml.forward(xx).detach().numpy(), + ml1.forward(xx).detach().numpy(), + rtol=rtol, + atol=atol, + err_msg=f"(i={ninp}, o={nout}) bias={bias} use_dt={ut} act={ac} resnet={resnet} prec={prec}", + ) + + def test_jit(self): + for (ninp, nout), bias, ut, ac, resnet, _, prec in self.test_cases: + ml = MLPLayer(ninp, nout, bias, ut, ac, resnet, precision=prec) + model = torch.jit.script(ml) + ml1 = MLPLayer.deserialize(ml.serialize()) + model = torch.jit.script(ml1) + + +@unittest.skipIf(not support_native_net, "NativeLayer not supported") +class TestMLP(unittest.TestCase): + def setUp(self): + self.test_cases = itertools.product( + [[2, 2, 4, 8], [1, 3, 3]], # inp and hiddens + [True, False], # bias + [True, False], # use time step + ["tanh", "none"], # activation + [True, False], # resnet + [None, [4], [3, 2]], # prefix shapes + ["float32", "double"], # precision + ) + + def test_match_native_net( + self, + ): + for ndims, bias, ut, ac, resnet, ashp, prec in self.test_cases: + # input + inp_shap = [ndims[0]] + if ashp is not None: + inp_shap = ashp + inp_shap + rtol, atol = get_tols(prec) + dtype = PRECISION_DICT[prec] + xx = torch.arange(np.prod(inp_shap), dtype=dtype).view(inp_shap) + # def MLP + layers = [] + for ii in range(1, len(ndims)): + layers.append( + MLPLayer( + ndims[ii - 1], ndims[ii], bias, ut, ac, resnet, precision=prec + ).serialize() + ) + ml = MLP(layers) + # check consistency + nl = NativeNet.deserialize(ml.serialize()) + np.testing.assert_allclose( + ml.forward(xx).detach().numpy(), + nl.call(xx.detach().numpy()), + rtol=rtol, + atol=atol, + err_msg=f"net={ndims} bias={bias} use_dt={ut} act={ac} resnet={resnet} prec={prec}", + ) + # check self-consistency + ml1 = MLP.deserialize(ml.serialize()) + np.testing.assert_allclose( + ml.forward(xx).detach().numpy(), + ml1.forward(xx).detach().numpy(), + rtol=rtol, + atol=atol, + err_msg=f"net={ndims} bias={bias} use_dt={ut} act={ac} resnet={resnet} prec={prec}", + ) + + def test_jit(self): + for ndims, bias, ut, ac, resnet, _, prec in self.test_cases: + layers = [] + for ii in range(1, len(ndims)): + ml = layers.append( + MLPLayer( + ndims[ii - 1], ndims[ii], bias, ut, ac, resnet, precision=prec + ).serialize() + ) + ml = MLP(ml) + model = torch.jit.script(ml) + ml1 = MLP.deserialize(ml.serialize()) + model = torch.jit.script(ml1) + + +@unittest.skipIf(not support_embedding_net, "EmbeddingNet not supported") +class TestEmbeddingNet(unittest.TestCase): + def setUp(self): + self.test_cases = itertools.product( + [1, 3], # inp + [[24, 48, 96], [24, 36]], # and hiddens + ["tanh", "none"], # activation + [True, False], # resnet_dt + ["float32", "double"], # precision + ) + + def test_match_embedding_net( + self, + ): + for idim, nn, act, idt, prec in self.test_cases: + # input + rtol, atol = get_tols(prec) + dtype = PRECISION_DICT[prec] + xx = torch.arange(idim, dtype=dtype) + # def MLP + ml = EmbeddingNet(idim, nn, act, idt, prec) + # check consistency + nl = DPEmbeddingNet.deserialize(ml.serialize()) + np.testing.assert_allclose( + ml.forward(xx).detach().numpy(), + nl.call(xx.detach().numpy()), + rtol=rtol, + atol=atol, + err_msg=f"idim={idim} nn={nn} use_dt={idt} act={act} prec={prec}", + ) + # check self-consistency + ml1 = EmbeddingNet.deserialize(ml.serialize()) + np.testing.assert_allclose( + ml.forward(xx).detach().numpy(), + ml1.forward(xx).detach().numpy(), + rtol=rtol, + atol=atol, + err_msg=f"idim={idim} nn={nn} use_dt={idt} act={act} prec={prec}", + ) + + def test_jit( + self, + ): + for idim, nn, act, idt, prec in self.test_cases: + # def MLP + ml = EmbeddingNet(idim, nn, act, idt, prec) + ml1 = EmbeddingNet.deserialize(ml.serialize()) + model = torch.jit.script(ml) + model = torch.jit.script(ml1) + + +@unittest.skipIf(not support_fitting_net, "FittingNet not supported") +class TestFittingNet(unittest.TestCase): + def setUp(self): + self.test_cases = itertools.product( + [1, 3], # inp + [1, 5], # out + [[24, 48, 96], [24, 36]], # and hiddens + ["tanh", "none"], # activation + [True, False], # resnet_dt + ["float32", "double"], # precision + [True, False], # bias_out + ) + + def test_match_fitting_net( + self, + ): + for idim, odim, nn, act, idt, prec, ob in self.test_cases: + # input + rtol, atol = get_tols(prec) + dtype = PRECISION_DICT[prec] + xx = torch.arange(idim, dtype=dtype) + # def MLP + ml = FittingNet( + idim, + odim, + neuron=nn, + activation_function=act, + resnet_dt=idt, + precision=prec, + bias_out=ob, + ) + # check consistency + nl = DPFittingNet.deserialize(ml.serialize()) + np.testing.assert_allclose( + ml.forward(xx).detach().numpy(), + nl.call(xx.detach().numpy()), + rtol=rtol, + atol=atol, + err_msg=f"idim={idim} nn={nn} use_dt={idt} act={act} prec={prec}", + ) + # check self-consistency + ml1 = FittingNet.deserialize(ml.serialize()) + np.testing.assert_allclose( + ml.forward(xx).detach().numpy(), + ml1.forward(xx).detach().numpy(), + rtol=rtol, + atol=atol, + err_msg=f"idim={idim} nn={nn} use_dt={idt} act={act} prec={prec}", + ) + + def test_jit( + self, + ): + for idim, odim, nn, act, idt, prec, ob in self.test_cases: + # def MLP + ml = FittingNet( + idim, + odim, + neuron=nn, + activation_function=act, + resnet_dt=idt, + precision=prec, + bias_out=ob, + ) + ml1 = FittingNet.deserialize(ml.serialize()) + model = torch.jit.script(ml) + model = torch.jit.script(ml1) diff --git a/source/tests/pt/test_model.py b/source/tests/pt/test_model.py new file mode 100644 index 0000000000..5bbbc9e352 --- /dev/null +++ b/source/tests/pt/test_model.py @@ -0,0 +1,415 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import collections +import json +import unittest + +import numpy as np +import tensorflow.compat.v1 as tf +import torch + +tf.disable_eager_execution() + +from pathlib import ( + Path, +) + +from deepmd.pt.loss import ( + EnergyStdLoss, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.env import ( + DEVICE, +) +from deepmd.pt.utils.learning_rate import LearningRateExp as MyLRExp +from deepmd.pt.utils.stat import ( + make_stat_input, +) +from deepmd.tf.common import ( + data_requirement, + expand_sys_str, +) +from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf +from deepmd.tf.fit import ( + EnerFitting, +) +from deepmd.tf.loss import ( + EnerStdLoss, +) +from deepmd.tf.model import ( + EnerModel, +) +from deepmd.tf.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.tf.utils.learning_rate import ( + LearningRateExp, +) + +VariableState = collections.namedtuple("VariableState", ["value", "gradient"]) + + +def torch2tf(torch_name): + fields = torch_name.split(".") + offset = int(fields[2] == "networks") + element_id = int(fields[2 + offset]) + if fields[0] == "descriptor": + layer_id = int(fields[4 + offset]) + 1 + weight_type = fields[5 + offset] + return "filter_type_all/%s_%d_%d:0" % (weight_type, layer_id, element_id) + elif fields[3] == "deep_layers": + layer_id = int(fields[4]) + weight_type = fields[5] + return "layer_%d_type_%d/%s:0" % (layer_id, element_id, weight_type) + elif fields[3] == "final_layer": + weight_type = fields[4] + return "final_layer_type_%d/%s:0" % (element_id, weight_type) + else: + raise RuntimeError("Unexpected parameter name: %s" % torch_name) + + +class DpTrainer: + def __init__(self): + with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: + content = fin.read() + config = json.loads(content) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + config["training"]["training_data"]["systems"] = data_file + config["training"]["validation_data"]["systems"] = data_file + model_config = config["model"] + self.rcut = model_config["descriptor"]["rcut"] + self.rcut_smth = model_config["descriptor"]["rcut_smth"] + self.sel = model_config["descriptor"]["sel"] + self.systems = config["training"]["validation_data"]["systems"] + if isinstance(self.systems, str): + self.systems = expand_sys_str(self.systems) + self.batch_size = config["training"]["training_data"]["batch_size"] + self.type_map = model_config["type_map"] + self.filter_neuron = model_config["descriptor"]["neuron"] + self.axis_neuron = model_config["descriptor"]["axis_neuron"] + self.n_neuron = model_config["fitting_net"]["neuron"] + self.data_stat_nbatch = 3 + self.start_lr = 0.001 + self.stop_lr = 3.51e-8 + self.decay_steps = 500 + self.stop_steps = 1600 + self.start_pref_e = 1.0 + self.limit_pref_e = 2.0 + self.start_pref_f = 2.0 + self.limit_pref_f = 1.0 + self.ntypes = len(self.type_map) + + def get_intermediate_state(self, num_steps=1): + dp_model = self._get_dp_model() + dp_loss = self._get_dp_loss() + dp_lr = self._get_dp_lr() + dp_ds = self._get_dp_dataset() + dp_model.data_stat(dp_ds) + + # Build graph + g = tf.Graph() + with g.as_default(): + place_holders = self._get_dp_placeholders(dp_ds) + model_pred = dp_model.build( + coord_=place_holders["coord"], + atype_=place_holders["type"], + natoms=place_holders["natoms_vec"], + box=place_holders["box"], + mesh=place_holders["default_mesh"], + input_dict=place_holders, + ) + global_step = tf.train.get_or_create_global_step() + learning_rate = dp_lr.build(global_step, self.stop_steps) + l2_l, _ = dp_loss.build( + learning_rate=learning_rate, + natoms=place_holders["natoms_vec"], + model_dict=model_pred, + label_dict=place_holders, + suffix="test", + ) + t_vars = tf.trainable_variables() + optimizer = tf.train.AdamOptimizer(learning_rate) + t_grad_and_vars = optimizer.compute_gradients(l2_l, t_vars) + train_op = optimizer.apply_gradients(t_grad_and_vars, global_step) + init_op = tf.global_variables_initializer() + t_heads = { + "loss": l2_l, + "energy": model_pred["energy"], + "force": model_pred["force"], + "virial": model_pred["virial"], + "atomic_virial": model_pred["atom_virial"], + } + + # Get statistics of each component + stat_dict = { + "descriptor.mean": dp_model.descrpt.davg, + "descriptor.stddev": dp_model.descrpt.dstd, + "fitting_net.bias_atom_e": dp_model.fitting.bias_atom_e, + } + + # Get variables and their gradients + with tf.Session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_steps): + batch = dp_ds.get_batch() + feeds = self._get_feed_dict(batch, place_holders) + sess.run(train_op, feed_dict=feeds) + + batch = dp_ds.get_batch() + feeds = self._get_feed_dict(batch, place_holders) + grads_and_vars, head_dict = sess.run( + [t_grad_and_vars, t_heads], feed_dict=feeds + ) + vs_dict = {} + for idx, one in enumerate(t_vars): + grad, var = grads_and_vars[idx] + vs_dict[one.name] = VariableState(var, grad) + + tf.reset_default_graph() + # Used for reproducing + return batch, head_dict, stat_dict, vs_dict + + def _get_dp_dataset(self): + data = DeepmdDataSystem( + systems=self.systems, + batch_size=self.batch_size, + test_size=1, + rcut=self.rcut, + type_map=self.type_map, + trn_all_set=True, + ) + data.add_dict(data_requirement) + return data + + def _get_dp_model(self): + dp_descrpt = DescrptSeA_tf( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=self.sel, + neuron=self.filter_neuron, + axis_neuron=self.axis_neuron, + ) + dp_fitting = EnerFitting(descrpt=dp_descrpt, neuron=self.n_neuron) + return EnerModel( + dp_descrpt, + dp_fitting, + type_map=self.type_map, + data_stat_nbatch=self.data_stat_nbatch, + ) + + def _get_dp_loss(self): + return EnerStdLoss( + starter_learning_rate=self.start_lr, + start_pref_e=self.start_pref_e, + limit_pref_e=self.limit_pref_e, + start_pref_f=self.start_pref_f, + limit_pref_f=self.limit_pref_f, + ) + + def _get_dp_lr(self): + return LearningRateExp( + start_lr=self.start_lr, stop_lr=self.stop_lr, decay_steps=self.decay_steps + ) + + def _get_dp_placeholders(self, dataset): + place_holders = {} + data_dict = dataset.get_data_dict() + for kk in data_dict.keys(): + if kk == "type": + continue + prec = tf.float64 + place_holders[kk] = tf.placeholder(prec, [None], name="t_" + kk) + place_holders["find_" + kk] = tf.placeholder( + tf.float32, name="t_find_" + kk + ) + place_holders["type"] = tf.placeholder(tf.int32, [None], name="t_type") + place_holders["natoms_vec"] = tf.placeholder( + tf.int32, [self.ntypes + 2], name="t_natoms" + ) + place_holders["default_mesh"] = tf.placeholder(tf.int32, [None], name="t_mesh") + place_holders["is_training"] = tf.placeholder(tf.bool) + return place_holders + + def _get_feed_dict(self, batch, place_holders): + feed_dict = {} + for kk in batch.keys(): + if kk == "find_type" or kk == "type": + continue + if "find_" in kk: + feed_dict[place_holders[kk]] = batch[kk] + else: + feed_dict[place_holders[kk]] = np.reshape(batch[kk], [-1]) + for ii in ["type"]: + feed_dict[place_holders[ii]] = np.reshape(batch[ii], [-1]) + for ii in ["natoms_vec", "default_mesh"]: + feed_dict[place_holders[ii]] = batch[ii] + feed_dict[place_holders["is_training"]] = True + return feed_dict + + +class TestEnergy(unittest.TestCase): + def setUp(self): + self.dp_trainer = DpTrainer() + self.wanted_step = 0 + for key in dir(self.dp_trainer): + if not key.startswith("_") or key == "get_intermediate_state": + value = getattr(self.dp_trainer, key) + setattr(self, key, value) + + def test_consistency(self): + batch, head_dict, stat_dict, vs_dict = self.dp_trainer.get_intermediate_state( + self.wanted_step + ) + # Build DeePMD graph + my_ds = DpLoaderSet( + self.systems, + self.batch_size, + model_params={ + "descriptor": { + "type": "se_e2_a", + "sel": self.sel, + "rcut": self.rcut, + }, + "type_map": self.type_map, + }, + ) + sampled = make_stat_input( + my_ds.systems, my_ds.dataloaders, self.data_stat_nbatch + ) + my_model = get_model( + model_params={ + "descriptor": { + "type": "se_e2_a", + "sel": self.sel, + "rcut_smth": self.rcut_smth, + "rcut": self.rcut, + "neuron": self.filter_neuron, + "axis_neuron": self.axis_neuron, + }, + "fitting_net": {"neuron": self.n_neuron}, + "data_stat_nbatch": self.data_stat_nbatch, + "type_map": self.type_map, + }, + sampled=sampled, + ) + my_model.to(DEVICE) + my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.stop_steps) + my_loss = EnergyStdLoss( + starter_learning_rate=self.start_lr, + start_pref_e=self.start_pref_e, + limit_pref_e=self.limit_pref_e, + start_pref_f=self.start_pref_f, + limit_pref_f=self.limit_pref_f, + ) + + # Keep statistics consistency between 2 implentations + my_em = my_model.descriptor + mean = stat_dict["descriptor.mean"].reshape([self.ntypes, my_em.get_nsel(), 4]) + stddev = stat_dict["descriptor.stddev"].reshape( + [self.ntypes, my_em.get_nsel(), 4] + ) + my_em.set_stat_mean_and_stddev( + torch.tensor(mean, device=DEVICE), + torch.tensor(stddev, device=DEVICE), + ) + my_model.fitting_net.bias_atom_e = torch.tensor( + stat_dict["fitting_net.bias_atom_e"], device=DEVICE + ) + + # Keep parameter value consistency between 2 implentations + for name, param in my_model.named_parameters(): + name = name.replace("sea.", "") + var_name = torch2tf(name) + var = vs_dict[var_name].value + with torch.no_grad(): + src = torch.from_numpy(var) + dst = param.data + # print(name) + # print(src.mean(), src.std()) + # print(dst.mean(), dst.std()) + dst.copy_(src) + # Start forward computing + batch = my_ds.systems[0]._data_system.preprocess(batch) + batch["coord"].requires_grad_(True) + batch["natoms"] = torch.tensor( + batch["natoms_vec"], device=batch["coord"].device + ).unsqueeze(0) + model_predict = my_model( + batch["coord"], batch["atype"], batch["box"], do_atomic_virial=True + ) + model_predict_1 = my_model( + batch["coord"], batch["atype"], batch["box"], do_atomic_virial=False + ) + p_energy, p_force, p_virial, p_atomic_virial = ( + model_predict["energy"], + model_predict["force"], + model_predict["virial"], + model_predict["atomic_virial"], + ) + cur_lr = my_lr.value(self.wanted_step) + model_pred = { + "energy": p_energy, + "force": p_force, + } + label = { + "energy": batch["energy"], + "force": batch["force"], + } + loss, _ = my_loss(model_pred, label, int(batch["natoms"][0, 0]), cur_lr) + np.testing.assert_allclose( + head_dict["energy"], p_energy.view(-1).cpu().detach().numpy() + ) + np.testing.assert_allclose( + head_dict["force"], + p_force.view(*head_dict["force"].shape).cpu().detach().numpy(), + ) + rtol = 1e-5 + atol = 1e-8 + np.testing.assert_allclose( + head_dict["loss"], loss.cpu().detach().numpy(), rtol=rtol, atol=atol + ) + np.testing.assert_allclose( + head_dict["virial"], + p_virial.view(*head_dict["virial"].shape).cpu().detach().numpy(), + ) + np.testing.assert_allclose( + head_dict["virial"], + model_predict_1["virial"] + .view(*head_dict["virial"].shape) + .cpu() + .detach() + .numpy(), + ) + self.assertIsNone(model_predict_1.get("atomic_virial", None)) + np.testing.assert_allclose( + head_dict["atomic_virial"], + p_atomic_virial.view(*head_dict["atomic_virial"].shape) + .cpu() + .detach() + .numpy(), + ) + optimizer = torch.optim.Adam(my_model.parameters(), lr=cur_lr) + optimizer.zero_grad() + + def step(step_id): + bdata = self.training_data.get_trainning_batch() + optimizer.zero_grad() + + # Compare gradient for consistency + loss.backward() + + for name, param in my_model.named_parameters(): + name = name.replace("sea.", "") + var_name = torch2tf(name) + var_grad = vs_dict[var_name].gradient + param_grad = param.grad.cpu() + var_grad = torch.tensor(var_grad) + assert np.allclose(var_grad, param_grad, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_nlist.py b/source/tests/pt/test_nlist.py new file mode 100644 index 0000000000..27c03acfaa --- /dev/null +++ b/source/tests/pt/test_nlist.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + build_multiple_neighbor_list, + build_neighbor_list, + extend_coord_with_ghosts, + get_multiple_nlist_key, +) +from deepmd.pt.utils.region import ( + inter2phys, +) + +dtype = torch.float64 + + +class TestNeighList(unittest.TestCase): + def setUp(self): + self.nf = 3 + self.nloc = 2 + self.ns = 5 * 5 * 3 + self.nall = self.ns * self.nloc + self.cell = torch.tensor( + [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype + ).to(env.DEVICE) + self.icoord = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype).to( + env.DEVICE + ) + self.atype = torch.tensor([0, 1], dtype=torch.int).to(env.DEVICE) + [self.cell, self.icoord, self.atype] = [ + ii.unsqueeze(0) for ii in [self.cell, self.icoord, self.atype] + ] + self.coord = inter2phys(self.icoord, self.cell).view([-1, self.nloc * 3]) + self.cell = self.cell.view([-1, 9]) + [self.cell, self.coord, self.atype] = [ + torch.tile(ii, [self.nf, 1]) for ii in [self.cell, self.coord, self.atype] + ] + self.rcut = 1.01 + self.prec = 1e-10 + self.nsel = [10, 10] + # genrated by preprocess.build_neighbor_list + # ref_nlist, _, _ = legacy_build_neighbor_list( + # 2, ecoord[0], eatype[0], + # self.rcut, + # torch.tensor([10,20], dtype=torch.long), + # mapping[0], type_split=True, ) + self.ref_nlist = torch.tensor( + [ + [0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1], + ] + ).to(env.DEVICE) + + def test_build_notype(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + nlist = build_neighbor_list( + ecoord, + eatype, + self.nloc, + self.rcut, + sum(self.nsel), + distinguish_types=False, + ) + torch.testing.assert_close(nlist[0], nlist[1]) + nlist_mask = nlist[0] == -1 + nlist_loc = mapping[0][nlist[0]] + nlist_loc[nlist_mask] = -1 + torch.testing.assert_close( + torch.sort(nlist_loc, dim=-1)[0], + torch.sort(self.ref_nlist, dim=-1)[0], + ) + + def test_build_type(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + nlist = build_neighbor_list( + ecoord, + eatype, + self.nloc, + self.rcut, + self.nsel, + distinguish_types=True, + ) + torch.testing.assert_close(nlist[0], nlist[1]) + nlist_mask = nlist[0] == -1 + nlist_loc = mapping[0][nlist[0]] + nlist_loc[nlist_mask] = -1 + for ii in range(2): + torch.testing.assert_close( + torch.sort(torch.split(nlist_loc, self.nsel, dim=-1)[ii], dim=-1)[0], + torch.sort(torch.split(self.ref_nlist, self.nsel, dim=-1)[ii], dim=-1)[ + 0 + ], + ) + + def test_build_multiple_nlist(self): + rcuts = [1.01, 2.01] + nsels = [20, 80] + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, max(rcuts) + ) + nlist1 = build_neighbor_list( + ecoord, + eatype, + self.nloc, + rcuts[1], + nsels[1] - 1, + distinguish_types=False, + ) + pad = -1 * torch.ones( + [self.nf, self.nloc, 1], dtype=nlist1.dtype, device=nlist1.device + ) + nlist2 = torch.cat([nlist1, pad], dim=-1) + nlist0 = build_neighbor_list( + ecoord, + eatype, + self.nloc, + rcuts[0], + nsels[0], + distinguish_types=False, + ) + nlists = build_multiple_neighbor_list(ecoord, nlist1, rcuts, nsels) + for dd in range(2): + self.assertEqual( + nlists[get_multiple_nlist_key(rcuts[dd], nsels[dd])].shape[-1], + nsels[dd], + ) + torch.testing.assert_close( + nlists[get_multiple_nlist_key(rcuts[0], nsels[0])], + nlist0, + ) + torch.testing.assert_close( + nlists[get_multiple_nlist_key(rcuts[1], nsels[1])], + nlist2, + ) + + def test_extend_coord(self): + ecoord, eatype, mapping = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + # expected ncopy x nloc + self.assertEqual(list(ecoord.shape), [self.nf, self.nall * 3]) + self.assertEqual(list(eatype.shape), [self.nf, self.nall]) + self.assertEqual(list(mapping.shape), [self.nf, self.nall]) + # check the nloc part is identical with original coord + torch.testing.assert_close( + ecoord[:, : self.nloc * 3], self.coord, rtol=self.prec, atol=self.prec + ) + # check the shift vectors are aligned with grid + shift_vec = ( + ecoord.view([-1, self.ns, self.nloc, 3]) + - self.coord.view([-1, self.nloc, 3])[:, None, :, :] + ) + shift_vec = shift_vec.view([-1, self.nall, 3]) + # hack!!! assumes identical cell across frames + shift_vec = torch.matmul( + shift_vec, torch.linalg.inv(self.cell.view([self.nf, 3, 3])[0]) + ) + # nf x nall x 3 + shift_vec = torch.round(shift_vec) + # check: identical shift vecs + torch.testing.assert_close( + shift_vec[0], shift_vec[1], rtol=self.prec, atol=self.prec + ) + # check: shift idx aligned with grid + mm, cc = torch.unique(shift_vec[0][:, 0], dim=-1, return_counts=True) + torch.testing.assert_close( + mm, + torch.tensor([-2, -1, 0, 1, 2], dtype=dtype).to(env.DEVICE), + rtol=self.prec, + atol=self.prec, + ) + torch.testing.assert_close( + cc, + torch.tensor([30, 30, 30, 30, 30], dtype=torch.long).to(env.DEVICE), + rtol=self.prec, + atol=self.prec, + ) + mm, cc = torch.unique(shift_vec[1][:, 1], dim=-1, return_counts=True) + torch.testing.assert_close( + mm, + torch.tensor([-2, -1, 0, 1, 2], dtype=dtype).to(env.DEVICE), + rtol=self.prec, + atol=self.prec, + ) + torch.testing.assert_close( + cc, + torch.tensor([30, 30, 30, 30, 30], dtype=torch.long).to(env.DEVICE), + rtol=self.prec, + atol=self.prec, + ) + mm, cc = torch.unique(shift_vec[1][:, 2], dim=-1, return_counts=True) + torch.testing.assert_close( + mm, + torch.tensor([-1, 0, 1], dtype=dtype).to(env.DEVICE), + rtol=self.prec, + atol=self.prec, + ) + torch.testing.assert_close( + cc, + torch.tensor([50, 50, 50], dtype=torch.long).to(env.DEVICE), + rtol=self.prec, + atol=self.prec, + ) diff --git a/source/tests/pt/test_permutation.py b/source/tests/pt/test_permutation.py new file mode 100644 index 0000000000..b9724bb2af --- /dev/null +++ b/source/tests/pt/test_permutation.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest +from pathlib import ( + Path, +) + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) + +dtype = torch.float64 + +model_se_e2_a = { + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "se_e2_a", + "sel": [46, 92, 4], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 20, +} + +model_dpa2 = { + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "dpa2", + "repinit_rcut": 6.0, + "repinit_rcut_smth": 2.0, + "repinit_nsel": 30, + "repformer_rcut": 4.0, + "repformer_rcut_smth": 0.5, + "repformer_nsel": 20, + "repinit_neuron": [2, 4, 8], + "repinit_axis_neuron": 4, + "repinit_activation": "tanh", + "repformer_nlayers": 12, + "repformer_g1_dim": 8, + "repformer_g2_dim": 5, + "repformer_attn2_hidden": 3, + "repformer_attn2_nhead": 1, + "repformer_attn1_hidden": 5, + "repformer_attn1_nhead": 1, + "repformer_axis_dim": 4, + "repformer_update_h2": False, + "repformer_update_g1_has_conv": True, + "repformer_update_g1_has_grrg": True, + "repformer_update_g1_has_drrd": True, + "repformer_update_g1_has_attn": True, + "repformer_update_g2_has_g1g1": True, + "repformer_update_g2_has_attn": True, + "repformer_attn2_has_gate": True, + "repformer_add_type_ebd_to_seq": False, + }, + "fitting_net": { + "neuron": [24, 24], + "resnet_dt": True, + "seed": 1, + }, +} + +model_dpa1 = { + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [25, 50, 100], + "axis_neuron": 16, + "attn": 64, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "post_ln": True, + "ffn": False, + "ffn_embed_dim": 512, + "activation": "tanh", + "scaling_factor": 1.0, + "head_num": 1, + "normalize": False, + "temperature": 1.0, + "set_davg_zero": True, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + }, +} + + +model_hybrid = { + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "hybrid", + "list": [ + { + "type": "se_atten", + "sel": 120, + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [25, 50, 100], + "axis_neuron": 16, + "attn": 128, + "attn_layer": 0, + "attn_dotr": True, + "attn_mask": False, + "post_ln": True, + "ffn": False, + "ffn_embed_dim": 1024, + "activation": "tanh", + "scaling_factor": 1.0, + "head_num": 1, + "normalize": True, + "temperature": 1.0, + }, + { + "type": "dpa2", + "repinit_rcut": 6.0, + "repinit_rcut_smth": 2.0, + "repinit_nsel": 30, + "repformer_rcut": 4.0, + "repformer_rcut_smth": 0.5, + "repformer_nsel": 10, + "repinit_neuron": [2, 4, 8], + "repinit_axis_neuron": 4, + "repinit_activation": "tanh", + "repformer_nlayers": 12, + "repformer_g1_dim": 8, + "repformer_g2_dim": 5, + "repformer_attn2_hidden": 3, + "repformer_attn2_nhead": 1, + "repformer_attn1_hidden": 5, + "repformer_attn1_nhead": 1, + "repformer_axis_dim": 4, + "repformer_update_h2": False, + "repformer_update_g1_has_conv": True, + "repformer_update_g1_has_grrg": True, + "repformer_update_g1_has_drrd": True, + "repformer_update_g1_has_attn": True, + "repformer_update_g2_has_g1g1": True, + "repformer_update_g2_has_attn": True, + "repformer_attn2_has_gate": True, + "repformer_add_type_ebd_to_seq": False, + }, + ], + }, + "fitting_net": { + "neuron": [240, 240, 240], + "resnet_dt": True, + "seed": 1, + "_comment": " that's all", + }, + "_comment": " that's all", +} + + +def make_sample(model_params): + training_systems = [ + str(Path(__file__).parent / "water/data/data_0"), + ] + data_stat_nbatch = model_params.get("data_stat_nbatch", 10) + train_data = DpLoaderSet( + training_systems, + batch_size=4, + model_params=model_params.copy(), + ) + sampled = make_stat_input( + train_data.systems, train_data.dataloaders, data_stat_nbatch + ) + return sampled + + +class PermutationTest: + def test( + self, + ): + natoms = 5 + cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) + cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) + coord = torch.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) + idx_perm = [1, 0, 4, 3, 2] + e0, f0, v0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret0 = { + "energy": e0.squeeze(0), + "force": f0.squeeze(0), + "virial": v0.squeeze(0), + } + e1, f1, v1 = eval_model( + self.model, coord[idx_perm].unsqueeze(0), cell.unsqueeze(0), atype[idx_perm] + ) + ret1 = { + "energy": e1.squeeze(0), + "force": f1.squeeze(0), + "virial": v1.squeeze(0), + } + prec = 1e-10 + torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) + torch.testing.assert_close( + ret0["force"][idx_perm], ret1["force"], rtol=prec, atol=prec + ) + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + ret0["virial"], ret1["virial"], rtol=prec, atol=prec + ) + + +class TestEnergyModelSeA(unittest.TestCase, PermutationTest): + def setUp(self): + model_params = copy.deepcopy(model_se_e2_a) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelDPA1(unittest.TestCase, PermutationTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelDPA2(unittest.TestCase, PermutationTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestForceModelDPA2(unittest.TestCase, PermutationTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + model_params["fitting_net"]["type"] = "direct_force_ener" + self.type_split = True + self.test_virial = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +@unittest.skip("hybrid not supported at the moment") +class TestEnergyModelHybrid(unittest.TestCase, PermutationTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +@unittest.skip("hybrid not supported at the moment") +class TestForceModelHybrid(unittest.TestCase, PermutationTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + model_params["fitting_net"]["type"] = "direct_force_ener" + sampled = make_sample(model_params) + self.type_split = True + self.test_virial = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +# class TestEnergyFoo(unittest.TestCase): +# def test(self): +# model_params = model_dpau +# sampled = make_sample(model_params) +# self.model = EnergyModelDPAUni(model_params, sampled).to(env.DEVICE) + +# natoms = 5 +# cell = torch.rand([3, 3], dtype=dtype) +# cell = (cell + cell.T) + 5. * torch.eye(3) +# coord = torch.rand([natoms, 3], dtype=dtype) +# coord = torch.matmul(coord, cell) +# atype = torch.IntTensor([0, 0, 0, 1, 1]) +# idx_perm = [1, 0, 4, 3, 2] +# ret0 = infer_model(self.model, coord, cell, atype, type_split=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_permutation_denoise.py b/source/tests/pt/test_permutation_denoise.py new file mode 100644 index 0000000000..47bd0360f2 --- /dev/null +++ b/source/tests/pt/test_permutation_denoise.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) + +from .test_permutation import ( # model_dpau, + make_sample, + model_dpa1, + model_dpa2, + model_hybrid, +) + +dtype = torch.float64 + +model_dpa1 = copy.deepcopy(model_dpa1) +model_dpa2 = copy.deepcopy(model_dpa2) +model_hybrid = copy.deepcopy(model_hybrid) +model_dpa1["type_map"] = ["O", "H", "B", "MASKED_TOKEN"] +model_dpa1.pop("fitting_net") +model_dpa2["type_map"] = ["O", "H", "B", "MASKED_TOKEN"] +model_dpa2.pop("fitting_net") +model_hybrid["type_map"] = ["O", "H", "B", "MASKED_TOKEN"] +model_hybrid.pop("fitting_net") + + +class PermutationDenoiseTest: + def test( + self, + ): + natoms = 5 + cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) + cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) + coord = torch.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) + idx_perm = [1, 0, 4, 3, 2] + updated_c0, logits0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + ret0 = {"updated_coord": updated_c0.squeeze(0), "logits": logits0.squeeze(0)} + updated_c1, logits1 = eval_model( + self.model, + coord[idx_perm].unsqueeze(0), + cell.unsqueeze(0), + atype[idx_perm], + denoise=True, + ) + ret1 = {"updated_coord": updated_c1.squeeze(0), "logits": logits1.squeeze(0)} + prec = 1e-10 + torch.testing.assert_close( + ret0["updated_coord"][idx_perm], ret1["updated_coord"], rtol=prec, atol=prec + ) + torch.testing.assert_close( + ret0["logits"][idx_perm], ret1["logits"], rtol=prec, atol=prec + ) + + +class TestDenoiseModelDPA1(unittest.TestCase, PermutationDenoiseTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestDenoiseModelDPA2(unittest.TestCase, PermutationDenoiseTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +# @unittest.skip("hybrid not supported at the moment") +# class TestDenoiseModelHybrid(unittest.TestCase, TestPermutationDenoise): +# def setUp(self): +# model_params = copy.deepcopy(model_hybrid_denoise) +# sampled = make_sample(model_params) +# self.type_split = True +# self.model = get_model(model_params, sampled).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_region.py b/source/tests/pt/test_region.py new file mode 100644 index 0000000000..e8a3346562 --- /dev/null +++ b/source/tests/pt/test_region.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt.utils.preprocess import ( + Region3D, +) +from deepmd.pt.utils.region import ( + inter2phys, + to_face_distance, +) + +dtype = torch.float64 + + +class TestRegion(unittest.TestCase): + def setUp(self): + self.cell = torch.tensor( + [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype + ) + self.cell = self.cell.unsqueeze(0).unsqueeze(0) + self.cell = torch.tile(self.cell, [4, 5, 1, 1]) + self.prec = 1e-8 + + def test_inter_to_phys(self): + inter = torch.rand([4, 5, 3, 3], dtype=dtype) + phys = inter2phys(inter, self.cell) + for ii in range(4): + for jj in range(5): + expected_phys = torch.matmul(inter[ii, jj], self.cell[ii, jj]) + torch.testing.assert_close( + phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec + ) + + def test_to_face_dist(self): + cell0 = self.cell[0][0].numpy() + vol = np.linalg.det(cell0) + # area of surfaces xy, xz, yz + sxy = np.linalg.norm(np.cross(cell0[0], cell0[1])) + sxz = np.linalg.norm(np.cross(cell0[0], cell0[2])) + syz = np.linalg.norm(np.cross(cell0[1], cell0[2])) + # vol / area gives distance + dz = vol / sxy + dy = vol / sxz + dx = vol / syz + expected = torch.tensor([dx, dy, dz]) + dists = to_face_distance(self.cell) + for ii in range(4): + for jj in range(5): + torch.testing.assert_close( + dists[ii][jj], expected, rtol=self.prec, atol=self.prec + ) + + +class TestLegacyRegion(unittest.TestCase): + def setUp(self): + self.cell = torch.tensor( + [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype + ) + self.prec = 1e-6 + + def test_inter_to_phys(self): + inter = torch.rand([3, 3], dtype=dtype) + reg = Region3D(self.cell) + phys = reg.inter2phys(inter) + expected_phys = torch.matmul(inter, self.cell) + torch.testing.assert_close(phys, expected_phys, rtol=self.prec, atol=self.prec) + + def test_inter_to_inter(self): + inter = torch.rand([3, 3], dtype=dtype) + reg = Region3D(self.cell) + new_inter = reg.phys2inter(reg.inter2phys(inter)) + torch.testing.assert_close(inter, new_inter, rtol=self.prec, atol=self.prec) + + def test_to_face_dist(self): + pass diff --git a/source/tests/pt/test_rot.py b/source/tests/pt/test_rot.py new file mode 100644 index 0000000000..b5d9d9b64b --- /dev/null +++ b/source/tests/pt/test_rot.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) + +from .test_permutation import ( # model_dpau, + make_sample, + model_dpa1, + model_dpa2, + model_hybrid, + model_se_e2_a, +) + +dtype = torch.float64 + + +class RotTest: + def test( + self, + ): + prec = 1e-10 + natoms = 5 + cell = 10.0 * torch.eye(3, dtype=dtype).to(env.DEVICE) + coord = 2 * torch.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + shift = torch.tensor([4, 4, 4], dtype=dtype).to(env.DEVICE) + atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) + from scipy.stats import ( + special_ortho_group, + ) + + rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype).to(env.DEVICE) + + # rotate only coord and shift to the center of cell + coord_rot = torch.matmul(coord, rmat) + e0, f0, v0 = eval_model( + self.model, (coord + shift).unsqueeze(0), cell.unsqueeze(0), atype + ) + ret0 = { + "energy": e0.squeeze(0), + "force": f0.squeeze(0), + "virial": v0.squeeze(0), + } + e1, f1, v1 = eval_model( + self.model, (coord_rot + shift).unsqueeze(0), cell.unsqueeze(0), atype + ) + ret1 = { + "energy": e1.squeeze(0), + "force": f1.squeeze(0), + "virial": v1.squeeze(0), + } + torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) + torch.testing.assert_close( + torch.matmul(ret0["force"], rmat), ret1["force"], rtol=prec, atol=prec + ) + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + torch.matmul(rmat.T, torch.matmul(ret0["virial"], rmat)), + ret1["virial"], + rtol=prec, + atol=prec, + ) + + # rotate coord and cell + torch.manual_seed(0) + cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) + cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) + coord = torch.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) + coord_rot = torch.matmul(coord, rmat) + cell_rot = torch.matmul(cell, rmat) + e0, f0, v0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret0 = { + "energy": e0.squeeze(0), + "force": f0.squeeze(0), + "virial": v0.squeeze(0), + } + e1, f1, v1 = eval_model( + self.model, coord_rot.unsqueeze(0), cell_rot.unsqueeze(0), atype + ) + ret1 = { + "energy": e1.squeeze(0), + "force": f1.squeeze(0), + "virial": v1.squeeze(0), + } + torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) + torch.testing.assert_close( + torch.matmul(ret0["force"], rmat), ret1["force"], rtol=prec, atol=prec + ) + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + torch.matmul(rmat.T, torch.matmul(ret0["virial"], rmat)), + ret1["virial"], + rtol=prec, + atol=prec, + ) + + +class TestEnergyModelSeA(unittest.TestCase, RotTest): + def setUp(self): + model_params = copy.deepcopy(model_se_e2_a) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelDPA1(unittest.TestCase, RotTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelDPA2(unittest.TestCase, RotTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestForceModelDPA2(unittest.TestCase, RotTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + model_params["fitting_net"]["type"] = "direct_force_ener" + self.type_split = True + self.test_virial = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +@unittest.skip("hybrid not supported at the moment") +class TestEnergyModelHybrid(unittest.TestCase, RotTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +@unittest.skip("hybrid not supported at the moment") +class TestForceModelHybrid(unittest.TestCase, RotTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + model_params["fitting_net"]["type"] = "direct_force_ener" + sampled = make_sample(model_params) + self.type_split = True + self.test_virial = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_rot_denoise.py b/source/tests/pt/test_rot_denoise.py new file mode 100644 index 0000000000..cab8de7bec --- /dev/null +++ b/source/tests/pt/test_rot_denoise.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) + +from .test_permutation_denoise import ( + make_sample, + model_dpa1, + model_dpa2, +) + +dtype = torch.float64 + + +class RotDenoiseTest: + def test( + self, + ): + prec = 1e-10 + natoms = 5 + cell = 10.0 * torch.eye(3, dtype=dtype).to(env.DEVICE) + coord = 2 * torch.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + shift = torch.tensor([4, 4, 4], dtype=dtype).to(env.DEVICE) + atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) + from scipy.stats import ( + special_ortho_group, + ) + + rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype).to(env.DEVICE) + + # rotate only coord and shift to the center of cell + coord_rot = torch.matmul(coord, rmat) + update_c0, logits0 = eval_model( + self.model, + (coord + shift).unsqueeze(0), + cell.unsqueeze(0), + atype, + denoise=True, + ) + update_c0 = update_c0 - (coord + shift).unsqueeze(0) + ret0 = {"updated_coord": update_c0.squeeze(0), "logits": logits0.squeeze(0)} + update_c1, logits1 = eval_model( + self.model, + (coord_rot + shift).unsqueeze(0), + cell.unsqueeze(0), + atype, + denoise=True, + ) + update_c1 = update_c1 - (coord_rot + shift).unsqueeze(0) + ret1 = {"updated_coord": update_c1.squeeze(0), "logits": logits1.squeeze(0)} + torch.testing.assert_close( + torch.matmul(ret0["updated_coord"], rmat), + ret1["updated_coord"], + rtol=prec, + atol=prec, + ) + torch.testing.assert_close(ret0["logits"], ret1["logits"], rtol=prec, atol=prec) + + # rotate coord and cell + torch.manual_seed(0) + cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) + cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) + coord = torch.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) + coord_rot = torch.matmul(coord, rmat) + cell_rot = torch.matmul(cell, rmat) + update_c0, logits0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + ret0 = {"updated_coord": update_c0.squeeze(0), "logits": logits0.squeeze(0)} + update_c1, logits1 = eval_model( + self.model, + coord_rot.unsqueeze(0), + cell_rot.unsqueeze(0), + atype, + denoise=True, + ) + ret1 = {"updated_coord": update_c1.squeeze(0), "logits": logits1.squeeze(0)} + torch.testing.assert_close(ret0["logits"], ret1["logits"], rtol=prec, atol=prec) + torch.testing.assert_close( + torch.matmul(ret0["updated_coord"], rmat), + ret1["updated_coord"], + rtol=prec, + atol=prec, + ) + + +class TestDenoiseModelDPA1(unittest.TestCase, RotDenoiseTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestDenoiseModelDPA2(unittest.TestCase, RotDenoiseTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +# @unittest.skip("hybrid not supported at the moment") +# class TestEnergyModelHybrid(unittest.TestCase, TestRotDenoise): +# def setUp(self): +# model_params = copy.deepcopy(model_hybrid_denoise) +# sampled = make_sample(model_params) +# self.type_split = True +# self.model = get_model(model_params, sampled).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_rotation.py b/source/tests/pt/test_rotation.py new file mode 100644 index 0000000000..4b49377a27 --- /dev/null +++ b/source/tests/pt/test_rotation.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import unittest +from pathlib import ( + Path, +) +from typing import ( + List, + Optional, +) + +import numpy as np +import torch +from scipy.stats import ( + special_ortho_group, +) + +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSystem, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) + + +class CheckSymmetry(DeepmdDataSystem): + def __init__( + self, + sys_path: str, + rcut, + sec, + type_map: Optional[List[str]] = None, + type_split=True, + ): + super().__init__(sys_path, rcut, sec, type_map, type_split) + + def get_rotation(self, index, rotation_matrix): + for i in range( + 0, len(self._dirs) + 1 + ): # note: if different sets can be merged, prefix sum is unused to calculate + if index < self.prefix_sum[i]: + break + frames = self._load_set(self._dirs[i - 1]) + frames["coord"] = np.dot( + rotation_matrix, frames["coord"].reshape(-1, 3).T + ).T.reshape(self.nframes, -1) + frames["box"] = np.dot( + rotation_matrix, frames["box"].reshape(-1, 3).T + ).T.reshape(self.nframes, -1) + frames["force"] = np.dot( + rotation_matrix, frames["force"].reshape(-1, 3).T + ).T.reshape(self.nframes, -1) + frame = self.single_preprocess(frames, index - self.prefix_sum[i - 1]) + return frame + + +def get_data(batch): + inputs = {} + for key in ["coord", "atype", "box"]: + inputs[key] = batch[key].unsqueeze(0).to(env.DEVICE) + return inputs + + +class TestRotation(unittest.TestCase): + def setUp(self): + with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: + self.config = json.load(fin) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.rotation = special_ortho_group.rvs(3) + self.get_dataset(0) + self.get_model() + + def get_model(self): + training_systems = self.config["training"]["training_data"]["systems"] + model_params = self.config["model"] + data_stat_nbatch = model_params.get("data_stat_nbatch", 10) + train_data = DpLoaderSet( + training_systems, + self.config["training"]["training_data"]["batch_size"], + model_params, + ) + sampled = make_stat_input( + train_data.systems, train_data.dataloaders, data_stat_nbatch + ) + self.model = get_model(self.config["model"], sampled).to(env.DEVICE) + + def get_dataset(self, system_index=0, batch_index=0): + systems = self.config["training"]["training_data"]["systems"] + rcut = self.config["model"]["descriptor"]["rcut"] + sel = self.config["model"]["descriptor"]["sel"] + sec = torch.cumsum(torch.tensor(sel), dim=0) + type_map = self.config["model"]["type_map"] + dpdatasystem = CheckSymmetry( + sys_path=systems[system_index], rcut=rcut, sec=sec, type_map=type_map + ) + self.origin_batch = dpdatasystem._get_item(batch_index) + self.rotated_batch = dpdatasystem.get_rotation(batch_index, self.rotation) + + def test_rotation(self): + result1 = self.model(**get_data(self.origin_batch)) + result2 = self.model(**get_data(self.rotated_batch)) + rotation = torch.from_numpy(self.rotation).to(env.DEVICE) + self.assertTrue(result1["energy"] == result2["energy"]) + if "force" in result1: + self.assertTrue( + torch.allclose( + result2["force"][0], torch.matmul(rotation, result1["force"][0].T).T + ) + ) + if "virial" in result1: + self.assertTrue( + torch.allclose( + result2["virial"][0], + torch.matmul( + torch.matmul(rotation, result1["virial"][0].T), rotation.T + ), + ) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_sampler.py b/source/tests/pt/test_sampler.py new file mode 100644 index 0000000000..0ff16ed7c7 --- /dev/null +++ b/source/tests/pt/test_sampler.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest +from pathlib import ( + Path, +) + +import numpy as np +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, + get_weighted_sampler, +) +from deepmd.tf.common import ( + expand_sys_str, +) +from deepmd.tf.utils import random as tf_random +from deepmd.tf.utils.data_system import ( + DeepmdDataSystem, +) + +CUR_DIR = os.path.dirname(__file__) + + +class TestSampler(unittest.TestCase): + def setUp(self): + with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: + content = fin.read() + config = json.loads(content) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + config["training"]["training_data"]["systems"] = data_file + config["training"]["validation_data"]["systems"] = data_file + model_config = config["model"] + self.rcut = model_config["descriptor"]["rcut"] + self.rcut_smth = model_config["descriptor"]["rcut_smth"] + self.sel = model_config["descriptor"]["sel"] + self.batch_size = config["training"]["training_data"]["batch_size"] + self.systems = config["training"]["validation_data"]["systems"] + if isinstance(self.systems, str): + self.systems = expand_sys_str(self.systems) + self.my_dataset = DpLoaderSet( + self.systems, + self.batch_size, + model_params={ + "descriptor": { + "type": "se_e2_a", + "sel": self.sel, + "rcut": self.rcut, + }, + "type_map": model_config["type_map"], + }, + seed=10, + shuffle=False, + ) + + tf_random.seed(10) + self.dp_dataset = DeepmdDataSystem(self.systems, self.batch_size, 1, self.rcut) + + def test_sampler_debug_info(self): + dataloader = DataLoader( + self.my_dataset, + sampler=get_weighted_sampler(self.my_dataset, prob_style="prob_sys_size"), + batch_size=None, + num_workers=0, # setting to 0 diverges the behavior of its iterator; should be >=1 + drop_last=False, + pin_memory=True, + ) + batch_data = next(iter(dataloader)) + sid = batch_data["sid"] + fid = batch_data["fid"][0] + coord = batch_data["coord"].squeeze(0) + frame = self.my_dataset.systems[sid].__getitem__(fid) + self.assertTrue(np.allclose(coord, frame["coord"])) + + def test_auto_prob_uniform(self): + auto_prob_style = "prob_uniform" + sampler = get_weighted_sampler(self.my_dataset, prob_style=auto_prob_style) + my_probs = np.array(sampler.weights) + self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style) + dp_probs = np.array(self.dp_dataset.sys_probs) + self.assertTrue(np.allclose(my_probs, dp_probs)) + + def test_auto_prob_sys_size(self): + auto_prob_style = "prob_sys_size" + sampler = get_weighted_sampler(self.my_dataset, prob_style=auto_prob_style) + my_probs = np.array(sampler.weights) + self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style) + dp_probs = np.array(self.dp_dataset.sys_probs) + self.assertTrue(np.allclose(my_probs, dp_probs)) + + def test_auto_prob_sys_size_ext(self): + auto_prob_style = "prob_sys_size;0:1:0.2;1:3:0.8" + sampler = get_weighted_sampler(self.my_dataset, prob_style=auto_prob_style) + my_probs = np.array(sampler.weights) + self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style) + dp_probs = np.array(self.dp_dataset.sys_probs) + self.assertTrue(np.allclose(my_probs, dp_probs)) + + def test_sys_probs(self): + sys_probs = [0.1, 0.4, 0.5] + sampler = get_weighted_sampler( + self.my_dataset, prob_style=sys_probs, sys_prob=True + ) + my_probs = np.array(sampler.weights) + self.dp_dataset.set_sys_probs(sys_probs=sys_probs) + dp_probs = np.array(self.dp_dataset.sys_probs) + self.assertTrue(np.allclose(my_probs, dp_probs)) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_saveload_dpa1.py b/source/tests/pt/test_saveload_dpa1.py new file mode 100644 index 0000000000..d1043f7029 --- /dev/null +++ b/source/tests/pt/test_saveload_dpa1.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import json +import os +import unittest +from pathlib import ( + Path, +) + +import torch +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.loss import ( + EnergyStdLoss, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.dataloader import ( + BufferedIterator, + DpLoaderSet, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) +from deepmd.tf.common import ( + expand_sys_str, +) + + +def get_dataset(config): + model_config = config["model"] + rcut = model_config["descriptor"]["rcut"] + sel = model_config["descriptor"]["sel"] + systems = config["training"]["validation_data"]["systems"] + if isinstance(systems, str): + systems = expand_sys_str(systems) + batch_size = config["training"]["training_data"]["batch_size"] + type_map = model_config["type_map"] + + dataset = DpLoaderSet( + systems, + batch_size, + model_params={ + "descriptor": { + "type": "dpa1", + "sel": sel, + "rcut": rcut, + }, + "type_map": type_map, + }, + ) + data_stat_nbatch = model_config.get("data_stat_nbatch", 10) + sampled = make_stat_input(dataset.systems, dataset.dataloaders, data_stat_nbatch) + return dataset, sampled + + +class TestSaveLoadDPA1(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as fin: + self.config = json.load(fin) + self.config["loss"]["starter_learning_rate"] = self.config["learning_rate"][ + "start_lr" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.dataset, self.sampled = get_dataset(self.config) + self.training_dataloader = DataLoader( + self.dataset, + sampler=torch.utils.data.RandomSampler(self.dataset), + batch_size=None, + num_workers=0, # setting to 0 diverges the behavior of its iterator; should be >=1 + drop_last=False, + pin_memory=True, + ) + self.training_data = BufferedIterator(iter(self.training_dataloader)) + self.loss = EnergyStdLoss(**self.config["loss"]) + self.cur_lr = 1 + self.task_key = "Default" + self.input_dict, self.label_dict = self.get_data() + self.start_lr = self.config["learning_rate"]["start_lr"] + + def get_model_result(self, read=False, model_file="tmp_model.pt"): + wrapper = self.create_wrapper(read) + optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr) + optimizer.zero_grad() + if read: + wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE)) + os.remove(model_file) + else: + torch.save(wrapper.state_dict(), model_file) + result = wrapper( + **self.input_dict, + cur_lr=self.cur_lr, + label=self.label_dict, + task_key=self.task_key, + )[0] + return result + + def create_wrapper(self, read: bool): + model_config = copy.deepcopy(self.config["model"]) + sampled = copy.deepcopy(self.sampled) + model_config["resuming"] = read + model_config["stat_file_dir"] = "stat_files" + model_config["stat_file"] = "stat.npz" + model_config["stat_file_path"] = os.path.join( + model_config["stat_file_dir"], model_config["stat_file"] + ) + model = get_model(model_config, sampled).to(env.DEVICE) + return ModelWrapper(model, self.loss) + + def get_data(self): + try: + batch_data = next(iter(self.training_data)) + except StopIteration: + # Refresh the status of the dataloader to start from a new epoch + self.training_data = BufferedIterator(iter(self.training_dataloader)) + batch_data = next(iter(self.training_data)) + input_dict = {} + for item in ["coord", "atype", "box"]: + if item in batch_data: + input_dict[item] = batch_data[item] + else: + input_dict[item] = None + label_dict = {} + for item in ["energy", "force", "virial"]: + if item in batch_data: + label_dict[item] = batch_data[item] + return input_dict, label_dict + + def test_saveload(self): + result1 = self.get_model_result() + result2 = self.get_model_result(read=True) + final_result = all( + torch.allclose(result1[item], result2[item]) for item in result1 + ) + self.assertTrue(final_result) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_saveload_se_e2_a.py b/source/tests/pt/test_saveload_se_e2_a.py new file mode 100644 index 0000000000..95d7f97a88 --- /dev/null +++ b/source/tests/pt/test_saveload_se_e2_a.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import json +import os +import unittest +from pathlib import ( + Path, +) + +import torch +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.loss import ( + EnergyStdLoss, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.dataloader import ( + BufferedIterator, + DpLoaderSet, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) +from deepmd.tf.common import ( + expand_sys_str, +) + + +def get_dataset(config): + model_config = config["model"] + rcut = model_config["descriptor"]["rcut"] + sel = model_config["descriptor"]["sel"] + systems = config["training"]["validation_data"]["systems"] + if isinstance(systems, str): + systems = expand_sys_str(systems) + batch_size = config["training"]["training_data"]["batch_size"] + type_map = model_config["type_map"] + + dataset = DpLoaderSet( + systems, + batch_size, + model_params={ + "descriptor": { + "type": "se_e2_a", + "sel": sel, + "rcut": rcut, + }, + "type_map": type_map, + }, + ) + data_stat_nbatch = model_config.get("data_stat_nbatch", 10) + sampled = make_stat_input(dataset.systems, dataset.dataloaders, data_stat_nbatch) + return dataset, sampled + + +class TestSaveLoadSeA(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_e2_a.json") + with open(input_json) as fin: + self.config = json.load(fin) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["loss"]["starter_learning_rate"] = self.config["learning_rate"][ + "start_lr" + ] + self.dataset, self.sampled = get_dataset(self.config) + self.training_dataloader = DataLoader( + self.dataset, + sampler=torch.utils.data.RandomSampler(self.dataset), + batch_size=None, + num_workers=0, # setting to 0 diverges the behavior of its iterator; should be >=1 + drop_last=False, + pin_memory=True, + ) + self.training_data = BufferedIterator(iter(self.training_dataloader)) + self.loss = EnergyStdLoss(**self.config["loss"]) + self.cur_lr = 1 + self.task_key = "Default" + self.input_dict, self.label_dict = self.get_data() + self.start_lr = self.config["learning_rate"]["start_lr"] + + def get_model_result(self, read=False, model_file="tmp_model.pt"): + wrapper = self.create_wrapper() + optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr) + optimizer.zero_grad() + if read: + wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE)) + os.remove(model_file) + else: + torch.save(wrapper.state_dict(), model_file) + result = wrapper( + **self.input_dict, + cur_lr=self.cur_lr, + label=self.label_dict, + task_key=self.task_key, + )[0] + return result + + def create_wrapper(self): + model_config = copy.deepcopy(self.config["model"]) + sampled = copy.deepcopy(self.sampled) + model = get_model(model_config, sampled).to(env.DEVICE) + return ModelWrapper(model, self.loss) + + def get_data(self): + try: + batch_data = next(iter(self.training_data)) + except StopIteration: + # Refresh the status of the dataloader to start from a new epoch + self.training_data = BufferedIterator(iter(self.training_dataloader)) + batch_data = next(iter(self.training_data)) + input_dict = {} + for item in ["coord", "atype", "box"]: + if item in batch_data: + input_dict[item] = batch_data[item] + else: + input_dict[item] = None + label_dict = {} + for item in ["energy", "force", "virial"]: + if item in batch_data: + label_dict[item] = batch_data[item] + return input_dict, label_dict + + def test_saveload(self): + result1 = self.get_model_result() + result2 = self.get_model_result(read=True) + final_result = all( + torch.allclose(result1[item], result2[item]) for item in result1 + ) + self.assertTrue(final_result) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_se_e2_a.py b/source/tests/pt/test_se_e2_a.py new file mode 100644 index 0000000000..96a17c2bad --- /dev/null +++ b/source/tests/pt/test_se_e2_a.py @@ -0,0 +1,199 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +try: + # from deepmd.model_format import PRECISION_DICT as DP_PRECISION_DICT + from deepmd.model_format import DescrptSeA as DPDescrptSeA + + support_se_e2_a = True +except ModuleNotFoundError: + support_se_e2_a = False +except ImportError: + support_se_e2_a = False + +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestCaseSingleFrameWithNlist: + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 3 + self.nall = 4 + self.nf, self.nt = 1, 2 + self.coord_ext = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, -2, 0], + ], + dtype=np.float64, + ).reshape([1, self.nall * 3]) + self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) + # sel = [5, 2] + self.sel = [5, 2] + self.nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, 0, -1], + ], + dtype=int, + ).reshape([1, self.nloc, sum(self.sel)]) + self.rcut = 0.4 + self.rcut_smth = 2.2 + + +# to be merged with the tf test case +@unittest.skipIf(not support_se_e2_a, "EnvMat not supported") +class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + # sea new impl + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + old_impl=False, + ).to(env.DEVICE) + dd0.sea.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.sea.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptSeA.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl + dd2 = DPDescrptSeA.deserialize(dd0.serialize()) + rd2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # old impl + if idt is False and prec == "float64": + dd3 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + old_impl=True, + ).to(env.DEVICE) + dd0_state_dict = dd0.sea.state_dict() + dd3_state_dict = dd3.sea.state_dict() + for i in dd3_state_dict: + dd3_state_dict[i] = ( + dd0_state_dict[ + i.replace(".deep_layers.", ".layers.").replace( + "filter_layers_old.", "filter_layers.networks." + ) + ] + .detach() + .clone() + ) + if ".bias" in i: + dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0) + dd3.sea.load_state_dict(dd3_state_dict) + + rd3, _, _, _, _ = dd3( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd3.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_jit( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + # sea new impl + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + old_impl=False, + ) + dd0.sea.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.sea.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd1 = DescrptSeA.deserialize(dd0.serialize()) + model = torch.jit.script(dd0) + model = torch.jit.script(dd1) diff --git a/source/tests/pt/test_smooth.py b/source/tests/pt/test_smooth.py new file mode 100644 index 0000000000..2e3bf61d10 --- /dev/null +++ b/source/tests/pt/test_smooth.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) + +from .test_permutation import ( # model_dpau, + make_sample, + model_dpa1, + model_dpa2, + model_hybrid, + model_se_e2_a, +) + +dtype = torch.float64 + + +class SmoothTest: + def test( + self, + ): + # displacement of atoms + epsilon = 1e-5 if self.epsilon is None else self.epsilon + # required prec. relative prec is not checked. + rprec = 0 + aprec = 1e-5 if self.aprec is None else self.aprec + + natoms = 10 + cell = 8.6 * torch.eye(3, dtype=dtype).to(env.DEVICE) + atype = torch.randint(0, 3, [natoms]) + coord0 = ( + torch.tensor( + [ + 0.0, + 0.0, + 0.0, + 4.0 - 0.5 * epsilon, + 0.0, + 0.0, + 0.0, + 4.0 - 0.5 * epsilon, + 0.0, + ], + dtype=dtype, + ) + .view([-1, 3]) + .to(env.DEVICE) + ) + coord1 = torch.rand([natoms - coord0.shape[0], 3], dtype=dtype).to(env.DEVICE) + coord1 = torch.matmul(coord1, cell) + coord = torch.concat([coord0, coord1], dim=0) + + coord0 = torch.clone(coord) + coord1 = torch.clone(coord) + coord1[1][0] += epsilon + coord2 = torch.clone(coord) + coord2[2][1] += epsilon + coord3 = torch.clone(coord) + coord3[1][0] += epsilon + coord3[2][1] += epsilon + + e0, f0, v0 = eval_model( + self.model, coord0.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret0 = { + "energy": e0.squeeze(0), + "force": f0.squeeze(0), + "virial": v0.squeeze(0), + } + e1, f1, v1 = eval_model( + self.model, coord1.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret1 = { + "energy": e1.squeeze(0), + "force": f1.squeeze(0), + "virial": v1.squeeze(0), + } + e2, f2, v2 = eval_model( + self.model, coord2.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret2 = { + "energy": e2.squeeze(0), + "force": f2.squeeze(0), + "virial": v2.squeeze(0), + } + e3, f3, v3 = eval_model( + self.model, coord3.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret3 = { + "energy": e3.squeeze(0), + "force": f3.squeeze(0), + "virial": v3.squeeze(0), + } + + def compare(ret0, ret1): + torch.testing.assert_close( + ret0["energy"], ret1["energy"], rtol=rprec, atol=aprec + ) + # plus 1. to avoid the divided-by-zero issue + torch.testing.assert_close( + 1.0 + ret0["force"], 1.0 + ret1["force"], rtol=rprec, atol=aprec + ) + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + 1.0 + ret0["virial"], 1.0 + ret1["virial"], rtol=rprec, atol=aprec + ) + + compare(ret0, ret1) + compare(ret1, ret2) + compare(ret0, ret3) + + +class TestEnergyModelSeA(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_se_e2_a) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + +# @unittest.skip("dpa-1 not smooth at the moment") +class TestEnergyModelDPA1(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + # less degree of smoothness, + # error can be systematically removed by reducing epsilon + self.epsilon = 1e-5 + self.aprec = 1e-5 + + +class TestEnergyModelDPA2(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa2) + model_params["descriptor"]["repinit_rcut"] = 8 + model_params["descriptor"]["repinit_rcut_smth"] = 3.5 + model_params_sample = copy.deepcopy(model_params) + ####################################################### + # dirty hack here! the interface of dataload should be + # redesigned to support specifying rcut and sel + ####################################################### + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + self.epsilon, self.aprec = 1e-5, 1e-4 + + +class TestEnergyModelDPA2_1(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa2) + model_params["fitting_net"]["type"] = "ener" + model_params_sample = copy.deepcopy(model_params) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + self.type_split = True + self.test_virial = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + +class TestEnergyModelDPA2_2(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa2) + model_params["fitting_net"]["type"] = "ener" + model_params_sample = copy.deepcopy(model_params) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + self.type_split = True + self.test_virial = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + +@unittest.skip("hybrid not supported at the moment") +class TestEnergyModelHybrid(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + +# class TestEnergyFoo(unittest.TestCase): +# def test(self): +# model_params = model_dpau +# sampled = make_sample(model_params) +# self.model = EnergyModelDPAUni(model_params, sampled).to(env.DEVICE) + +# natoms = 5 +# cell = torch.rand([3, 3], dtype=dtype) +# cell = (cell + cell.T) + 5. * torch.eye(3) +# coord = torch.rand([natoms, 3], dtype=dtype) +# coord = torch.matmul(coord, cell) +# atype = torch.IntTensor([0, 0, 0, 1, 1]) +# idx_perm = [1, 0, 4, 3, 2] +# ret0 = infer_model(self.model, coord, cell, atype, type_split=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_smooth_denoise.py b/source/tests/pt/test_smooth_denoise.py new file mode 100644 index 0000000000..a66e5df957 --- /dev/null +++ b/source/tests/pt/test_smooth_denoise.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) + +from .test_permutation_denoise import ( + make_sample, + model_dpa2, +) + +dtype = torch.float64 + + +class SmoothDenoiseTest: + def test( + self, + ): + # displacement of atoms + epsilon = 1e-5 if self.epsilon is None else self.epsilon + # required prec. relative prec is not checked. + rprec = 0 + aprec = 1e-5 if self.aprec is None else self.aprec + + natoms = 10 + cell = 8.6 * torch.eye(3, dtype=dtype).to(env.DEVICE) + atype = torch.randint(0, 3, [natoms]) + coord0 = ( + torch.tensor( + [ + 0.0, + 0.0, + 0.0, + 4.0 - 0.5 * epsilon, + 0.0, + 0.0, + 0.0, + 4.0 - 0.5 * epsilon, + 0.0, + ], + dtype=dtype, + ) + .view([-1, 3]) + .to(env.DEVICE) + ) + coord1 = torch.rand([natoms - coord0.shape[0], 3], dtype=dtype).to(env.DEVICE) + coord1 = torch.matmul(coord1, cell) + coord = torch.concat([coord0, coord1], dim=0) + + coord0 = torch.clone(coord) + coord1 = torch.clone(coord) + coord1[1][0] += epsilon + coord2 = torch.clone(coord) + coord2[2][1] += epsilon + coord3 = torch.clone(coord) + coord3[1][0] += epsilon + coord3[2][1] += epsilon + + update_c0, logits0 = eval_model( + self.model, coord0.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + ret0 = {"updated_coord": update_c0.squeeze(0), "logits": logits0.squeeze(0)} + update_c1, logits1 = eval_model( + self.model, coord1.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + ret1 = {"updated_coord": update_c1.squeeze(0), "logits": logits1.squeeze(0)} + update_c2, logits2 = eval_model( + self.model, coord2.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + ret2 = {"updated_coord": update_c2.squeeze(0), "logits": logits2.squeeze(0)} + update_c3, logits3 = eval_model( + self.model, coord3.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + ret3 = {"updated_coord": update_c3.squeeze(0), "logits": logits3.squeeze(0)} + + def compare(ret0, ret1): + torch.testing.assert_close( + ret0["updated_coord"], ret1["updated_coord"], rtol=rprec, atol=aprec + ) + torch.testing.assert_close( + ret0["logits"], ret1["logits"], rtol=rprec, atol=aprec + ) + + compare(ret0, ret1) + compare(ret1, ret2) + compare(ret0, ret3) + + +class TestDenoiseModelDPA2(unittest.TestCase, SmoothDenoiseTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + model_params["descriptor"]["sel"] = 8 + model_params["descriptor"]["rcut_smth"] = 3.5 + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + self.epsilon, self.aprec = None, None + self.epsilon = 1e-7 + self.aprec = 1e-5 + + +class TestDenoiseModelDPA2_1(unittest.TestCase, SmoothDenoiseTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + # model_params["descriptor"]["combine_grrg"] = True + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + self.epsilon, self.aprec = None, None + self.epsilon = 1e-7 + self.aprec = 1e-5 + + +# @unittest.skip("hybrid not supported at the moment") +# class TestDenoiseModelHybrid(unittest.TestCase, TestSmoothDenoise): +# def setUp(self): +# model_params = copy.deepcopy(model_hybrid_denoise) +# sampled = make_sample(model_params) +# self.type_split = True +# self.model = get_model(model_params, sampled).to(env.DEVICE) +# self.epsilon, self.aprec = None, None +# self.epsilon = 1e-7 +# self.aprec = 1e-5 + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_stat.py b/source/tests/pt/test_stat.py new file mode 100644 index 0000000000..08fc12ff11 --- /dev/null +++ b/source/tests/pt/test_stat.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest +from pathlib import ( + Path, +) + +import numpy as np +import torch + +from deepmd.pt.model.descriptor import ( + DescrptSeA, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.stat import ( + compute_output_stats, +) +from deepmd.pt.utils.stat import make_stat_input as my_make +from deepmd.tf.common import ( + expand_sys_str, +) +from deepmd.tf.descriptor.se_a import DescrptSeA as DescrptSeA_tf +from deepmd.tf.fit.ener import ( + EnerFitting, +) +from deepmd.tf.model.model_stat import make_stat_input as dp_make +from deepmd.tf.model.model_stat import merge_sys_stat as dp_merge +from deepmd.tf.utils import random as tf_random +from deepmd.tf.utils.data_system import ( + DeepmdDataSystem, +) + +CUR_DIR = os.path.dirname(__file__) + + +def compare(ut, base, given): + if isinstance(base, list): + ut.assertEqual(len(base), len(given)) + for idx in range(len(base)): + compare(ut, base[idx], given[idx]) + elif isinstance(base, np.ndarray): + ut.assertTrue(np.allclose(base.reshape(-1), given.reshape(-1))) + else: + ut.assertEqual(base, given) + + +class TestDataset(unittest.TestCase): + def setUp(self): + with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: + content = fin.read() + config = json.loads(content) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + config["training"]["training_data"]["systems"] = data_file + config["training"]["validation_data"]["systems"] = data_file + model_config = config["model"] + self.rcut = model_config["descriptor"]["rcut"] + self.rcut_smth = model_config["descriptor"]["rcut_smth"] + self.sel = model_config["descriptor"]["sel"] + self.batch_size = config["training"]["training_data"]["batch_size"] + self.systems = config["training"]["validation_data"]["systems"] + if isinstance(self.systems, str): + self.systems = expand_sys_str(self.systems) + self.my_dataset = DpLoaderSet( + self.systems, + self.batch_size, + model_params={ + "descriptor": { + "type": "se_e2_a", + "sel": self.sel, + "rcut": self.rcut, + }, + "type_map": model_config["type_map"], + }, + seed=10, + ) + self.filter_neuron = model_config["descriptor"]["neuron"] + self.axis_neuron = model_config["descriptor"]["axis_neuron"] + self.data_stat_nbatch = 2 + self.filter_neuron = model_config["descriptor"]["neuron"] + self.axis_neuron = model_config["descriptor"]["axis_neuron"] + self.n_neuron = model_config["fitting_net"]["neuron"] + + self.my_sampled = my_make( + self.my_dataset.systems, self.my_dataset.dataloaders, self.data_stat_nbatch + ) + + tf_random.seed(10) + dp_dataset = DeepmdDataSystem(self.systems, self.batch_size, 1, self.rcut) + dp_dataset.add("energy", 1, atomic=False, must=False, high_prec=True) + dp_dataset.add("force", 3, atomic=True, must=False, high_prec=False) + self.dp_sampled = dp_make(dp_dataset, self.data_stat_nbatch, False) + self.dp_merged = dp_merge(self.dp_sampled) + self.dp_mesh = self.dp_merged.pop("default_mesh") + self.dp_d = DescrptSeA_tf( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=self.sel, + neuron=self.filter_neuron, + axis_neuron=self.axis_neuron, + ) + + def test_stat_output(self): + def my_merge(energy, natoms): + energy_lst = [] + natoms_lst = [] + for i in range(len(energy)): + for j in range(len(energy[i])): + energy_lst.append(torch.tensor(energy[i][j])) + natoms_lst.append( + torch.tensor(natoms[i][j]) + .unsqueeze(0) + .expand(energy[i][j].shape[0], -1) + ) + return energy_lst, natoms_lst + + energy = self.dp_sampled["energy"] + natoms = self.dp_sampled["natoms_vec"] + energy, natoms = my_merge(energy, natoms) + dp_fn = EnerFitting(self.dp_d, self.n_neuron) + dp_fn.compute_output_stats(self.dp_sampled) + bias_atom_e = compute_output_stats(energy, natoms) + self.assertTrue(np.allclose(dp_fn.bias_atom_e, bias_atom_e[:, 0])) + + # temporarily delete this function for performance of seeds in tf and pytorch may be different + """ + def test_stat_input(self): + my_sampled = self.my_sampled + # list of dicts, each dict contains samples from a system + dp_keys = set(self.dp_merged.keys()) # dict of list of batches + self.dp_merged['natoms'] = self.dp_merged['natoms_vec'] + for key in dp_keys: + if not key in my_sampled[0] or key in 'coord': + # coord is pre-normalized + continue + lst = [] + for item in my_sampled: + bsz = item['energy'].shape[0]//self.data_stat_nbatch + for j in range(self.data_stat_nbatch): + lst.append(item[key][j*bsz:(j+1)*bsz].cpu().numpy()) + compare(self, self.dp_merged[key], lst) + """ + + def test_descriptor(self): + coord = self.dp_merged["coord"] + atype = self.dp_merged["type"] + natoms = self.dp_merged["natoms_vec"] + box = self.dp_merged["box"] + self.dp_d.compute_input_stats(coord, box, atype, natoms, self.dp_mesh, {}) + + my_en = DescrptSeA( + self.rcut, self.rcut_smth, self.sel, self.filter_neuron, self.axis_neuron + ) + my_en = my_en.sea # get the block who has stat as private vars + sampled = self.my_sampled + for sys in sampled: + for key in [ + "coord", + "force", + "energy", + "atype", + "natoms", + "extended_coord", + "nlist", + "shift", + "mapping", + ]: + if key in sys.keys(): + sys[key] = sys[key].to(env.DEVICE) + sumr, suma, sumn, sumr2, suma2 = my_en.compute_input_stats(sampled) + my_en.init_desc_stat(sumr, suma, sumn, sumr2, suma2) + my_en.mean = my_en.mean + my_en.stddev = my_en.stddev + self.assertTrue( + np.allclose( + self.dp_d.davg.reshape([-1]), my_en.mean.cpu().reshape([-1]), rtol=0.01 + ) + ) + self.assertTrue( + np.allclose( + self.dp_d.dstd.reshape([-1]), + my_en.stddev.cpu().reshape([-1]), + rtol=0.01, + ) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py new file mode 100644 index 0000000000..574ca8688e --- /dev/null +++ b/source/tests/pt/test_training.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import shutil +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) + +from .test_permutation import ( + model_dpa1, + model_dpa2, + model_hybrid, + model_se_e2_a, +) + + +class DPTrainTest: + def test_dp_train(self): + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + self.tearDown() + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + +class TestEnergyModelSeA(unittest.TestCase, DPTrainTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) + + +class TestEnergyModelDPA1(unittest.TestCase, DPTrainTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa1) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) + + +class TestEnergyModelDPA2(unittest.TestCase, DPTrainTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa2) + self.config["model"]["descriptor"]["rcut"] = self.config["model"]["descriptor"][ + "repinit_rcut" + ] + self.config["model"]["descriptor"]["rcut_smth"] = self.config["model"][ + "descriptor" + ]["repinit_rcut_smth"] + self.config["model"]["descriptor"]["sel"] = self.config["model"]["descriptor"][ + "repinit_nsel" + ] + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) + + +@unittest.skip("hybrid not supported at the moment") +class TestEnergyModelHybrid(unittest.TestCase, DPTrainTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_hybrid) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_trans.py b/source/tests/pt/test_trans.py new file mode 100644 index 0000000000..e5d379b9ff --- /dev/null +++ b/source/tests/pt/test_trans.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) + +from .test_permutation import ( # model_dpau, + make_sample, + model_dpa1, + model_dpa2, + model_hybrid, + model_se_e2_a, +) + +dtype = torch.float64 + + +class TransTest: + def test( + self, + ): + natoms = 5 + cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) + cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) + coord = torch.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) + shift = (torch.rand([3], dtype=dtype) - 0.5).to(env.DEVICE) * 2.0 + coord_s = torch.matmul( + torch.remainder(torch.matmul(coord + shift, torch.linalg.inv(cell)), 1.0), + cell, + ) + e0, f0, v0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret0 = { + "energy": e0.squeeze(0), + "force": f0.squeeze(0), + "virial": v0.squeeze(0), + } + e1, f1, v1 = eval_model( + self.model, coord_s.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret1 = { + "energy": e1.squeeze(0), + "force": f1.squeeze(0), + "virial": v1.squeeze(0), + } + prec = 1e-10 + torch.testing.assert_close(ret0["energy"], ret1["energy"], rtol=prec, atol=prec) + torch.testing.assert_close(ret0["force"], ret1["force"], rtol=prec, atol=prec) + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close( + ret0["virial"], ret1["virial"], rtol=prec, atol=prec + ) + + +class TestEnergyModelSeA(unittest.TestCase, TransTest): + def setUp(self): + model_params = copy.deepcopy(model_se_e2_a) + sampled = make_sample(model_params) + self.type_split = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelDPA1(unittest.TestCase, TransTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestEnergyModelDPA2(unittest.TestCase, TransTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestForceModelDPA2(unittest.TestCase, TransTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + model_params["fitting_net"]["type"] = "direct_force_ener" + self.type_split = True + self.test_virial = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +@unittest.skip("hybrid not supported at the moment") +class TestEnergyModelHybrid(unittest.TestCase, TransTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +@unittest.skip("hybrid not supported at the moment") +class TestForceModelHybrid(unittest.TestCase, TransTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + model_params["fitting_net"]["type"] = "direct_force_ener" + sampled = make_sample(model_params) + self.type_split = True + self.test_virial = False + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_trans_denoise.py b/source/tests/pt/test_trans_denoise.py new file mode 100644 index 0000000000..360633278c --- /dev/null +++ b/source/tests/pt/test_trans_denoise.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) + +from .test_permutation_denoise import ( + make_sample, + model_dpa1, + model_dpa2, + model_hybrid, +) + +dtype = torch.float64 + + +class TransDenoiseTest: + def test( + self, + ): + natoms = 5 + cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) + cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) + coord = torch.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) + shift = (torch.rand([3], dtype=dtype) - 0.5).to(env.DEVICE) * 2.0 + coord_s = torch.matmul( + torch.remainder(torch.matmul(coord + shift, torch.linalg.inv(cell)), 1.0), + cell, + ) + updated_c0, logits0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + updated_c0 = updated_c0 - coord.unsqueeze(0) + ret0 = {"updated_coord": updated_c0.squeeze(0), "logits": logits0.squeeze(0)} + updated_c1, logits1 = eval_model( + self.model, coord_s.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + updated_c1 = updated_c1 - coord_s.unsqueeze(0) + ret1 = {"updated_coord": updated_c1.squeeze(0), "logits": logits1.squeeze(0)} + prec = 1e-10 + torch.testing.assert_close( + ret0["updated_coord"], ret1["updated_coord"], rtol=prec, atol=prec + ) + torch.testing.assert_close(ret0["logits"], ret1["logits"], rtol=prec, atol=prec) + + +class TestDenoiseModelDPA1(unittest.TestCase, TransDenoiseTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +class TestDenoiseModelDPA2(unittest.TestCase, TransDenoiseTest): + def setUp(self): + model_params_sample = copy.deepcopy(model_dpa2) + model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][ + "repinit_rcut" + ] + model_params_sample["descriptor"]["sel"] = model_params_sample["descriptor"][ + "repinit_nsel" + ] + sampled = make_sample(model_params_sample) + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +@unittest.skip("hybrid not supported at the moment") +class TestDenoiseModelHybrid(unittest.TestCase, TransDenoiseTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + sampled = make_sample(model_params) + self.type_split = True + self.model = get_model(model_params, sampled).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_unused_params.py b/source/tests/pt/test_unused_params.py new file mode 100644 index 0000000000..a924979466 --- /dev/null +++ b/source/tests/pt/test_unused_params.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) + +from .test_permutation import ( + make_sample, + model_dpa2, +) + +dtype = torch.float64 + + +class TestUnusedParamsDPA2(unittest.TestCase): + def test_unused(self): + import itertools + + for conv, drrd, grrg, attn1, g1g1, attn2, h2 in itertools.product( + [True], + [True], + [True], + [True], + [True], + [True], + [True], + ): + if (not drrd) and (not grrg) and h2: + # skip the case h2 is not envolved + continue + if (not grrg) and (not conv): + # skip the case g2 is not envolved + continue + model = copy.deepcopy(model_dpa2) + model["descriptor"]["rcut"] = model["descriptor"]["repinit_rcut"] + model["descriptor"]["sel"] = model["descriptor"]["repinit_nsel"] + model["descriptor"]["repformer_nlayers"] = 2 + # model["descriptor"]["combine_grrg"] = cmbg2 + model["descriptor"]["repformer_update_g1_has_conv"] = conv + model["descriptor"]["repformer_update_g1_has_drrd"] = drrd + model["descriptor"]["repformer_update_g1_has_grrg"] = grrg + model["descriptor"]["repformer_update_g1_has_attn"] = attn1 + model["descriptor"]["repformer_update_g2_has_g1g1"] = g1g1 + model["descriptor"]["repformer_update_g2_has_attn"] = attn2 + model["descriptor"]["repformer_update_h2"] = h2 + model["fitting_net"]["neuron"] = [12, 12, 12] + self._test_unused(model) + + def _test_unused(self, model_params): + sampled = make_sample(model_params) + self.model = get_model(model_params, sampled).to(env.DEVICE) + natoms = 5 + cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE) + cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) + coord = torch.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) + idx_perm = [1, 0, 4, 3, 2] + e0, f0, v0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret0 = { + "energy": e0.squeeze(0), + "force": f0.squeeze(0), + "virial": v0.squeeze(0), + } + + # use computation graph to find all contributing tensors + def get_contributing_params(y, top_level=True): + nf = y.grad_fn.next_functions if top_level else y.next_functions + for f, _ in nf: + try: + yield f.variable + except AttributeError: + pass # node has no tensor + if f is not None: + yield from get_contributing_params(f, top_level=False) + + contributing_parameters = set(get_contributing_params(ret0["energy"])) + all_parameters = set(self.model.parameters()) + non_contributing = all_parameters - contributing_parameters + for ii in non_contributing: + print(ii.shape) + self.assertEqual(len(non_contributing), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/water/data/data_0/set.000/box.npy b/source/tests/pt/water/data/data_0/set.000/box.npy new file mode 100644 index 0000000000..6ad2de625b Binary files /dev/null and b/source/tests/pt/water/data/data_0/set.000/box.npy differ diff --git a/source/tests/pt/water/data/data_0/set.000/coord.npy b/source/tests/pt/water/data/data_0/set.000/coord.npy new file mode 100644 index 0000000000..8bd448b125 Binary files /dev/null and b/source/tests/pt/water/data/data_0/set.000/coord.npy differ diff --git a/source/tests/pt/water/data/data_0/set.000/energy.npy b/source/tests/pt/water/data/data_0/set.000/energy.npy new file mode 100644 index 0000000000..d03db103f5 Binary files /dev/null and b/source/tests/pt/water/data/data_0/set.000/energy.npy differ diff --git a/source/tests/pt/water/data/data_0/set.000/force.npy b/source/tests/pt/water/data/data_0/set.000/force.npy new file mode 100644 index 0000000000..10b2ab83a2 Binary files /dev/null and b/source/tests/pt/water/data/data_0/set.000/force.npy differ diff --git a/source/tests/pt/water/data/data_0/type.raw b/source/tests/pt/water/data/data_0/type.raw new file mode 100644 index 0000000000..97e8fdfcf8 --- /dev/null +++ b/source/tests/pt/water/data/data_0/type.raw @@ -0,0 +1,192 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/source/tests/pt/water/data/data_0/type_map.raw b/source/tests/pt/water/data/data_0/type_map.raw new file mode 100644 index 0000000000..e900768b1d --- /dev/null +++ b/source/tests/pt/water/data/data_0/type_map.raw @@ -0,0 +1,2 @@ +O +H diff --git a/source/tests/pt/water/data/single/set.000/box.npy b/source/tests/pt/water/data/single/set.000/box.npy new file mode 100644 index 0000000000..65897e0f9c Binary files /dev/null and b/source/tests/pt/water/data/single/set.000/box.npy differ diff --git a/source/tests/pt/water/data/single/set.000/coord.npy b/source/tests/pt/water/data/single/set.000/coord.npy new file mode 100644 index 0000000000..6e0594a803 Binary files /dev/null and b/source/tests/pt/water/data/single/set.000/coord.npy differ diff --git a/source/tests/pt/water/data/single/set.000/energy.npy b/source/tests/pt/water/data/single/set.000/energy.npy new file mode 100644 index 0000000000..a0a88fb78a Binary files /dev/null and b/source/tests/pt/water/data/single/set.000/energy.npy differ diff --git a/source/tests/pt/water/data/single/set.000/force.npy b/source/tests/pt/water/data/single/set.000/force.npy new file mode 100644 index 0000000000..d5b847a86e Binary files /dev/null and b/source/tests/pt/water/data/single/set.000/force.npy differ diff --git a/source/tests/pt/water/data/single/type.raw b/source/tests/pt/water/data/single/type.raw new file mode 100644 index 0000000000..97e8fdfcf8 --- /dev/null +++ b/source/tests/pt/water/data/single/type.raw @@ -0,0 +1,192 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/source/tests/pt/water/data/single/type_map.raw b/source/tests/pt/water/data/single/type_map.raw new file mode 100644 index 0000000000..e900768b1d --- /dev/null +++ b/source/tests/pt/water/data/single/type_map.raw @@ -0,0 +1,2 @@ +O +H diff --git a/source/tests/pt/water/lkf.json b/source/tests/pt/water/lkf.json new file mode 100644 index 0000000000..4385d02136 --- /dev/null +++ b/source/tests/pt/water/lkf.json @@ -0,0 +1,79 @@ +{ + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "se_e2_a", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 25, + 25 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1, + "_comment": " that's all" + }, + "fitting_net": { + "neuron": [ + 100, + 100, + 100 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "data_stat_nbatch": 20, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "_comment": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 3, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment": "that's all" + }, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": 1, + "opt_type": "LKF", + "kf_blocksize": 1024, + "_comment": "that's all" + }, + "_comment": "that's all" +} diff --git a/source/tests/pt/water/se_atten.json b/source/tests/pt/water/se_atten.json new file mode 100644 index 0000000000..8867e0db41 --- /dev/null +++ b/source/tests/pt/water/se_atten.json @@ -0,0 +1,84 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [ + 25, + 50, + 100 + ], + "axis_neuron": 16, + "attn": 64, + "attn_layer": 2, + "attn_dotr": true, + "attn_mask": false, + "post_ln": true, + "ffn": false, + "ffn_embed_dim": 512, + "activation": "tanh", + "scaling_factor": 1.0, + "head_num": 1, + "normalize": false, + "temperature": 1.0 + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "numb_btch": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000, + "_comment": "that's all" + } +} diff --git a/source/tests/pt/water/se_e2_a.json b/source/tests/pt/water/se_e2_a.json new file mode 100644 index 0000000000..425ca3cbf5 --- /dev/null +++ b/source/tests/pt/water/se_e2_a.json @@ -0,0 +1,77 @@ +{ + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "se_e2_a", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1, + "_comment": " that's all" + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "data_stat_nbatch": 20, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "_comment": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment": "that's all" + }, + "numb_steps": 100000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 10000, + "_comment": "that's all" + }, + "_comment": "that's all" +} diff --git a/source/tests/test_adjust_sel.py b/source/tests/test_adjust_sel.py index b1cbdc5afc..9bed3606fd 100644 --- a/source/tests/test_adjust_sel.py +++ b/source/tests/test_adjust_sel.py @@ -82,12 +82,10 @@ def _init_models(): return INPUT, frozen_model, decreased_model, increased_model -INPUT, FROZEN_MODEL, DECREASED_MODEL, INCREASED_MODEL = _init_models() - - class TestDeepPotAAdjustSel(unittest.TestCase): @classmethod def setUpClass(self): + INPUT, FROZEN_MODEL, DECREASED_MODEL, INCREASED_MODEL = _init_models() self.dp_original = DeepPot(FROZEN_MODEL) self.dp_decreased = DeepPot(DECREASED_MODEL) self.dp_increased = DeepPot(INCREASED_MODEL) diff --git a/source/tests/test_finetune_se_atten.py b/source/tests/test_finetune_se_atten.py index 3614fcb13a..47fedcf685 100644 --- a/source/tests/test_finetune_se_atten.py +++ b/source/tests/test_finetune_se_atten.py @@ -147,67 +147,77 @@ def _init_models(setup_model, i): ) -if not parse_version(tf.__version__) < parse_version("1.15"): - - def previous_se_atten(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = False - jdata["model"]["descriptor"]["attn_layer"] = 2 - - def stripped_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True - jdata["model"]["descriptor"]["attn_layer"] = 2 - - def compressible_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True - jdata["model"]["descriptor"]["attn_layer"] = 0 - - models = [previous_se_atten, stripped_model, compressible_model] - INPUT_PRES = [] - INPUT_FINETUNES = [] - INPUT_FINETUNE_MIXS = [] - PRE_MODELS = [] - FINETUNED_MODELS = [] - FINETUNED_MODEL_MIXS = [] - PRE_MAPS = [] - FINETUNED_MAPS = [] - VALID_DATAS = [] - for i, model in enumerate(models): - ( - INPUT_PRE, - INPUT_FINETUNE, - INPUT_FINETUNE_MIX, - PRE_MODEL, - FINETUNED_MODEL, - FINETUNED_MODEL_MIX, - PRE_MAP, - FINETUNED_MAP, - VALID_DATA, - ) = _init_models(model, i) - INPUT_PRES.append(INPUT_PRE) - INPUT_FINETUNES.append(INPUT_FINETUNE) - INPUT_FINETUNE_MIXS.append(INPUT_FINETUNE_MIX) - PRE_MODELS.append(PRE_MODEL) - FINETUNED_MODELS.append(FINETUNED_MODEL) - FINETUNED_MODEL_MIXS.append(FINETUNED_MODEL_MIX) - PRE_MAPS.append(PRE_MAP) - FINETUNED_MAPS.append(FINETUNED_MAP) - VALID_DATAS.append(VALID_DATA) - - @unittest.skipIf( parse_version(tf.__version__) < parse_version("1.15"), f"The current tf version {tf.__version__} is too low to run the new testing model.", ) class TestFinetuneSeAtten(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + if not parse_version(tf.__version__) < parse_version("1.15"): + + def previous_se_atten(jdata): + jdata["model"]["descriptor"]["stripped_type_embedding"] = False + jdata["model"]["descriptor"]["attn_layer"] = 2 + + def stripped_model(jdata): + jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["attn_layer"] = 2 + + def compressible_model(jdata): + jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["attn_layer"] = 0 + + models = [previous_se_atten, stripped_model, compressible_model] + INPUT_PRES = [] + INPUT_FINETUNES = [] + INPUT_FINETUNE_MIXS = [] + PRE_MODELS = [] + FINETUNED_MODELS = [] + FINETUNED_MODEL_MIXS = [] + PRE_MAPS = [] + FINETUNED_MAPS = [] + VALID_DATAS = [] + for i, model in enumerate(models): + ( + INPUT_PRE, + INPUT_FINETUNE, + INPUT_FINETUNE_MIX, + PRE_MODEL, + FINETUNED_MODEL, + FINETUNED_MODEL_MIX, + PRE_MAP, + FINETUNED_MAP, + VALID_DATA, + ) = _init_models(model, i) + INPUT_PRES.append(INPUT_PRE) + INPUT_FINETUNES.append(INPUT_FINETUNE) + INPUT_FINETUNE_MIXS.append(INPUT_FINETUNE_MIX) + PRE_MODELS.append(PRE_MODEL) + FINETUNED_MODELS.append(FINETUNED_MODEL) + FINETUNED_MODEL_MIXS.append(FINETUNED_MODEL_MIX) + PRE_MAPS.append(PRE_MAP) + FINETUNED_MAPS.append(FINETUNED_MAP) + VALID_DATAS.append(VALID_DATA) + cls.INPUT_PRES = INPUT_PRES + cls.INPUT_FINETUNES = INPUT_FINETUNES + cls.INPUT_FINETUNE_MIXS = INPUT_FINETUNE_MIXS + cls.PRE_MODELS = PRE_MODELS + cls.FINETUNED_MODELS = FINETUNED_MODELS + cls.FINETUNED_MODEL_MIXS = FINETUNED_MODEL_MIXS + cls.PRE_MAPS = PRE_MAPS + cls.FINETUNED_MAPS = FINETUNED_MAPS + cls.VALID_DATAS = VALID_DATAS + @classmethod def tearDownClass(self): - for i in range(len(INPUT_PRES)): - _file_delete(INPUT_PRES[i]) - _file_delete(INPUT_FINETUNES[i]) - _file_delete(INPUT_FINETUNE_MIXS[i]) - _file_delete(PRE_MODELS[i]) - _file_delete(FINETUNED_MODELS[i]) - _file_delete(FINETUNED_MODEL_MIXS[i]) + for i in range(len(self.INPUT_PRES)): + _file_delete(self.INPUT_PRES[i]) + _file_delete(self.INPUT_FINETUNES[i]) + _file_delete(self.INPUT_FINETUNE_MIXS[i]) + _file_delete(self.PRE_MODELS[i]) + _file_delete(self.FINETUNED_MODELS[i]) + _file_delete(self.FINETUNED_MODEL_MIXS[i]) _file_delete("out.json") _file_delete("model.ckpt.meta") _file_delete("model.ckpt.index") @@ -223,22 +233,22 @@ def tearDownClass(self): _file_delete("lcurve.out") def test_finetune_standard(self): - for i in range(len(INPUT_PRES)): - self.valid_data = VALID_DATAS[i] + for i in range(len(self.INPUT_PRES)): + self.valid_data = self.VALID_DATAS[i] pretrained_bias = get_tensor_by_name( - PRE_MODELS[i], "fitting_attr/t_bias_atom_e" + self.PRE_MODELS[i], "fitting_attr/t_bias_atom_e" ) finetuned_bias = get_tensor_by_name( - FINETUNED_MODELS[i], "fitting_attr/t_bias_atom_e" + self.FINETUNED_MODELS[i], "fitting_attr/t_bias_atom_e" ) - sorter = np.argsort(PRE_MAPS[i]) + sorter = np.argsort(self.PRE_MAPS[i]) idx_type_map = sorter[ - np.searchsorted(PRE_MAPS[i], FINETUNED_MAPS[i], sorter=sorter) + np.searchsorted(self.PRE_MAPS[i], self.FINETUNED_MAPS[i], sorter=sorter) ] test_data = self.valid_data.get_test() atom_nums = np.tile(np.bincount(test_data["type"][0])[idx_type_map], (4, 1)) - dp = DeepPotential(PRE_MODELS[i]) + dp = DeepPotential(self.PRE_MODELS[i]) energy = dp.eval( test_data["coord"], test_data["box"], test_data["type"][0] )[0] @@ -250,7 +260,7 @@ def test_finetune_standard(self): 0 ].reshape(-1) - dp_finetuned = DeepPotential(FINETUNED_MODELS[i]) + dp_finetuned = DeepPotential(self.FINETUNED_MODELS[i]) energy_finetuned = dp_finetuned.eval( test_data["coord"], test_data["box"], test_data["type"][0] )[0] @@ -266,22 +276,22 @@ def test_finetune_standard(self): np.testing.assert_almost_equal(finetune_results, 0.0, default_places) def test_finetune_mixed_type(self): - for i in range(len(INPUT_PRES)): - self.valid_data = VALID_DATAS[i] + for i in range(len(self.INPUT_PRES)): + self.valid_data = self.VALID_DATAS[i] pretrained_bias = get_tensor_by_name( - PRE_MODELS[i], "fitting_attr/t_bias_atom_e" + self.PRE_MODELS[i], "fitting_attr/t_bias_atom_e" ) finetuned_bias_mixed_type = get_tensor_by_name( - FINETUNED_MODEL_MIXS[i], "fitting_attr/t_bias_atom_e" + self.FINETUNED_MODEL_MIXS[i], "fitting_attr/t_bias_atom_e" ) - sorter = np.argsort(PRE_MAPS[i]) + sorter = np.argsort(self.PRE_MAPS[i]) idx_type_map = sorter[ - np.searchsorted(PRE_MAPS[i], FINETUNED_MAPS[i], sorter=sorter) + np.searchsorted(self.PRE_MAPS[i], self.FINETUNED_MAPS[i], sorter=sorter) ] test_data = self.valid_data.get_test() atom_nums = np.tile(np.bincount(test_data["type"][0])[idx_type_map], (4, 1)) - dp = DeepPotential(PRE_MODELS[i]) + dp = DeepPotential(self.PRE_MODELS[i]) energy = dp.eval( test_data["coord"], test_data["box"], test_data["type"][0] )[0] @@ -293,7 +303,7 @@ def test_finetune_mixed_type(self): 0 ].reshape(-1) - dp_finetuned_mixed_type = DeepPotential(FINETUNED_MODEL_MIXS[i]) + dp_finetuned_mixed_type = DeepPotential(self.FINETUNED_MODEL_MIXS[i]) energy_finetuned = dp_finetuned_mixed_type.eval( test_data["coord"], test_data["box"], test_data["type"][0] )[0] diff --git a/source/tests/test_init_frz_model_multi.py b/source/tests/test_init_frz_model_multi.py index e5e5733c7d..fc37d82397 100644 --- a/source/tests/test_init_frz_model_multi.py +++ b/source/tests/test_init_frz_model_multi.py @@ -180,20 +180,19 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelMulti(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() + cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data_dict = {"water_ener": VALID_DATA} @@ -205,19 +204,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_init_frz_model_se_a.py b/source/tests/test_init_frz_model_se_a.py index d98c2bc14f..7545e3aae9 100644 --- a/source/tests/test_init_frz_model_se_a.py +++ b/source/tests/test_init_frz_model_se_a.py @@ -128,20 +128,18 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelA(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data = VALID_DATA @@ -149,19 +147,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_init_frz_model_se_a_tebd.py b/source/tests/test_init_frz_model_se_a_tebd.py index 594bf83085..1b282c00d5 100644 --- a/source/tests/test_init_frz_model_se_a_tebd.py +++ b/source/tests/test_init_frz_model_se_a_tebd.py @@ -129,20 +129,19 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelA(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() + cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data = VALID_DATA @@ -150,19 +149,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_init_frz_model_se_a_type.py b/source/tests/test_init_frz_model_se_a_type.py index 3221245065..b356dbf6d0 100644 --- a/source/tests/test_init_frz_model_se_a_type.py +++ b/source/tests/test_init_frz_model_se_a_type.py @@ -132,20 +132,18 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelAType(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data = VALID_DATA @@ -153,19 +151,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_init_frz_model_se_atten.py b/source/tests/test_init_frz_model_se_atten.py index 5554ae415c..7889440cd3 100644 --- a/source/tests/test_init_frz_model_se_atten.py +++ b/source/tests/test_init_frz_model_se_atten.py @@ -146,32 +146,6 @@ def compressible_model(jdata): jdata["model"]["descriptor"]["stripped_type_embedding"] = True jdata["model"]["descriptor"]["attn_layer"] = 0 - models = [previous_se_atten, stripped_model, compressible_model] - INPUTS = [] - CKPTS = [] - FROZEN_MODELS = [] - CKPT_TRAINERS = [] - FRZ_TRAINERS = [] - VALID_DATAS = [] - STOP_BATCHS = [] - for i, model in enumerate(models): - ( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, - ) = _init_models(model, i) - INPUTS.append(INPUT) - CKPTS.append(CKPT) - FROZEN_MODELS.append(FROZEN_MODEL) - CKPT_TRAINERS.append(CKPT_TRAINER) - FRZ_TRAINERS.append(FRZ_TRAINER) - VALID_DATAS.append(VALID_DATA) - STOP_BATCHS.append(STOP_BATCH) - @unittest.skipIf( parse_version(tf.__version__) < parse_version("1.15"), @@ -180,6 +154,38 @@ def compressible_model(jdata): class TestInitFrzModelAtten(unittest.TestCase): @classmethod def setUpClass(cls): + models = [previous_se_atten, stripped_model, compressible_model] + INPUTS = [] + CKPTS = [] + FROZEN_MODELS = [] + CKPT_TRAINERS = [] + FRZ_TRAINERS = [] + VALID_DATAS = [] + STOP_BATCHS = [] + for i, model in enumerate(models): + ( + INPUT, + CKPT, + FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models(model, i) + INPUTS.append(INPUT) + CKPTS.append(CKPT) + FROZEN_MODELS.append(FROZEN_MODEL) + CKPT_TRAINERS.append(CKPT_TRAINER) + FRZ_TRAINERS.append(FRZ_TRAINER) + VALID_DATAS.append(VALID_DATA) + STOP_BATCHS.append(STOP_BATCH) + cls.INPUTS = INPUTS + cls.CKPTS = CKPTS + cls.FROZEN_MODELS = FROZEN_MODELS + cls.CKPT_TRAINERS = CKPT_TRAINERS + cls.FRZ_TRAINERS = FRZ_TRAINERS + cls.VALID_DATAS = VALID_DATAS + cls.STOP_BATCHS = STOP_BATCHS cls.dp_ckpts = CKPT_TRAINERS cls.dp_frzs = FRZ_TRAINERS cls.valid_datas = VALID_DATAS @@ -188,28 +194,28 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): for i in range(len(cls.dp_ckpts)): - _file_delete(INPUTS[i]) - _file_delete(FROZEN_MODELS[i]) + _file_delete(cls.INPUTS[i]) + _file_delete(cls.FROZEN_MODELS[i]) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT[i] + ".meta") - _file_delete(CKPT[i] + ".index") - _file_delete(CKPT[i] + ".data-00000-of-00001") - _file_delete(CKPT[i] + "-0.meta") - _file_delete(CKPT[i] + "-0.index") - _file_delete(CKPT[i] + "-0.data-00000-of-00001") - _file_delete(CKPT[i] + "-1.meta") - _file_delete(CKPT[i] + "-1.index") - _file_delete(CKPT[i] + "-1.data-00000-of-00001") + _file_delete(cls.CKPTS[i] + ".meta") + _file_delete(cls.CKPTS[i] + ".index") + _file_delete(cls.CKPTS[i] + ".data-00000-of-00001") + _file_delete(cls.CKPTS[i] + "-0.meta") + _file_delete(cls.CKPTS[i] + "-0.index") + _file_delete(cls.CKPTS[i] + "-0.data-00000-of-00001") + _file_delete(cls.CKPTS[i] + "-1.meta") + _file_delete(cls.CKPTS[i] + "-1.index") + _file_delete(cls.CKPTS[i] + "-1.data-00000-of-00001") _file_delete(f"input_v2_compat{i}.json") _file_delete("lcurve.out") def test_single_frame(self): for i in range(len(self.dp_ckpts)): - self.dp_ckpt = CKPT_TRAINERS[i] - self.dp_frz = FRZ_TRAINERS[i] - self.valid_data = VALID_DATAS[i] - self.stop_batch = STOP_BATCHS[i] + self.dp_ckpt = self.CKPT_TRAINERS[i] + self.dp_frz = self.FRZ_TRAINERS[i] + self.valid_data = self.VALID_DATAS[i] + self.stop_batch = self.STOP_BATCHS[i] valid_batch = self.valid_data.get_batch() natoms = valid_batch["natoms_vec"] diff --git a/source/tests/test_init_frz_model_se_r.py b/source/tests/test_init_frz_model_se_r.py index 84d109bcfd..fd916b3fdc 100644 --- a/source/tests/test_init_frz_model_se_r.py +++ b/source/tests/test_init_frz_model_se_r.py @@ -136,20 +136,19 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelR(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() + cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data = VALID_DATA @@ -157,19 +156,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_init_frz_model_spin.py b/source/tests/test_init_frz_model_spin.py index 7aa3d514dc..b5c480c2ba 100644 --- a/source/tests/test_init_frz_model_spin.py +++ b/source/tests/test_init_frz_model_spin.py @@ -140,20 +140,19 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelR(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() + cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data = VALID_DATA @@ -161,19 +160,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_model_compression_se_a_ebd_type_one_side.py b/source/tests/test_model_compression_se_a_ebd_type_one_side.py index 9ad1970e9b..741c95b26e 100644 --- a/source/tests/test_model_compression_se_a_ebd_type_one_side.py +++ b/source/tests/test_model_compression_se_a_ebd_type_one_side.py @@ -98,7 +98,6 @@ def _init_models_exclude_types(): INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() -INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() class TestDeepPotAPBC(unittest.TestCase): @@ -444,8 +443,13 @@ def test_ase(self): class TestDeepPotAPBCExcludeTypes(unittest.TestCase): @classmethod def setUpClass(self): - self.dp_original = DeepPot(FROZEN_MODEL_ET) - self.dp_compressed = DeepPot(COMPRESSED_MODEL_ET) + ( + self.INPUT_ET, + self.FROZEN_MODEL_ET, + self.COMPRESSED_MODEL_ET, + ) = _init_models_exclude_types() + self.dp_original = DeepPot(self.FROZEN_MODEL_ET) + self.dp_compressed = DeepPot(self.COMPRESSED_MODEL_ET) self.coords = np.array( [ 12.83, @@ -473,9 +477,9 @@ def setUpClass(self): @classmethod def tearDownClass(self): - _file_delete(INPUT_ET) - _file_delete(FROZEN_MODEL_ET) - _file_delete(COMPRESSED_MODEL_ET) + _file_delete(self.INPUT_ET) + _file_delete(self.FROZEN_MODEL_ET) + _file_delete(self.COMPRESSED_MODEL_ET) _file_delete("out.json") _file_delete("compress.json") _file_delete("checkpoint") diff --git a/source/tests/test_model_compression_se_a_type_one_side_exclude_types.py b/source/tests/test_model_compression_se_a_type_one_side_exclude_types.py index 5b6ac4e13e..bdf09cf3e8 100644 --- a/source/tests/test_model_compression_se_a_type_one_side_exclude_types.py +++ b/source/tests/test_model_compression_se_a_type_one_side_exclude_types.py @@ -66,12 +66,11 @@ def _init_models(): return INPUT, frozen_model, compressed_model -INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() - - class TestDeepPotAPBCTypeOneSideExcludeTypes(unittest.TestCase): @classmethod def setUpClass(self): + INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() + self.dp_original = DeepPot(FROZEN_MODEL) self.dp_compressed = DeepPot(COMPRESSED_MODEL) self.coords = np.array(