From beeb3d932695c872809f10ba9b35917f877af80d Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 25 Dec 2024 11:58:19 +0800 Subject: [PATCH] pd: support dpa2 (#4418) Support DPA-2 in paddle backend. This PR will be updated after #4414 is merged. ### Training curve: ![training_curves_comparison_dpa2](https://github.com/user-attachments/assets/29bdeffa-cf2d-4586-afcf-7df0569997c3) ### Accuracy test(left: paddle, right: torch): ![image](https://github.com/user-attachments/assets/5bff55f3-1c39-4b95-93f0-68783e794716) Ralated optimization of Paddle framework: - [x] https://github.com/PaddlePaddle/Paddle/pull/69349 - [x] https://github.com/PaddlePaddle/Paddle/pull/69333 - [x] https://github.com/PaddlePaddle/Paddle/pull/69479 - [x] https://github.com/PaddlePaddle/Paddle/pull/69515 - [x] https://github.com/PaddlePaddle/Paddle/pull/69487 - [x] https://github.com/PaddlePaddle/Paddle/pull/69661 - [x] https://github.com/PaddlePaddle/Paddle/pull/69660 - [x] https://github.com/PaddlePaddle/Paddle/pull/69596 - [x] https://github.com/PaddlePaddle/Paddle/pull/69556 ## Summary by CodeRabbit - **New Features** - Introduced new classes for molecular descriptors: `DescrptDPA2`, `DescrptBlockRepformers`, `DescrptSeTTebd`, and `DescrptBlockSeTTebd`. - Added new functions for tensor operations and descriptor management, enhancing the capabilities of the module. - Updated JSON configurations for multitask models to refine selection criteria and data paths. - **Bug Fixes** - Improved error handling and parameter validation across various descriptor classes. - **Documentation** - Enhanced test coverage for new descriptor functionalities and configurations. - **Tests** - Added new test classes to validate the functionality of `DescrptDPA2` and multitask training scenarios. - Expanded test capabilities for descriptor classes based on installed dependencies. - Updated existing tests to support new configurations and functionalities. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/pd/model/descriptor/__init__.py | 14 + deepmd/pd/model/descriptor/dpa2.py | 902 ++++++++++ deepmd/pd/model/descriptor/repformer_layer.py | 1484 +++++++++++++++++ deepmd/pd/model/descriptor/repformers.py | 576 +++++++ deepmd/pd/model/descriptor/se_t_tebd.py | 931 +++++++++++ deepmd/pd/model/task/fitting.py | 4 +- deepmd/pd/utils/multi_task.py | 4 +- deepmd/pd/utils/spin.py | 30 + .../tests/consistent/descriptor/test_dpa2.py | 51 + .../consistent/descriptor/test_se_t_tebd.py | 17 + source/tests/pd/model/models/dpa2.json | 57 + source/tests/pd/model/models/dpa2.pd | Bin 0 -> 119535 bytes source/tests/pd/model/test_autodiff.py | 43 +- source/tests/pd/model/test_descriptor_dpa2.py | 208 +++ source/tests/pd/model/test_dpa2.py | 333 ++++ source/tests/pd/model/test_forward_lower.py | 15 +- source/tests/pd/model/test_null_input.py | 12 +- source/tests/pd/model/test_permutation.py | 1 - source/tests/pd/model/test_rot.py | 1 - source/tests/pd/model/test_rot_denoise.py | 11 +- source/tests/pd/model/test_smooth.py | 31 + source/tests/pd/model/test_trans.py | 1 - source/tests/pd/model/test_unused_params.py | 92 + source/tests/pd/model/water/multitask.json | 3 +- .../pd/model/water/multitask_sharefit.json | 8 +- source/tests/pd/test_finetune.py | 15 +- source/tests/pd/test_multitask.py | 127 ++ source/tests/pd/test_training.py | 17 + source/tests/pd/test_update_sel.py | 62 +- 29 files changed, 4987 insertions(+), 63 deletions(-) create mode 100644 deepmd/pd/model/descriptor/dpa2.py create mode 100644 deepmd/pd/model/descriptor/repformer_layer.py create mode 100644 deepmd/pd/model/descriptor/repformers.py create mode 100644 deepmd/pd/model/descriptor/se_t_tebd.py create mode 100644 deepmd/pd/utils/spin.py create mode 100644 source/tests/pd/model/models/dpa2.json create mode 100644 source/tests/pd/model/models/dpa2.pd create mode 100644 source/tests/pd/model/test_descriptor_dpa2.py create mode 100644 source/tests/pd/model/test_dpa2.py create mode 100644 source/tests/pd/model/test_unused_params.py diff --git a/deepmd/pd/model/descriptor/__init__.py b/deepmd/pd/model/descriptor/__init__.py index 7eaa0df85b..cee9dbf226 100644 --- a/deepmd/pd/model/descriptor/__init__.py +++ b/deepmd/pd/model/descriptor/__init__.py @@ -9,20 +9,34 @@ DescrptBlockSeAtten, DescrptDPA1, ) +from .dpa2 import ( + DescrptDPA2, +) from .env_mat import ( prod_env_mat, ) +from .repformers import ( + DescrptBlockRepformers, +) from .se_a import ( DescrptBlockSeA, DescrptSeA, ) +from .se_t_tebd import ( + DescrptBlockSeTTebd, + DescrptSeTTebd, +) __all__ = [ "BaseDescriptor", "DescriptorBlock", + "DescrptBlockRepformers", "DescrptBlockSeA", "DescrptBlockSeAtten", + "DescrptBlockSeTTebd", "DescrptDPA1", + "DescrptDPA2", "DescrptSeA", + "DescrptSeTTebd", "prod_env_mat", ] diff --git a/deepmd/pd/model/descriptor/dpa2.py b/deepmd/pd/model/descriptor/dpa2.py new file mode 100644 index 0000000000..8d4e13edae --- /dev/null +++ b/deepmd/pd/model/descriptor/dpa2.py @@ -0,0 +1,902 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import paddle + +from deepmd.dpmodel.descriptor.dpa2 import ( + RepformerArgs, + RepinitArgs, +) +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pd.model.network.mlp import ( + Identity, + MLPLayer, + NetworkCollection, +) +from deepmd.pd.model.network.network import ( + TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, +) +from deepmd.pd.utils.nlist import ( + build_multiple_neighbor_list, + get_multiple_nlist_key, +) +from deepmd.pd.utils.update_sel import ( + UpdateSel, +) +from deepmd.pd.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_pair_exclude_types, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .descriptor import ( + extend_descrpt_stat, +) +from .repformer_layer import ( + RepformerLayer, +) +from .repformers import ( + DescrptBlockRepformers, +) +from .se_atten import ( + DescrptBlockSeAtten, +) +from .se_t_tebd import ( + DescrptBlockSeTTebd, +) + + +@BaseDescriptor.register("dpa2") +class DescrptDPA2(BaseDescriptor, paddle.nn.Layer): + def __init__( + self, + ntypes: int, + # args for repinit + repinit: Union[RepinitArgs, dict], + # args for repformer + repformer: Union[RepformerArgs, dict], + # kwargs for descriptor + concat_output_tebd: bool = True, + precision: str = "float64", + smooth: bool = True, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + trainable: bool = True, + seed: Optional[Union[int, list[int]]] = None, + add_tebd_to_repinit_out: bool = False, + use_econf_tebd: bool = False, + use_tebd_bias: bool = False, + type_map: Optional[list[str]] = None, + ) -> None: + r"""The DPA-2 descriptor[1]_. + + Parameters + ---------- + repinit : Union[RepinitArgs, dict] + The arguments used to initialize the repinit block, see docstr in `RepinitArgs` for details information. + repformer : Union[RepformerArgs, dict] + The arguments used to initialize the repformer block, see docstr in `RepformerArgs` for details information. + concat_output_tebd : bool, optional + Whether to concat type embedding at the output of the descriptor. + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable : bool, optional + If the parameters are trainable. + seed : int, optional + Random seed for parameter initialization. + add_tebd_to_repinit_out : bool, optional + Whether to add type embedding to the output representation from repinit before inputting it into repformer. + use_econf_tebd : bool, Optional + Whether to use electronic configuration type embedding. + use_tebd_bias : bool, Optional + Whether to use bias in the type embedding layer. + type_map : list[str], Optional + A list of strings. Give the name to each type of atoms. + + Returns + ------- + descriptor: paddle.Tensor + the descriptor of shape nb x nloc x g1_dim. + invariant single-atom representation. + g2: paddle.Tensor + invariant pair-atom representation. + h2: paddle.Tensor + equivariant pair-atom representation. + rot_mat: paddle.Tensor + rotation matrix for equivariant fittings + sw: paddle.Tensor + The switch function for decaying inverse distance. + + References + ---------- + .. [1] Zhang, D., Liu, X., Zhang, X. et al. DPA-2: a + large atomic model as a multi-task learner. npj + Comput Mater 10, 293 (2024). https://doi.org/10.1038/s41524-024-01493-2 + """ + super().__init__() + + def init_subclass_params(sub_data, sub_class): + if isinstance(sub_data, dict): + return sub_class(**sub_data) + elif isinstance(sub_data, sub_class): + return sub_data + else: + raise ValueError( + f"Input args must be a {sub_class.__name__} class or a dict!" + ) + + self.repinit_args = init_subclass_params(repinit, RepinitArgs) + self.repformer_args = init_subclass_params(repformer, RepformerArgs) + self.tebd_input_mode = self.repinit_args.tebd_input_mode + + self.repinit = DescrptBlockSeAtten( + self.repinit_args.rcut, + self.repinit_args.rcut_smth, + self.repinit_args.nsel, + ntypes, + attn_layer=0, + neuron=self.repinit_args.neuron, + axis_neuron=self.repinit_args.axis_neuron, + tebd_dim=self.repinit_args.tebd_dim, + tebd_input_mode=self.repinit_args.tebd_input_mode, + set_davg_zero=self.repinit_args.set_davg_zero, + exclude_types=exclude_types, + env_protection=env_protection, + activation_function=self.repinit_args.activation_function, + precision=precision, + resnet_dt=self.repinit_args.resnet_dt, + smooth=smooth, + type_one_side=self.repinit_args.type_one_side, + seed=child_seed(seed, 0), + ) + self.use_three_body = self.repinit_args.use_three_body + if self.use_three_body: + self.repinit_three_body = DescrptBlockSeTTebd( + self.repinit_args.three_body_rcut, + self.repinit_args.three_body_rcut_smth, + self.repinit_args.three_body_sel, + ntypes, + neuron=self.repinit_args.three_body_neuron, + tebd_dim=self.repinit_args.tebd_dim, + tebd_input_mode=self.repinit_args.tebd_input_mode, + set_davg_zero=self.repinit_args.set_davg_zero, + exclude_types=exclude_types, + env_protection=env_protection, + activation_function=self.repinit_args.activation_function, + precision=precision, + resnet_dt=self.repinit_args.resnet_dt, + smooth=smooth, + seed=child_seed(seed, 5), + ) + else: + self.repinit_three_body = None + self.repformers = DescrptBlockRepformers( + self.repformer_args.rcut, + self.repformer_args.rcut_smth, + self.repformer_args.nsel, + ntypes, + nlayers=self.repformer_args.nlayers, + g1_dim=self.repformer_args.g1_dim, + g2_dim=self.repformer_args.g2_dim, + axis_neuron=self.repformer_args.axis_neuron, + direct_dist=self.repformer_args.direct_dist, + update_g1_has_conv=self.repformer_args.update_g1_has_conv, + update_g1_has_drrd=self.repformer_args.update_g1_has_drrd, + update_g1_has_grrg=self.repformer_args.update_g1_has_grrg, + update_g1_has_attn=self.repformer_args.update_g1_has_attn, + update_g2_has_g1g1=self.repformer_args.update_g2_has_g1g1, + update_g2_has_attn=self.repformer_args.update_g2_has_attn, + update_h2=self.repformer_args.update_h2, + attn1_hidden=self.repformer_args.attn1_hidden, + attn1_nhead=self.repformer_args.attn1_nhead, + attn2_hidden=self.repformer_args.attn2_hidden, + attn2_nhead=self.repformer_args.attn2_nhead, + attn2_has_gate=self.repformer_args.attn2_has_gate, + activation_function=self.repformer_args.activation_function, + update_style=self.repformer_args.update_style, + update_residual=self.repformer_args.update_residual, + update_residual_init=self.repformer_args.update_residual_init, + set_davg_zero=self.repformer_args.set_davg_zero, + smooth=smooth, + exclude_types=exclude_types, + env_protection=env_protection, + precision=precision, + trainable_ln=self.repformer_args.trainable_ln, + ln_eps=self.repformer_args.ln_eps, + use_sqrt_nnei=self.repformer_args.use_sqrt_nnei, + g1_out_conv=self.repformer_args.g1_out_conv, + g1_out_mlp=self.repformer_args.g1_out_mlp, + seed=child_seed(seed, 1), + ) + self.rcsl_list = [ + (self.repformers.get_rcut(), self.repformers.get_nsel()), + (self.repinit.get_rcut(), self.repinit.get_nsel()), + ] + if self.use_three_body: + self.rcsl_list.append( + (self.repinit_three_body.get_rcut(), self.repinit_three_body.get_nsel()) + ) + self.rcsl_list.sort() + for ii in range(1, len(self.rcsl_list)): + assert ( + self.rcsl_list[ii - 1][1] <= self.rcsl_list[ii][1] + ), "rcut and sel are not in the same order" + self.rcut_list = [ii[0] for ii in self.rcsl_list] + self.nsel_list = [ii[1] for ii in self.rcsl_list] + self.use_econf_tebd = use_econf_tebd + self.use_tebd_bias = use_tebd_bias + self.type_map = type_map + self.type_embedding = TypeEmbedNet( + ntypes, + self.repinit_args.tebd_dim, + precision=precision, + seed=child_seed(seed, 2), + use_econf_tebd=self.use_econf_tebd, + use_tebd_bias=use_tebd_bias, + type_map=type_map, + ) + self.concat_output_tebd = concat_output_tebd + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.smooth = smooth + self.exclude_types = exclude_types + self.env_protection = env_protection + self.trainable = trainable + self.add_tebd_to_repinit_out = add_tebd_to_repinit_out + + self.repinit_out_dim = self.repinit.dim_out + if self.repinit_args.use_three_body: + assert self.repinit_three_body is not None + self.repinit_out_dim += self.repinit_three_body.dim_out + + if self.repinit_out_dim == self.repformers.dim_in: + self.g1_shape_tranform = Identity() + else: + self.g1_shape_tranform = MLPLayer( + self.repinit_out_dim, + self.repformers.dim_in, + bias=False, + precision=precision, + init="glorot", + seed=child_seed(seed, 3), + ) + self.tebd_transform = None + if self.add_tebd_to_repinit_out: + self.tebd_transform = MLPLayer( + self.repinit_args.tebd_dim, + self.repformers.dim_in, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + assert self.repinit.rcut > self.repformers.rcut + assert self.repinit.sel[0] > self.repformers.sel[0] + + self.tebd_dim = self.repinit_args.tebd_dim + self.rcut = self.repinit.get_rcut() + self.rcut_smth = self.repinit.get_rcut_smth() + self.ntypes = ntypes + self.sel = self.repinit.sel + # set trainable + for param in self.parameters(): + param.stop_gradient = not trainable + self.compress = False + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + ret = self.repformers.dim_out + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + """Returns the embedding dimension of this descriptor.""" + return self.repformers.dim_emb + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return any( + [self.repinit.has_message_passing(), self.repformers.has_message_passing()] + ) + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + # the env_protection of repinit is the same as that of the repformer + return self.repinit.get_env_protection() + + def share_params(self, base_class, shared_level, resume=False) -> None: + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some separated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For DPA2 descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in type_embedding, repinit and repformers + if shared_level == 0: + self._sub_layers["type_embedding"] = base_class._sub_layers[ + "type_embedding" + ] + self.repinit.share_params(base_class.repinit, 0, resume=resume) + if self.use_three_body: + self.repinit_three_body.share_params( + base_class.repinit_three_body, 0, resume=resume + ) + self._sub_layers["g1_shape_tranform"] = base_class._sub_layers[ + "g1_shape_tranform" + ] + self.repformers.share_params(base_class.repformers, 0, resume=resume) + # shared_level: 1 + # share all parameters in type_embedding + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + # Other shared levels + else: + raise NotImplementedError + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert ( + self.type_map is not None + ), "'type_map' must be defined when performing type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + self.type_map = type_map + self.type_embedding.change_type_map(type_map=type_map) + self.exclude_types = map_pair_exclude_types(self.exclude_types, remap_index) + self.ntypes = len(type_map) + repinit = self.repinit + repformers = self.repformers + repinit_three_body = self.repinit_three_body + if has_new_type: + # the avg and std of new types need to be updated + extend_descrpt_stat( + repinit, + type_map, + des_with_stat=model_with_new_type_stat.repinit + if model_with_new_type_stat is not None + else None, + ) + extend_descrpt_stat( + repformers, + type_map, + des_with_stat=model_with_new_type_stat.repformers + if model_with_new_type_stat is not None + else None, + ) + if self.use_three_body: + extend_descrpt_stat( + repinit_three_body, + type_map, + des_with_stat=model_with_new_type_stat.repinit_three_body + if model_with_new_type_stat is not None + else None, + ) + repinit.ntypes = self.ntypes + repformers.ntypes = self.ntypes + repinit.reinit_exclude(self.exclude_types) + repformers.reinit_exclude(self.exclude_types) + repinit["davg"] = repinit["davg"][remap_index] + repinit["dstd"] = repinit["dstd"][remap_index] + repformers["davg"] = repformers["davg"][remap_index] + repformers["dstd"] = repformers["dstd"][remap_index] + if self.use_three_body: + repinit_three_body.ntypes = self.ntypes + repinit_three_body.reinit_exclude(self.exclude_types) + repinit_three_body["davg"] = repinit_three_body["davg"][remap_index] + repinit_three_body["dstd"] = repinit_three_body["dstd"][remap_index] + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + descrpt_list = [self.repinit, self.repformers] + if self.use_three_body: + descrpt_list.append(self.repinit_three_body) + for ii, descrpt in enumerate(descrpt_list): + descrpt.compute_input_stats(merged, path) + + def set_stat_mean_and_stddev( + self, + mean: list[paddle.Tensor], + stddev: list[paddle.Tensor], + ) -> None: + """Update mean and stddev for descriptor.""" + descrpt_list = [self.repinit, self.repformers] + if self.use_three_body: + descrpt_list.append(self.repinit_three_body) + for ii, descrpt in enumerate(descrpt_list): + descrpt.mean = mean[ii] + descrpt.stddev = stddev[ii] + + def get_stat_mean_and_stddev( + self, + ) -> tuple[list[paddle.Tensor], list[paddle.Tensor]]: + """Get mean and stddev for descriptor.""" + mean_list = [self.repinit.mean, self.repformers.mean] + stddev_list = [ + self.repinit.stddev, + self.repformers.stddev, + ] + if self.use_three_body: + mean_list.append(self.repinit_three_body.mean) + stddev_list.append(self.repinit_three_body.stddev) + return mean_list, stddev_list + + def serialize(self) -> dict: + repinit = self.repinit + repformers = self.repformers + repinit_three_body = self.repinit_three_body + data = { + "@class": "Descriptor", + "type": "dpa2", + "@version": 3, + "ntypes": self.ntypes, + "repinit_args": self.repinit_args.serialize(), + "repformer_args": self.repformer_args.serialize(), + "concat_output_tebd": self.concat_output_tebd, + "precision": self.precision, + "smooth": self.smooth, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "trainable": self.trainable, + "add_tebd_to_repinit_out": self.add_tebd_to_repinit_out, + "use_econf_tebd": self.use_econf_tebd, + "use_tebd_bias": self.use_tebd_bias, + "type_map": self.type_map, + "type_embedding": self.type_embedding.embedding.serialize(), + "g1_shape_tranform": self.g1_shape_tranform.serialize(), + } + if self.add_tebd_to_repinit_out: + data.update( + { + "tebd_transform": self.tebd_transform.serialize(), + } + ) + repinit_variable = { + "embeddings": repinit.filter_layers.serialize(), + "env_mat": DPEnvMat(repinit.rcut, repinit.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(repinit["davg"]), + "dstd": to_numpy_array(repinit["dstd"]), + }, + } + if repinit.tebd_input_mode in ["strip"]: + repinit_variable.update( + {"embeddings_strip": repinit.filter_layers_strip.serialize()} + ) + repformers_variable = { + "g2_embd": repformers.g2_embd.serialize(), + "repformer_layers": [layer.serialize() for layer in repformers.layers], + "env_mat": DPEnvMat(repformers.rcut, repformers.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(repformers["davg"]), + "dstd": to_numpy_array(repformers["dstd"]), + }, + } + data.update( + { + "repinit_variable": repinit_variable, + "repformers_variable": repformers_variable, + } + ) + if self.use_three_body: + repinit_three_body_variable = { + "embeddings": repinit_three_body.filter_layers.serialize(), + "env_mat": DPEnvMat( + repinit_three_body.rcut, repinit_three_body.rcut_smth + ).serialize(), + "@variables": { + "davg": to_numpy_array(repinit_three_body["davg"]), + "dstd": to_numpy_array(repinit_three_body["dstd"]), + }, + } + if repinit_three_body.tebd_input_mode in ["strip"]: + repinit_three_body_variable.update( + { + "embeddings_strip": repinit_three_body.filter_layers_strip.serialize() + } + ) + data.update( + { + "repinit_three_body_variable": repinit_three_body_variable, + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA2": + data = data.copy() + version = data.pop("@version") + check_version_compatibility(version, 3, 1) + data.pop("@class") + data.pop("type") + repinit_variable = data.pop("repinit_variable").copy() + repformers_variable = data.pop("repformers_variable").copy() + repinit_three_body_variable = ( + data.pop("repinit_three_body_variable").copy() + if "repinit_three_body_variable" in data + else None + ) + type_embedding = data.pop("type_embedding") + g1_shape_tranform = data.pop("g1_shape_tranform") + tebd_transform = data.pop("tebd_transform", None) + add_tebd_to_repinit_out = data["add_tebd_to_repinit_out"] + if version < 3: + # compat with old version + data["repformer_args"]["use_sqrt_nnei"] = False + data["repformer_args"]["g1_out_conv"] = False + data["repformer_args"]["g1_out_mlp"] = False + data["repinit"] = RepinitArgs(**data.pop("repinit_args")) + data["repformer"] = RepformerArgs(**data.pop("repformer_args")) + # compat with version 1 + if "use_tebd_bias" not in data: + data["use_tebd_bias"] = True + obj = cls(**data) + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + if add_tebd_to_repinit_out: + assert isinstance(tebd_transform, dict) + obj.tebd_transform = MLPLayer.deserialize(tebd_transform) + if obj.repinit.dim_out != obj.repformers.dim_in: + obj.g1_shape_tranform = MLPLayer.deserialize(g1_shape_tranform) + + def t_cvt(xx): + return paddle.to_tensor(xx, dtype=obj.repinit.prec, place=env.DEVICE) + + # deserialize repinit + statistic_repinit = repinit_variable.pop("@variables") + env_mat = repinit_variable.pop("env_mat") + tebd_input_mode = data["repinit"].tebd_input_mode + obj.repinit.filter_layers = NetworkCollection.deserialize( + repinit_variable.pop("embeddings") + ) + if tebd_input_mode in ["strip"]: + obj.repinit.filter_layers_strip = NetworkCollection.deserialize( + repinit_variable.pop("embeddings_strip") + ) + obj.repinit["davg"] = t_cvt(statistic_repinit["davg"]) + obj.repinit["dstd"] = t_cvt(statistic_repinit["dstd"]) + + if data["repinit"].use_three_body: + # deserialize repinit_three_body + statistic_repinit_three_body = repinit_three_body_variable.pop("@variables") + env_mat = repinit_three_body_variable.pop("env_mat") + tebd_input_mode = data["repinit"].tebd_input_mode + obj.repinit_three_body.filter_layers = NetworkCollection.deserialize( + repinit_three_body_variable.pop("embeddings") + ) + if tebd_input_mode in ["strip"]: + obj.repinit_three_body.filter_layers_strip = ( + NetworkCollection.deserialize( + repinit_three_body_variable.pop("embeddings_strip") + ) + ) + obj.repinit_three_body["davg"] = t_cvt(statistic_repinit_three_body["davg"]) + obj.repinit_three_body["dstd"] = t_cvt(statistic_repinit_three_body["dstd"]) + + # deserialize repformers + statistic_repformers = repformers_variable.pop("@variables") + env_mat = repformers_variable.pop("env_mat") + repformer_layers = repformers_variable.pop("repformer_layers") + obj.repformers.g2_embd = MLPLayer.deserialize( + repformers_variable.pop("g2_embd") + ) + obj.repformers["davg"] = t_cvt(statistic_repformers["davg"]) + obj.repformers["dstd"] = t_cvt(statistic_repformers["dstd"]) + obj.repformers.layers = paddle.nn.LayerList( + [RepformerLayer.deserialize(layer) for layer in repformer_layers] + ) + return obj + + def forward( + self, + extended_coord: paddle.Tensor, + extended_atype: paddle.Tensor, + nlist: paddle.Tensor, + mapping: Optional[paddle.Tensor] = None, + comm_dict: Optional[dict[str, paddle.Tensor]] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, mapps extended region index to local region. + comm_dict + The data needed for communication for parallel inference. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + + use_three_body = self.use_three_body + nframes, nloc, nnei = nlist.shape + nall = extended_coord.reshape([nframes, -1]).shape[1] // 3 + # nlists + nlist_dict = build_multiple_neighbor_list( + extended_coord.detach(), + nlist, + self.rcut_list, + self.nsel_list, + ) + # repinit + g1_ext = self.type_embedding(extended_atype) + g1_inp = g1_ext[:, :nloc, :] + if self.tebd_input_mode in ["strip"]: + type_embedding = self.type_embedding.get_full_embedding(g1_ext.place) + else: + type_embedding = None + g1, _, _, _, _ = self.repinit( + nlist_dict[ + get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel()) + ], + extended_coord, + extended_atype, + g1_ext, + mapping, + type_embedding, + ) + if use_three_body: + assert self.repinit_three_body is not None + g1_three_body, __, __, __, __ = self.repinit_three_body( + nlist_dict[ + get_multiple_nlist_key( + self.repinit_three_body.get_rcut(), + self.repinit_three_body.get_nsel(), + ) + ], + extended_coord, + extended_atype, + g1_ext, + mapping, + type_embedding, + ) + g1 = paddle.concat([g1, g1_three_body], axis=-1) + # linear to change shape + g1 = self.g1_shape_tranform(g1) + if self.add_tebd_to_repinit_out: + assert self.tebd_transform is not None + g1 = g1 + self.tebd_transform(g1_inp) + # mapping g1 + if comm_dict is None: + assert mapping is not None + mapping_ext = ( + mapping.reshape([nframes, nall]) + .unsqueeze(-1) + .expand([-1, -1, g1.shape[-1]]) + ) + g1_ext = paddle.take_along_axis(g1, mapping_ext, 1) + g1 = g1_ext + # repformer + g1, g2, h2, rot_mat, sw = self.repformers( + nlist_dict[ + get_multiple_nlist_key( + self.repformers.get_rcut(), self.repformers.get_nsel() + ) + ], + extended_coord, + extended_atype, + g1, + mapping, + comm_dict=comm_dict, + ) + if self.concat_output_tebd: + g1 = paddle.concat([g1, g1_inp], axis=-1) + return ( + g1.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + ) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + update_sel = UpdateSel() + min_nbor_dist, repinit_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repinit"]["rcut"], + local_jdata_cpy["repinit"]["nsel"], + True, + ) + local_jdata_cpy["repinit"]["nsel"] = repinit_sel[0] + min_nbor_dist, repinit_three_body_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repinit"]["three_body_rcut"], + local_jdata_cpy["repinit"]["three_body_sel"], + True, + ) + local_jdata_cpy["repinit"]["three_body_sel"] = repinit_three_body_sel[0] + min_nbor_dist, repformer_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repformer"]["rcut"], + local_jdata_cpy["repformer"]["nsel"], + True, + ) + local_jdata_cpy["repformer"]["nsel"] = repformer_sel[0] + return local_jdata_cpy, min_nbor_dist + + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + # do some checks before the mocel compression process + raise NotImplementedError("enable_compression is not implemented yet") diff --git a/deepmd/pd/model/descriptor/repformer_layer.py b/deepmd/pd/model/descriptor/repformer_layer.py new file mode 100644 index 0000000000..a09c5cbe17 --- /dev/null +++ b/deepmd/pd/model/descriptor/repformer_layer.py @@ -0,0 +1,1484 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, + Union, +) + +import paddle +import paddle.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pd.model.network.init import ( + constant_, + normal_, +) +from deepmd.pd.model.network.layernorm import ( + LayerNorm, +) +from deepmd.pd.model.network.mlp import ( + MLPLayer, +) +from deepmd.pd.utils import ( + decomp, + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, +) +from deepmd.pd.utils.utils import ( + ActivationFn, + get_generator, + to_numpy_array, + to_paddle_tensor, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +def get_residual( + _dim: int, + _scale: float, + _mode: str = "norm", + trainable: bool = True, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, +) -> paddle.Tensor: + r""" + Get residual tensor for one update vector. + + Parameters + ---------- + _dim : int + The dimension of the update vector. + _scale + The initial scale of the residual tensor. See `_mode` for details. + _mode + The mode of residual initialization for the residual tensor. + - "norm" (default): init residual using normal with `_scale` std. + - "const": init residual using element-wise constants of `_scale`. + trainable + Whether the residual tensor is trainable. + precision + The precision of the residual tensor. + seed : int, optional + Random seed for parameter initialization. + """ + random_generator = get_generator(seed) + residual = paddle.create_parameter( + [_dim], + dtype=PRECISION_DICT[precision], + default_initializer=nn.initializer.Constant(0), + ).to(device=env.DEVICE) + residual.stop_gradient = not trainable + if _mode == "norm": + normal_(residual.data, std=_scale, generator=random_generator) + elif _mode == "const": + constant_(residual.data, val=_scale) + else: + raise RuntimeError(f"Unsupported initialization mode '{_mode}'!") + return residual + + +# common ops +def _make_nei_g1( + g1_ext: paddle.Tensor, + nlist: paddle.Tensor, +) -> paddle.Tensor: + """ + Make neighbor-wise atomic invariant rep. + + Parameters + ---------- + g1_ext + Extended atomic invariant rep, with shape nb x nall x ng1. + nlist + Neighbor list, with shape nb x nloc x nnei. + + Returns + ------- + gg1: paddle.Tensor + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + + """ + # nlist: nb x nloc x nnei + nb, nloc, nnei = nlist.shape + # g1_ext: nb x nall x ng1 + ng1 = g1_ext.shape[-1] + # index: nb x (nloc x nnei) x ng1 + index = nlist.reshape([nb, nloc * nnei]).unsqueeze(-1).expand([-1, -1, ng1]) + # gg1 : nb x (nloc x nnei) x ng1 + gg1 = paddle.take_along_axis(g1_ext, axis=1, indices=index) + # gg1 : nb x nloc x nnei x ng1 + gg1 = gg1.reshape([nb, nloc, nnei, ng1]) + return gg1 + + +def _apply_nlist_mask( + gg: paddle.Tensor, + nlist_mask: paddle.Tensor, +) -> paddle.Tensor: + """ + Apply nlist mask to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + """ + # gg: nf x nloc x nnei x d + # msk: nf x nloc x nnei + return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0) + + +def _apply_switch(gg: paddle.Tensor, sw: paddle.Tensor) -> paddle.Tensor: + """ + Apply switch function to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + """ + # gg: nf x nloc x nnei x d + # sw: nf x nloc x nnei + return gg * sw.unsqueeze(-1) + + +class Atten2Map(paddle.nn.Layer): + def __init__( + self, + input_dim: int, + hidden_dim: int, + head_num: int, + has_gate: bool = False, # apply gate to attn map + smooth: bool = True, + attnw_shift: float = 20.0, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ): + """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapqk = MLPLayer( + input_dim, + hidden_dim * 2 * head_num, + bias=False, + precision=precision, + seed=seed, + ) + self.has_gate = has_gate + self.smooth = smooth + self.attnw_shift = attnw_shift + self.precision = precision + + def forward( + self, + g2: paddle.Tensor, # nb x nloc x nnei x ng2 + h2: paddle.Tensor, # nb x nloc x nnei x 3 + nlist_mask: paddle.Tensor, # nb x nloc x nnei + sw: paddle.Tensor, # nb x nloc x nnei + ) -> paddle.Tensor: + ( + nb, + nloc, + nnei, + _, + ) = g2.shape + nd, nh = self.hidden_dim, self.head_num + # nb x nloc x nnei x nd x (nh x 2) + g2qk = self.mapqk(g2).reshape([nb, nloc, nnei, nd, nh * 2]) + # nb x nloc x (nh x 2) x nnei x nd + g2qk = paddle.transpose(g2qk, (0, 1, 4, 2, 3)) + # nb x nloc x nh x nnei x nd + g2q, g2k = paddle.split(g2qk, decomp.sec(g2qk.shape[2], nh), axis=2) + # g2q = paddle.nn.functional.normalize(g2q, axis=-1) + # g2k = paddle.nn.functional.normalize(g2k, axis=-1) + # nb x nloc x nh x nnei x nnei + attnw = paddle.matmul(g2q, paddle.transpose(g2k, [0, 1, 2, 4, 3])) / nd**0.5 + if self.has_gate: + gate = paddle.matmul(h2, paddle.transpose(h2, [0, 1, 3, 2])).unsqueeze(-3) + attnw = attnw * gate + # mask the attenmap, nb x nloc x 1 x 1 x nnei + attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2) + # mask the attenmap, nb x nloc x 1 x nnei x 1 + attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1) + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ + :, :, None, None, : + ] - self.attnw_shift + else: + attnw = attnw.masked_fill( + attnw_mask, + float("-inf"), + ) + attnw = paddle.nn.functional.softmax(attnw, axis=-1) + attnw = attnw.masked_fill( + attnw_mask, + 0.0, + ) + # nb x nloc x nh x nnei x nnei + attnw = attnw.masked_fill( + attnw_mask_c, + 0.0, + ) + if self.smooth: + attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] + # nb x nloc x nnei x nnei + h2h2t = paddle.matmul(h2, paddle.transpose(h2, [0, 1, 3, 2])) / 3.0**0.5 + # nb x nloc x nh x nnei x nnei + ret = attnw * h2h2t[:, :, None, :, :] + # ret = paddle.nn.functional.softmax(g2qk, axis=-1) + # nb x nloc x nnei x nnei x nh + ret = paddle.transpose(ret, (0, 1, 3, 4, 2)) + return ret + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2Map", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "has_gate": self.has_gate, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapqk": self.mapqk.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2Map": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapqk = data.pop("mapqk") + obj = cls(**data) + obj.mapqk = MLPLayer.deserialize(mapqk) + return obj + + +class Atten2MultiHeadApply(paddle.nn.Layer): + def __init__( + self, + input_dim: int, + head_num: int, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.head_num = head_num + self.mapv = MLPLayer( + input_dim, + input_dim * head_num, + bias=False, + precision=precision, + seed=child_seed(seed, 0), + ) + self.head_map = MLPLayer( + input_dim * head_num, + input_dim, + precision=precision, + seed=child_seed(seed, 1), + ) + self.precision = precision + + def forward( + self, + AA: paddle.Tensor, # nf x nloc x nnei x nnei x nh + g2: paddle.Tensor, # nf x nloc x nnei x ng2 + ) -> paddle.Tensor: + nf, nloc, nnei, ng2 = g2.shape + nh = self.head_num + # nf x nloc x nnei x ng2 x nh + g2v = self.mapv(g2).reshape([nf, nloc, nnei, ng2, nh]) + # nf x nloc x nh x nnei x ng2 + g2v = paddle.transpose(g2v, (0, 1, 4, 2, 3)) + # g2v = paddle.nn.functional.normalize(g2v, axis=-1) + # nf x nloc x nh x nnei x nnei + AA = paddle.transpose(AA, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x ng2 + ret = paddle.matmul(AA, g2v) + # nf x nloc x nnei x ng2 x nh + ret = paddle.transpose(ret, (0, 1, 3, 4, 2)).reshape( + [nf, nloc, nnei, (ng2 * nh)] + ) + # nf x nloc x nnei x ng2 + return self.head_map(ret) + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2MultiHeadApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "mapv": self.mapv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2MultiHeadApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapv = data.pop("mapv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapv = MLPLayer.deserialize(mapv) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + + +class Atten2EquiVarApply(paddle.nn.Layer): + def __init__( + self, + input_dim: int, + head_num: int, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.head_num = head_num + self.head_map = MLPLayer( + head_num, 1, bias=False, precision=precision, seed=seed + ) + self.precision = precision + + def forward( + self, + AA: paddle.Tensor, # nf x nloc x nnei x nnei x nh + h2: paddle.Tensor, # nf x nloc x nnei x 3 + ) -> paddle.Tensor: + nf, nloc, nnei, _ = h2.shape + nh = self.head_num + # nf x nloc x nh x nnei x nnei + AA = paddle.transpose(AA, (0, 1, 4, 2, 3)) + h2m = paddle.unsqueeze(h2, axis=2) + # nf x nloc x nh x nnei x 3 + h2m = paddle.tile(h2m, [1, 1, nh, 1, 1]) + # nf x nloc x nh x nnei x 3 + ret = paddle.matmul(AA, h2m) + # nf x nloc x nnei x 3 x nh + ret = paddle.transpose(ret, (0, 1, 3, 4, 2)).reshape([nf, nloc, nnei, 3, nh]) + # nf x nloc x nnei x 3 + return paddle.squeeze(self.head_map(ret), axis=-1) + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2EquiVarApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2EquiVarApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + head_map = data.pop("head_map") + obj = cls(**data) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + + +class LocalAtten(paddle.nn.Layer): + def __init__( + self, + input_dim: int, + hidden_dim: int, + head_num: int, + smooth: bool = True, + attnw_shift: float = 20.0, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapq = MLPLayer( + input_dim, + hidden_dim * 1 * head_num, + bias=False, + precision=precision, + seed=child_seed(seed, 0), + ) + self.mapkv = MLPLayer( + input_dim, + (hidden_dim + input_dim) * head_num, + bias=False, + precision=precision, + seed=child_seed(seed, 1), + ) + self.head_map = MLPLayer( + input_dim * head_num, + input_dim, + precision=precision, + seed=child_seed(seed, 2), + ) + self.smooth = smooth + self.attnw_shift = attnw_shift + self.precision = precision + + def forward( + self, + g1: paddle.Tensor, # nb x nloc x ng1 + gg1: paddle.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: paddle.Tensor, # nb x nloc x nnei + sw: paddle.Tensor, # nb x nloc x nnei + ) -> paddle.Tensor: + nb, nloc, nnei = nlist_mask.shape + ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num + assert ni == g1.shape[-1] + assert ni == gg1.shape[-1] + # nb x nloc x nd x nh + g1q = self.mapq(g1).reshape([nb, nloc, nd, nh]) + # nb x nloc x nh x nd + g1q = paddle.transpose(g1q, (0, 1, 3, 2)) + # nb x nloc x nnei x (nd+ni) x nh + gg1kv = self.mapkv(gg1).reshape([nb, nloc, nnei, nd + ni, nh]) + gg1kv = paddle.transpose(gg1kv, (0, 1, 4, 2, 3)) + # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 + gg1k, gg1v = paddle.split(gg1kv, [nd, ni], axis=-1) + + # nb x nloc x nh x 1 x nnei + attnw = ( + paddle.matmul(g1q.unsqueeze(-2), paddle.transpose(gg1k, [0, 1, 2, 4, 3])) + / nd**0.5 + ) + # nb x nloc x nh x nnei + attnw = attnw.squeeze(-2) + # mask the attenmap, nb x nloc x 1 x nnei + attnw_mask = ~nlist_mask.unsqueeze(-2) + # nb x nloc x nh x nnei + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift + else: + attnw = attnw.masked_fill( + attnw_mask, + float("-inf"), + ) + attnw = paddle.nn.functional.softmax(attnw, axis=-1) + attnw = attnw.masked_fill( + attnw_mask, + 0.0, + ) + if self.smooth: + attnw = attnw * sw.unsqueeze(-2) + + # nb x nloc x nh x ng1 + ret = ( + paddle.matmul(attnw.unsqueeze(-2), gg1v) + .squeeze(-2) + .reshape([nb, nloc, nh * ni]) + ) + # nb x nloc x ng1 + ret = self.head_map(ret) + return ret + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "LocalAtten", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapq": self.mapq.serialize(), + "mapkv": self.mapkv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "LocalAtten": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapq = data.pop("mapq") + mapkv = data.pop("mapkv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapq = MLPLayer.deserialize(mapq) + obj.mapkv = MLPLayer.deserialize(mapkv) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + + +class RepformerLayer(paddle.nn.Layer): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + g1_dim=128, + g2_dim=16, + axis_neuron: int = 4, + update_chnnl_2: bool = True, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation_function: str = "tanh", + update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", + smooth: bool = True, + precision: str = "float64", + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + use_sqrt_nnei: bool = True, + g1_out_conv: bool = True, + g1_out_mlp: bool = True, + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) + self.ntypes = ntypes + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + assert len(sel) == 1 + self.sel = sel + self.sec = self.sel + self.axis_neuron = axis_neuron + self.activation_function = activation_function + self.act = ActivationFn(activation_function) + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_attn = update_g1_has_attn + self.update_chnnl_2 = update_chnnl_2 + self.update_g2_has_g1g1 = update_g2_has_g1g1 if self.update_chnnl_2 else False + self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False + self.update_h2 = update_h2 if self.update_chnnl_2 else False + del update_g2_has_g1g1, update_g2_has_attn, update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.attn2_has_gate = attn2_has_gate + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.smooth = smooth + self.g1_dim = g1_dim + self.g2_dim = g2_dim + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.precision = precision + self.seed = seed + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp + + assert update_residual_init in [ + "norm", + "const", + ], "'update_residual_init' only support 'norm' or 'const'!" + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.g1_residual = [] + self.g2_residual = [] + self.h2_residual = [] + + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 0), + ) + ) + + g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) + self.linear1 = MLPLayer( + g1_in_dim, + g1_dim, + precision=precision, + seed=child_seed(seed, 1), + ) + self.linear2 = None + self.proj_g1g2 = None + self.proj_g1g1g2 = None + self.attn2g_map = None + self.attn2_mh_apply = None + self.attn2_lm = None + self.attn2_ev_apply = None + self.loc_attn = None + + if self.update_chnnl_2: + self.linear2 = MLPLayer( + g2_dim, + g2_dim, + precision=precision, + seed=child_seed(seed, 2), + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 3), + ) + ) + if self.g1_out_mlp: + self.g1_self_mlp = MLPLayer( + g1_dim, + g1_dim, + precision=precision, + seed=child_seed(seed, 15), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 16), + ) + ) + else: + self.g1_self_mlp = None + if self.update_g1_has_conv: + if not self.g1_out_conv: + self.proj_g1g2 = MLPLayer( + g1_dim, + g2_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + else: + self.proj_g1g2 = MLPLayer( + g2_dim, + g1_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 17), + ) + ) + if self.update_g2_has_g1g1: + self.proj_g1g1g2 = MLPLayer( + g1_dim, + g2_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 5), + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 6), + ) + ) + if self.update_g2_has_attn or self.update_h2: + self.attn2g_map = Atten2Map( + g2_dim, + attn2_hidden, + attn2_nhead, + attn2_has_gate, + self.smooth, + precision=precision, + seed=child_seed(seed, 7), + ) + if self.update_g2_has_attn: + self.attn2_mh_apply = Atten2MultiHeadApply( + g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 8) + ) + self.attn2_lm = LayerNorm( + g2_dim, + eps=ln_eps, + trainable=trainable_ln, + precision=precision, + seed=child_seed(seed, 9), + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 10), + ) + ) + + if self.update_h2: + self.attn2_ev_apply = Atten2EquiVarApply( + g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 11) + ) + if self.update_style == "res_residual": + self.h2_residual.append( + get_residual( + 1, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 12), + ) + ) + if self.update_g1_has_attn: + self.loc_attn = LocalAtten( + g1_dim, + attn1_hidden, + attn1_nhead, + self.smooth, + precision=precision, + seed=child_seed(seed, 13), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 14), + ) + ) + + self.g1_residual = nn.ParameterList(self.g1_residual) + self.g2_residual = nn.ParameterList(self.g2_residual) + self.h2_residual = nn.ParameterList(self.h2_residual) + + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: + ret = g1d if not self.g1_out_mlp else 0 + if self.update_g1_has_grrg: + ret += g2d * ax + if self.update_g1_has_drrd: + ret += g1d * ax + if self.update_g1_has_conv and not self.g1_out_conv: + ret += g2d + return ret + + def _update_h2( + self, + h2: paddle.Tensor, + attn: paddle.Tensor, + ) -> paddle.Tensor: + """ + Calculate the attention weights update for pair-wise equivariant rep. + + Parameters + ---------- + h2 + Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + attn + Attention weights from g2 attention, with shape nf x nloc x nnei x nnei x nh2. + """ + assert self.attn2_ev_apply is not None + # nf x nloc x nnei x nh2 + h2_1 = self.attn2_ev_apply(attn, h2) + return h2_1 + + def _update_g1_conv( + self, + gg1: paddle.Tensor, + g2: paddle.Tensor, + nlist_mask: paddle.Tensor, + sw: paddle.Tensor, + ) -> paddle.Tensor: + """ + Calculate the convolution update for atomic invariant rep. + + Parameters + ---------- + gg1 + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + g2 + Pair invariant rep, with shape nb x nloc x nnei x ng2. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + """ + assert self.proj_g1g2 is not None + nb, nloc, nnei, _ = g2.shape + ng1 = gg1.shape[-1] + ng2 = g2.shape[-1] + if not self.g1_out_conv: + # gg1 : nb x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).reshape([nb, nloc, nnei, ng2]) + else: + gg1 = gg1.reshape([nb, nloc, nnei, ng1]) + # nb x nloc x nnei x ng2/ng1 + gg1 = _apply_nlist_mask(gg1, nlist_mask) + if not self.smooth: + # normalized by number of neighbors, not smooth + # nb x nloc x 1 + # must use astype here to convert bool to float, otherwise there will be numerical difference from numpy + invnnei = 1.0 / ( + self.epsilon + paddle.sum(nlist_mask.astype(gg1.dtype), axis=-1) + ).unsqueeze(-1) + else: + gg1 = _apply_switch(gg1, sw) + invnnei = (1.0 / float(nnei)) * paddle.ones( + (nb, nloc, 1), dtype=gg1.dtype + ).to(device=gg1.place) + if not self.g1_out_conv: + # nb x nloc x ng2 + g1_11 = paddle.sum(g2 * gg1, axis=2) * invnnei + else: + g2 = self.proj_g1g2(g2).reshape([nb, nloc, nnei, ng1]) + # nb x nloc x ng1 + g1_11 = paddle.sum(g2 * gg1, axis=2) * invnnei + return g1_11 + + @staticmethod + def _cal_hg( + g2: paddle.Tensor, + h2: paddle.Tensor, + nlist_mask: paddle.Tensor, + sw: paddle.Tensor, + smooth: bool = True, + epsilon: float = 1e-4, + use_sqrt_nnei: bool = True, + ) -> paddle.Tensor: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + g2 + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape nb x nloc x 3 x ng2. + """ + # g2: nb x nloc x nnei x ng2 + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = g2.shape + ng2 = g2.shape[-1] + # nb x nloc x nnei x ng2 + g2 = _apply_nlist_mask(g2, nlist_mask) + if not smooth: + # nb x nloc + # must use astype here to convert bool to float, otherwise there will be numerical difference from numpy + if not use_sqrt_nnei: + invnnei = 1.0 / ( + epsilon + paddle.sum(nlist_mask.astype(g2.dtype), axis=-1) + ) + else: + invnnei = 1.0 / ( + epsilon + + paddle.sqrt(paddle.sum(nlist_mask.astype(g2.dtype), axis=-1)) + ) + # nb x nloc x 1 x 1 + invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) + else: + g2 = _apply_switch(g2, sw) + if not use_sqrt_nnei: + invnnei = (1.0 / float(nnei)) * paddle.ones( + (nb, nloc, 1, 1), dtype=g2.dtype + ).to(device=g2.place) + else: + invnnei = paddle.rsqrt( + float(nnei) + * paddle.ones((nb, nloc, 1, 1), dtype=g2.dtype).to(device=g2.place) + ) + # nb x nloc x 3 x ng2 + h2g2 = paddle.matmul(paddle.transpose(h2, [0, 1, 3, 2]), g2) * invnnei + return h2g2 + + @staticmethod + def _cal_grrg(h2g2: paddle.Tensor, axis_neuron: int) -> paddle.Tensor: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + h2g2 + The transposed rotation matrix, with shape nb x nloc x 3 x ng2. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) + """ + # nb x nloc x 3 x ng2 + nb, nloc, _, ng2 = h2g2.shape + # nb x nloc x 3 x axis + # h2g2m = paddle.split(h2g2, decomp.sec(h2g2.shape[-1], axis_neuron), axis=-1)[0] + h2g2m = h2g2[..., :axis_neuron] # use slice instead of split + # nb x nloc x axis x ng2 + g1_13 = paddle.matmul(paddle.transpose(h2g2m, [0, 1, 3, 2]), h2g2) / (3.0**1) + # nb x nloc x (axisxng2) + g1_13 = g1_13.reshape([nb, nloc, axis_neuron * ng2]) + return g1_13 + + def symmetrization_op( + self, + g2: paddle.Tensor, + h2: paddle.Tensor, + nlist_mask: paddle.Tensor, + sw: paddle.Tensor, + axis_neuron: int, + smooth: bool = True, + epsilon: float = 1e-4, + ) -> paddle.Tensor: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + g2 + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + axis_neuron + Size of the submatrix. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) + """ + # g2: nb x nloc x nnei x ng2 + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = g2.shape + # nb x nloc x 3 x ng2 + h2g2 = self._cal_hg( + g2, + h2, + nlist_mask, + sw, + smooth=smooth, + epsilon=epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, + ) + # nb x nloc x (axisxng2) + g1_13 = self._cal_grrg(h2g2, axis_neuron) + return g1_13 + + def _update_g2_g1g1( + self, + g1: paddle.Tensor, # nb x nloc x ng1 + gg1: paddle.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: paddle.Tensor, # nb x nloc x nnei + sw: paddle.Tensor, # nb x nloc x nnei + ) -> paddle.Tensor: + """ + Update the g2 using element-wise dot g1_i * g1_j. + + Parameters + ---------- + g1 + Atomic invariant rep, with shape nb x nloc x ng1. + gg1 + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + """ + ret = g1.unsqueeze(-2) * gg1 + # nb x nloc x nnei x ng1 + ret = _apply_nlist_mask(ret, nlist_mask) + if self.smooth: + ret = _apply_switch(ret, sw) + return ret + + def forward( + self, + g1_ext: paddle.Tensor, # nf x nall x ng1 + g2: paddle.Tensor, # nf x nloc x nnei x ng2 + h2: paddle.Tensor, # nf x nloc x nnei x 3 + nlist: paddle.Tensor, # nf x nloc x nnei + nlist_mask: paddle.Tensor, # nf x nloc x nnei + sw: paddle.Tensor, # switch func, nf x nloc x nnei + ): + """ + Parameters + ---------- + g1_ext : nf x nall x ng1 extended single-atom channel + g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant + h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant + nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) + nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei switch function + + Returns + ------- + g1: nf x nloc x ng1 updated single-atom channel + g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant + h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant + """ + cal_gg1 = ( + self.update_g1_has_drrd + or self.update_g1_has_conv + or self.update_g1_has_attn + or self.update_g2_has_g1g1 + ) + + nb, nloc, nnei, _ = g2.shape + nall = g1_ext.shape[1] + g1, _ = paddle.split(g1_ext, [nloc, nall - nloc], axis=1) + if paddle.in_dynamic_mode(): + assert [nb, nloc] == g1.shape[:2] + if paddle.in_dynamic_mode(): + assert [nb, nloc, nnei] == h2.shape[:3] + + g2_update: list[paddle.Tensor] = [g2] + h2_update: list[paddle.Tensor] = [h2] + g1_update: list[paddle.Tensor] = [g1] + g1_mlp: list[paddle.Tensor] = [g1] if not self.g1_out_mlp else [] + if self.g1_out_mlp: + if paddle.in_dynamic_mode(): + assert self.g1_self_mlp is not None + g1_self_mlp = self.act(self.g1_self_mlp(g1)) + g1_update.append(g1_self_mlp) + + if cal_gg1: + gg1 = _make_nei_g1(g1_ext, nlist) + else: + gg1 = None + + if self.update_chnnl_2: + # mlp(g2) + if paddle.in_dynamic_mode(): + assert self.linear2 is not None + # nb x nloc x nnei x ng2 + g2_1 = self.act(self.linear2(g2)) + g2_update.append(g2_1) + + if self.update_g2_has_g1g1: + # linear(g1_i * g1_j) + if paddle.in_dynamic_mode(): + assert gg1 is not None + if paddle.in_dynamic_mode(): + assert self.proj_g1g1g2 is not None + g2_update.append( + self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) + ) + + if self.update_g2_has_attn or self.update_h2: + # gated_attention(g2, h2) + if paddle.in_dynamic_mode(): + assert self.attn2g_map is not None + # nb x nloc x nnei x nnei x nh + AAg = self.attn2g_map(g2, h2, nlist_mask, sw) + + if self.update_g2_has_attn: + if paddle.in_dynamic_mode(): + assert self.attn2_mh_apply is not None + if paddle.in_dynamic_mode(): + assert self.attn2_lm is not None + # nb x nloc x nnei x ng2 + g2_2 = self.attn2_mh_apply(AAg, g2) + g2_2 = self.attn2_lm(g2_2) + g2_update.append(g2_2) + + if self.update_h2: + # linear_head(attention_weights * h2) + h2_update.append(self._update_h2(h2, AAg)) + + if self.update_g1_has_conv: + if paddle.in_dynamic_mode(): + assert gg1 is not None + g1_conv = self._update_g1_conv(gg1, g2, nlist_mask, sw) + if not self.g1_out_conv: + g1_mlp.append(g1_conv) + else: + g1_update.append(g1_conv) + + if self.update_g1_has_grrg: + g1_mlp.append( + self.symmetrization_op( + g2, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) + + if self.update_g1_has_drrd: + if paddle.in_dynamic_mode(): + assert gg1 is not None + g1_mlp.append( + self.symmetrization_op( + gg1, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) + + # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] + # conv grrg drrd + g1_1 = self.act(self.linear1(paddle.concat(g1_mlp, axis=-1))) + g1_update.append(g1_1) + + if self.update_g1_has_attn: + assert gg1 is not None + assert self.loc_attn is not None + g1_update.append(self.loc_attn(g1, gg1, nlist_mask, sw)) + + # update + if self.update_chnnl_2: + g2_new = self.list_update(g2_update, "g2") + h2_new = self.list_update(h2_update, "h2") + else: + g2_new, h2_new = g2, h2 + g1_new = self.list_update(g1_update, "g1") + return g1_new, g2_new, h2_new + + def list_update_res_avg( + self, + update_list: list[paddle.Tensor], + ) -> paddle.Tensor: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + def list_update_res_incr(self, update_list: list[paddle.Tensor]) -> paddle.Tensor: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + def list_update_res_residual( + self, update_list: list[paddle.Tensor], update_name: str = "g1" + ) -> paddle.Tensor: + nitem = len(update_list) + uu = update_list[0] + # make jit happy + if update_name == "g1": + for ii, vv in enumerate(self.g1_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "g2": + for ii, vv in enumerate(self.g2_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "h2": + for ii, vv in enumerate(self.h2_residual): + uu = uu + vv * update_list[ii + 1] + else: + raise NotImplementedError + return uu + + def list_update( + self, update_list: list[paddle.Tensor], update_name: str = "g1" + ) -> paddle.Tensor: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + elif self.update_style == "res_residual": + return self.list_update_res_residual(update_list, update_name=update_name) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + data = { + "@class": "RepformerLayer", + "@version": 2, + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "g1_dim": self.g1_dim, + "g2_dim": self.g2_dim, + "axis_neuron": self.axis_neuron, + "update_chnnl_2": self.update_chnnl_2, + "update_g1_has_conv": self.update_g1_has_conv, + "update_g1_has_drrd": self.update_g1_has_drrd, + "update_g1_has_grrg": self.update_g1_has_grrg, + "update_g1_has_attn": self.update_g1_has_attn, + "update_g2_has_g1g1": self.update_g2_has_g1g1, + "update_g2_has_attn": self.update_g2_has_attn, + "update_h2": self.update_h2, + "attn1_hidden": self.attn1_hidden, + "attn1_nhead": self.attn1_nhead, + "attn2_hidden": self.attn2_hidden, + "attn2_nhead": self.attn2_nhead, + "attn2_has_gate": self.attn2_has_gate, + "activation_function": self.activation_function, + "update_style": self.update_style, + "smooth": self.smooth, + "precision": self.precision, + "trainable_ln": self.trainable_ln, + "use_sqrt_nnei": self.use_sqrt_nnei, + "g1_out_conv": self.g1_out_conv, + "g1_out_mlp": self.g1_out_mlp, + "ln_eps": self.ln_eps, + "linear1": self.linear1.serialize(), + } + if self.update_chnnl_2: + data.update( + { + "linear2": self.linear2.serialize(), + } + ) + if self.update_g1_has_conv: + data.update( + { + "proj_g1g2": self.proj_g1g2.serialize(), + } + ) + if self.update_g2_has_g1g1: + data.update( + { + "proj_g1g1g2": self.proj_g1g1g2.serialize(), + } + ) + if self.update_g2_has_attn or self.update_h2: + data.update( + { + "attn2g_map": self.attn2g_map.serialize(), + } + ) + if self.update_g2_has_attn: + data.update( + { + "attn2_mh_apply": self.attn2_mh_apply.serialize(), + "attn2_lm": self.attn2_lm.serialize(), + } + ) + + if self.update_h2: + data.update( + { + "attn2_ev_apply": self.attn2_ev_apply.serialize(), + } + ) + if self.update_g1_has_attn: + data.update( + { + "loc_attn": self.loc_attn.serialize(), + } + ) + if self.g1_out_mlp: + data.update( + { + "g1_self_mlp": self.g1_self_mlp.serialize(), + } + ) + if self.update_style == "res_residual": + data.update( + { + "@variables": { + "g1_residual": [to_numpy_array(t) for t in self.g1_residual], + "g2_residual": [to_numpy_array(t) for t in self.g2_residual], + "h2_residual": [to_numpy_array(t) for t in self.h2_residual], + } + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "RepformerLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 2, 1) + data.pop("@class") + linear1 = data.pop("linear1") + update_chnnl_2 = data["update_chnnl_2"] + update_g1_has_conv = data["update_g1_has_conv"] + update_g2_has_g1g1 = data["update_g2_has_g1g1"] + update_g2_has_attn = data["update_g2_has_attn"] + update_h2 = data["update_h2"] + update_g1_has_attn = data["update_g1_has_attn"] + update_style = data["update_style"] + g1_out_mlp = data["g1_out_mlp"] + + linear2 = data.pop("linear2", None) + proj_g1g2 = data.pop("proj_g1g2", None) + proj_g1g1g2 = data.pop("proj_g1g1g2", None) + attn2g_map = data.pop("attn2g_map", None) + attn2_mh_apply = data.pop("attn2_mh_apply", None) + attn2_lm = data.pop("attn2_lm", None) + attn2_ev_apply = data.pop("attn2_ev_apply", None) + loc_attn = data.pop("loc_attn", None) + g1_self_mlp = data.pop("g1_self_mlp", None) + variables = data.pop("@variables", {}) + g1_residual = variables.get("g1_residual", data.pop("g1_residual", [])) + g2_residual = variables.get("g2_residual", data.pop("g2_residual", [])) + h2_residual = variables.get("h2_residual", data.pop("h2_residual", [])) + + obj = cls(**data) + obj.linear1 = MLPLayer.deserialize(linear1) + if update_chnnl_2: + assert isinstance(linear2, dict) + obj.linear2 = MLPLayer.deserialize(linear2) + if update_g1_has_conv: + assert isinstance(proj_g1g2, dict) + obj.proj_g1g2 = MLPLayer.deserialize(proj_g1g2) + if update_g2_has_g1g1: + assert isinstance(proj_g1g1g2, dict) + obj.proj_g1g1g2 = MLPLayer.deserialize(proj_g1g1g2) + if update_g2_has_attn or update_h2: + assert isinstance(attn2g_map, dict) + obj.attn2g_map = Atten2Map.deserialize(attn2g_map) + if update_g2_has_attn: + assert isinstance(attn2_mh_apply, dict) + assert isinstance(attn2_lm, dict) + obj.attn2_mh_apply = Atten2MultiHeadApply.deserialize(attn2_mh_apply) + obj.attn2_lm = LayerNorm.deserialize(attn2_lm) + if update_h2: + assert isinstance(attn2_ev_apply, dict) + obj.attn2_ev_apply = Atten2EquiVarApply.deserialize(attn2_ev_apply) + if update_g1_has_attn: + assert isinstance(loc_attn, dict) + obj.loc_attn = LocalAtten.deserialize(loc_attn) + if g1_out_mlp: + assert isinstance(g1_self_mlp, dict) + obj.g1_self_mlp = MLPLayer.deserialize(g1_self_mlp) + if update_style == "res_residual": + for ii, t in enumerate(obj.g1_residual): + t.data = to_paddle_tensor(g1_residual[ii]) + for ii, t in enumerate(obj.g2_residual): + t.data = to_paddle_tensor(g2_residual[ii]) + for ii, t in enumerate(obj.h2_residual): + t.data = to_paddle_tensor(h2_residual[ii]) + return obj diff --git a/deepmd/pd/model/descriptor/repformers.py b/deepmd/pd/model/descriptor/repformers.py new file mode 100644 index 0000000000..47d92317df --- /dev/null +++ b/deepmd/pd/model/descriptor/repformers.py @@ -0,0 +1,576 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import paddle + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pd.model.descriptor.descriptor import ( + DescriptorBlock, +) +from deepmd.pd.model.descriptor.env_mat import ( + prod_env_mat, +) +from deepmd.pd.model.network.mlp import ( + MLPLayer, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, +) +from deepmd.pd.utils.env_mat_stat import ( + EnvMatStatSe, +) +from deepmd.pd.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pd.utils.utils import ( + ActivationFn, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) + +from .repformer_layer import ( + RepformerLayer, +) + + +@DescriptorBlock.register("se_repformer") +@DescriptorBlock.register("se_uni") +class DescrptBlockRepformers(DescriptorBlock): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + nlayers: int = 3, + g1_dim=128, + g2_dim=16, + axis_neuron: int = 4, + direct_dist: bool = False, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation_function: str = "tanh", + update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", + set_davg_zero: bool = True, + smooth: bool = True, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + precision: str = "float64", + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + seed: Optional[Union[int, list[int]]] = None, + use_sqrt_nnei: bool = True, + g1_out_conv: bool = True, + g1_out_mlp: bool = True, + ) -> None: + r""" + The repformer descriptor block. + + Parameters + ---------- + rcut : float + The cut-off radius. + rcut_smth : float + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. + sel : int + Maximally possible number of selected neighbors. + ntypes : int + Number of element types + nlayers : int, optional + Number of repformer layers. + g1_dim : int, optional + Dimension of the first graph convolution layer. + g2_dim : int, optional + Dimension of the second graph convolution layer. + axis_neuron : int, optional + Size of the submatrix of G (embedding matrix). + direct_dist : bool, optional + Whether to use direct distance information (1/r term) in the repformer block. + update_g1_has_conv : bool, optional + Whether to update the g1 rep with convolution term. + update_g1_has_drrd : bool, optional + Whether to update the g1 rep with the drrd term. + update_g1_has_grrg : bool, optional + Whether to update the g1 rep with the grrg term. + update_g1_has_attn : bool, optional + Whether to update the g1 rep with the localized self-attention. + update_g2_has_g1g1 : bool, optional + Whether to update the g2 rep with the g1xg1 term. + update_g2_has_attn : bool, optional + Whether to update the g2 rep with the gated self-attention. + update_h2 : bool, optional + Whether to update the h2 rep. + attn1_hidden : int, optional + The hidden dimension of localized self-attention to update the g1 rep. + attn1_nhead : int, optional + The number of heads in localized self-attention to update the g1 rep. + attn2_hidden : int, optional + The hidden dimension of gated self-attention to update the g2 rep. + attn2_nhead : int, optional + The number of heads in gated self-attention to update the g2 rep. + attn2_has_gate : bool, optional + Whether to use gate in the gated self-attention to update the g2 rep. + activation_function : str, optional + The activation function in the embedding net. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + set_davg_zero : bool, optional + Set the normalization average to zero. + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable_ln : bool, optional + Whether to use trainable shift and scale weights in layer normalization. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. + g1_out_conv : bool, optional + Whether to put the convolutional update of g1 separately outside the concatenated MLP update. + g1_out_mlp : bool, optional + Whether to put the self MLP update of g1 separately outside the concatenated MLP update. + ln_eps : float, optional + The epsilon value for layer normalization. + seed : int, optional + Random seed for parameter initialization. + """ + super().__init__() + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) + self.ntypes = ntypes + self.nlayers = nlayers + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 # use full descriptor. + assert len(sel) == 1 + self.sel = sel + self.sec = self.sel + self.split_sel = self.sel + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + self.g1_dim = g1_dim + self.g2_dim = g2_dim + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_attn = update_g1_has_attn + self.update_g2_has_g1g1 = update_g2_has_g1g1 + self.update_g2_has_attn = update_g2_has_attn + self.update_h2 = update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_has_gate = attn2_has_gate + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.activation_function = activation_function + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.direct_dist = direct_dist + self.act = ActivationFn(activation_function) + self.smooth = smooth + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + self.env_protection = env_protection + self.precision = precision + self.prec = PRECISION_DICT[precision] + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.epsilon = 1e-4 + self.seed = seed + + self.g2_embd = MLPLayer( + 1, self.g2_dim, precision=precision, seed=child_seed(seed, 0) + ) + layers = [] + for ii in range(nlayers): + layers.append( + RepformerLayer( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.g1_dim, + self.g2_dim, + axis_neuron=self.axis_neuron, + update_chnnl_2=(ii != nlayers - 1), + update_g1_has_conv=self.update_g1_has_conv, + update_g1_has_drrd=self.update_g1_has_drrd, + update_g1_has_grrg=self.update_g1_has_grrg, + update_g1_has_attn=self.update_g1_has_attn, + update_g2_has_g1g1=self.update_g2_has_g1g1, + update_g2_has_attn=self.update_g2_has_attn, + update_h2=self.update_h2, + attn1_hidden=self.attn1_hidden, + attn1_nhead=self.attn1_nhead, + attn2_has_gate=self.attn2_has_gate, + attn2_hidden=self.attn2_hidden, + attn2_nhead=self.attn2_nhead, + activation_function=self.activation_function, + update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, + smooth=self.smooth, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + precision=precision, + use_sqrt_nnei=self.use_sqrt_nnei, + g1_out_conv=self.g1_out_conv, + g1_out_mlp=self.g1_out_mlp, + seed=child_seed(child_seed(seed, 1), ii), + ) + ) + self.layers = paddle.nn.LayerList(layers) + + wanted_shape = (self.ntypes, self.nnei, 4) + mean = paddle.zeros(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( + device=env.DEVICE + ) + stddev = paddle.ones(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( + device=env.DEVICE + ) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.stats = None + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_emb(self) -> int: + """Returns the embedding dimension g2.""" + return self.g2_dim + + def __setitem__(self, key, value) -> None: + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.g1_dim + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.g1_dim + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def reinit_exclude( + self, + exclude_types: list[tuple[int, int]] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def forward( + self, + nlist: paddle.Tensor, + extended_coord: paddle.Tensor, + extended_atype: paddle.Tensor, + extended_atype_embd: Optional[paddle.Tensor] = None, + mapping: Optional[paddle.Tensor] = None, + type_embedding: Optional[paddle.Tensor] = None, + comm_dict: Optional[dict[str, paddle.Tensor]] = None, + ): + if comm_dict is None: + assert mapping is not None + assert extended_atype_embd is not None + nframes, nloc, nnei = nlist.shape + nall = extended_coord.reshape([nframes, -1]).shape[1] // 3 + atype = extended_atype[:, :nloc] + # nb x nloc x nnei + exclude_mask = self.emask(nlist, extended_atype) + nlist = paddle.where(exclude_mask != 0, nlist, paddle.full_like(nlist, -1)) + # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 + dmatrix, diff, sw = prod_env_mat( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + protection=self.env_protection, + ) + nlist_mask = nlist != -1 + sw = paddle.squeeze(sw, -1) + # beyond the cutoff sw should be 0.0 + sw = sw.masked_fill(~nlist_mask, 0.0) + + # [nframes, nloc, tebd_dim] + if comm_dict is None: + if paddle.in_dynamic_mode(): + assert isinstance(extended_atype_embd, paddle.Tensor) # for jit + atype_embd = extended_atype_embd[:, :nloc, :] + if paddle.in_dynamic_mode(): + assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim] + else: + atype_embd = extended_atype_embd + if paddle.in_dynamic_mode(): + assert isinstance(atype_embd, paddle.Tensor) # for jit + g1 = self.act(atype_embd) + ng1 = g1.shape[-1] + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + if not self.direct_dist: + g2, h2 = paddle.split(dmatrix, [1, 3], axis=-1) + else: + # g2, h2 = paddle.linalg.norm(diff, axis=-1, keepdim=True), diff + g2, h2 = paddle.linalg.norm(diff, axis=-1, keepdim=True), diff + g2 = g2 / self.rcut + h2 = h2 / self.rcut + # nb x nloc x nnei x ng2 + g2 = self.act(self.g2_embd(g2)) + + # set all padding positions to index of 0 + # if the a neighbor is real or not is indicated by nlist_mask + nlist[nlist == -1] = 0 + # nb x nall x ng1 + if comm_dict is None: + assert mapping is not None + mapping = ( + mapping.reshape([nframes, nall]) + .unsqueeze(-1) + .expand([-1, -1, self.g1_dim]) + ) + for idx, ll in enumerate(self.layers): + # g1: nb x nloc x ng1 + # g1_ext: nb x nall x ng1 + if comm_dict is None: + assert mapping is not None + g1_ext = paddle.take_along_axis(g1, axis=1, indices=mapping) + else: + raise NotImplementedError("Not implemented yet") + # has_spin = "has_spin" in comm_dict + # if not has_spin: + # n_padding = nall - nloc + # g1 = paddle.nn.functional.pad( + # g1.squeeze(0), (0, 0, 0, n_padding), value=0.0 + # ) + # real_nloc = nloc + # real_nall = nall + # else: + # # for spin + # real_nloc = nloc // 2 + # real_nall = nall // 2 + # real_n_padding = real_nall - real_nloc + # g1_real, g1_virtual = paddle.split( + # g1, [real_nloc, real_nloc], axis=1 + # ) + # # mix_g1: nb x real_nloc x (ng1 * 2) + # mix_g1 = paddle.concat([g1_real, g1_virtual], axis=2) + # # nb x real_nall x (ng1 * 2) + # g1 = paddle.nn.functional.pad( + # mix_g1.squeeze(0), (0, 0, 0, real_n_padding), value=0.0 + # ) + + # assert "send_list" in comm_dict + # assert "send_proc" in comm_dict + # assert "recv_proc" in comm_dict + # assert "send_num" in comm_dict + # assert "recv_num" in comm_dict + # assert "communicator" in comm_dict + # ret = paddle.ops.deepmd.border_op( + # comm_dict["send_list"], + # comm_dict["send_proc"], + # comm_dict["recv_proc"], + # comm_dict["send_num"], + # comm_dict["recv_num"], + # g1, + # comm_dict["communicator"], + # paddle.to_tensor( + # real_nloc, + # dtype=paddle.int32, + # place=env.DEVICE, + # ), # should be int of c++ + # paddle.to_tensor( + # real_nall - real_nloc, + # dtype=paddle.int32, + # place=env.DEVICE, + # ), # should be int of c++ + # ) + # g1_ext = ret[0].unsqueeze(0) + # if has_spin: + # g1_real_ext, g1_virtual_ext = paddle.split( + # g1_ext, [ng1, ng1], axis=2 + # ) + # g1_ext = concat_switch_virtual( + # g1_real_ext, g1_virtual_ext, real_nloc + # ) + g1, g2, h2 = ll.forward( + g1_ext, + g2, + h2, + nlist, + nlist_mask, + sw, + ) + + # nb x nloc x 3 x ng2 + h2g2 = RepformerLayer._cal_hg( + g2, + h2, + nlist_mask, + sw, + smooth=self.smooth, + epsilon=self.epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, + ) + # (nb x nloc) x ng2 x 3 + rot_mat = paddle.transpose(h2g2, (0, 1, 3, 2)) + + return g1, g2, h2, rot_mat.reshape([nframes, nloc, self.dim_emb, 3]), sw + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + paddle.assign(paddle.to_tensor(mean).to(device=env.DEVICE), self.mean) # pylint: disable=no-explicit-dtype + paddle.assign(paddle.to_tensor(stddev).to(device=env.DEVICE), self.stddev) # pylint: disable=no-explicit-dtype + + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return True + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False diff --git a/deepmd/pd/model/descriptor/se_t_tebd.py b/deepmd/pd/model/descriptor/se_t_tebd.py new file mode 100644 index 0000000000..a8b9a6a417 --- /dev/null +++ b/deepmd/pd/model/descriptor/se_t_tebd.py @@ -0,0 +1,931 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import paddle + +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pd.model.descriptor import ( + DescriptorBlock, +) +from deepmd.pd.model.descriptor.env_mat import ( + prod_env_mat, +) +from deepmd.pd.model.network.mlp import ( + EmbeddingNet, + NetworkCollection, +) +from deepmd.pd.model.network.network import ( + TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISON_DICT, +) +from deepmd.pd.utils.env_mat_stat import ( + EnvMatStatSe, +) +from deepmd.pd.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pd.utils.update_sel import ( + UpdateSel, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_pair_exclude_types, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .descriptor import ( + extend_descrpt_stat, +) + + +@BaseDescriptor.register("se_e3_tebd") +class DescrptSeTTebd(BaseDescriptor, paddle.nn.Layer): + r"""Construct an embedding net that takes angles between two neighboring atoms and type embeddings as input. + + Parameters + ---------- + rcut + The cut-off radius + rcut_smth + From where the environment matrix should be smoothed + sel : Union[list[int], int] + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net + tebd_dim : int + Dimension of the type embedding + tebd_input_mode : str + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed angular information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the angular embedding network output. + resnet_dt + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + set_davg_zero + Set the shift of embedding net input to zero. + activation_function + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. + exclude_types : list[tuple[int, int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + precision + The precision of the embedding net parameters. Supported options are |PRECISION| + trainable + If the weights of embedding net are trainable. + seed + Random seed for initializing the network parameters. + type_map: list[str], Optional + A list of strings. Give the name to each type of atoms. + concat_output_tebd: bool + Whether to concat type embedding at the output of the descriptor. + use_econf_tebd: bool, Optional + Whether to use electronic configuration type embedding. + use_tebd_bias : bool, Optional + Whether to use bias in the type embedding layer. + smooth: bool + Whether to use smooth process in calculation. + + """ + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[list[int], int], + ntypes: int, + neuron: list = [2, 4, 8], + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + resnet_dt: bool = False, + set_davg_zero: bool = True, + activation_function: str = "tanh", + env_protection: float = 0.0, + exclude_types: list[tuple[int, int]] = [], + precision: str = "float64", + trainable: bool = True, + seed: Optional[Union[int, list[int]]] = None, + type_map: Optional[list[str]] = None, + concat_output_tebd: bool = True, + use_econf_tebd: bool = False, + use_tebd_bias=False, + smooth: bool = True, + ) -> None: + super().__init__() + self.se_ttebd = DescrptBlockSeTTebd( + rcut, + rcut_smth, + sel, + ntypes, + neuron=neuron, + tebd_dim=tebd_dim, + tebd_input_mode=tebd_input_mode, + set_davg_zero=set_davg_zero, + activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, + exclude_types=exclude_types, + env_protection=env_protection, + smooth=smooth, + seed=child_seed(seed, 1), + ) + self.prec = PRECISION_DICT[precision] + self.use_econf_tebd = use_econf_tebd + self.type_map = type_map + self.smooth = smooth + self.type_embedding = TypeEmbedNet( + ntypes, + tebd_dim, + precision=precision, + seed=child_seed(seed, 2), + use_econf_tebd=use_econf_tebd, + type_map=type_map, + use_tebd_bias=use_tebd_bias, + ) + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.concat_output_tebd = concat_output_tebd + self.trainable = trainable + # set trainable + for param in self.parameters(): + param.stop_gradient = not trainable + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.se_ttebd.get_rcut() + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.se_ttebd.get_rcut_smth() + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return self.se_ttebd.get_nsel() + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.se_ttebd.get_sel() + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.se_ttebd.get_ntypes() + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + ret = self.se_ttebd.get_dim_out() + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + return self.se_ttebd.dim_emb + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return self.se_ttebd.mixed_types() + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.se_ttebd.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return self.se_ttebd.need_sorted_nlist_for_lower() + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.se_ttebd.get_env_protection() + + def share_params(self, base_class, shared_level, resume=False) -> None: + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some separated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For DPA1 descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in both type_embedding and se_ttebd + if shared_level == 0: + self._sub_layers["type_embedding"] = base_class._sub_layers[ + "type_embedding" + ] + self.se_ttebd.share_params(base_class.se_ttebd, 0, resume=resume) + # shared_level: 1 + # share all parameters in type_embedding + elif shared_level == 1: + self._sub_layers["type_embedding"] = base_class._sub_layers[ + "type_embedding" + ] + # Other shared levels + else: + raise NotImplementedError + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + return self.se_ttebd.compute_input_stats(merged, path) + + def set_stat_mean_and_stddev( + self, + mean: paddle.Tensor, + stddev: paddle.Tensor, + ) -> None: + """Update mean and stddev for descriptor.""" + self.se_ttebd.mean = mean + self.se_ttebd.stddev = stddev + + def get_stat_mean_and_stddev(self) -> tuple[paddle.Tensor, paddle.Tensor]: + """Get mean and stddev for descriptor.""" + return self.se_ttebd.mean, self.se_ttebd.stddev + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert ( + self.type_map is not None + ), "'type_map' must be defined when performing type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + obj = self.se_ttebd + obj.ntypes = len(type_map) + self.type_map = type_map + self.type_embedding.change_type_map(type_map=type_map) + obj.reinit_exclude(map_pair_exclude_types(obj.exclude_types, remap_index)) + if has_new_type: + # the avg and std of new types need to be updated + extend_descrpt_stat( + obj, + type_map, + des_with_stat=model_with_new_type_stat.se_ttebd + if model_with_new_type_stat is not None + else None, + ) + obj["davg"] = obj["davg"][remap_index] + obj["dstd"] = obj["dstd"][remap_index] + + def serialize(self) -> dict: + obj = self.se_ttebd + data = { + "@class": "Descriptor", + "type": "se_e3_tebd", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "concat_output_tebd": self.concat_output_tebd, + "use_econf_tebd": self.use_econf_tebd, + "type_map": self.type_map, + # make deterministic + "precision": RESERVED_PRECISON_DICT[obj.prec], + "embeddings": obj.filter_layers.serialize(), + "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), + "type_embedding": self.type_embedding.embedding.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "smooth": self.smooth, + "@variables": { + "davg": obj["davg"].numpy(), + "dstd": obj["dstd"].numpy(), + }, + "trainable": self.trainable, + } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.filter_layers_strip.serialize()}) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptSeTTebd": + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None + obj = cls(**data) + + def t_cvt(xx): + return paddle.to_tensor(xx, dtype=obj.se_ttebd.prec).to(device=env.DEVICE) + + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + obj.se_ttebd["davg"] = t_cvt(variables["davg"]) + obj.se_ttebd["dstd"] = t_cvt(variables["dstd"]) + obj.se_ttebd.filter_layers = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.se_ttebd.filter_layers_strip = NetworkCollection.deserialize( + embeddings_strip + ) + return obj + + def forward( + self, + extended_coord: paddle.Tensor, + extended_atype: paddle.Tensor, + nlist: paddle.Tensor, + mapping: Optional[paddle.Tensor] = None, + comm_dict: Optional[dict[str, paddle.Tensor]] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, not required by this descriptor. + comm_dict + The data needed for communication for parallel inference. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + del mapping + nframes, nloc, nnei = nlist.shape + nall = extended_coord.reshape([nframes, -1]).shape[1] // 3 + g1_ext = self.type_embedding(extended_atype) + g1_inp = g1_ext[:, :nloc, :] + if self.tebd_input_mode in ["strip"]: + type_embedding = self.type_embedding.get_full_embedding(g1_ext.place) + else: + type_embedding = None + g1, _, _, _, sw = self.se_ttebd( + nlist, + extended_coord, + extended_atype, + g1_ext, + mapping=None, + type_embedding=type_embedding, + ) + if self.concat_output_tebd: + g1 = paddle.concat([g1, g1_inp], axis=-1) + + return ( + g1.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + None, + None, + None, + sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + ) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + min_nbor_dist, sel = UpdateSel().update_one_sel( + train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True + ) + local_jdata_cpy["sel"] = sel[0] + return local_jdata_cpy, min_nbor_dist + + +@DescriptorBlock.register("se_ttebd") +class DescrptBlockSeTTebd(DescriptorBlock): + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[list[int], int], + ntypes: int, + neuron: list = [25, 50, 100], + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + set_davg_zero: bool = True, + activation_function="tanh", + precision: str = "float64", + resnet_dt: bool = False, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + smooth: bool = True, + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) + self.neuron = neuron + self.filter_neuron = self.neuron + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.resnet_dt = resnet_dt + self.env_protection = env_protection + self.seed = seed + self.smooth = smooth + + if isinstance(sel, int): + sel = [sel] + + self.ntypes = ntypes + self.sel = sel + self.sec = self.sel + self.split_sel = self.sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + + wanted_shape = (self.ntypes, self.nnei, 4) + mean = paddle.zeros(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( + device=env.DEVICE + ) + stddev = paddle.ones(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( + device=env.DEVICE + ) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.tebd_dim_input = self.tebd_dim * 2 + if self.tebd_input_mode in ["concat"]: + self.embd_input_dim = 1 + self.tebd_dim_input + else: + self.embd_input_dim = 1 + + self.filter_layers = None + self.filter_layers_strip = None + filter_layers = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers[0] = EmbeddingNet( + self.embd_input_dim, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + seed=child_seed(self.seed, 1), + ) + self.filter_layers = filter_layers + if self.tebd_input_mode in ["strip"]: + filter_layers_strip = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers_strip[0] = EmbeddingNet( + self.tebd_dim_input, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + seed=child_seed(self.seed, 2), + ) + self.filter_layers_strip = filter_layers_strip + self.stats = None + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_emb(self) -> int: + """Returns the output dimension of embedding.""" + return self.filter_neuron[-1] + + def __setitem__(self, key, value) -> None: + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.filter_neuron[-1] + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.tebd_dim + + @property + def dim_emb(self): + """Returns the output dimension of embedding.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + paddle.assign(paddle.to_tensor(mean).to(device=env.DEVICE), self.mean) # pylint: disable=no-explicit-dtype + paddle.assign(paddle.to_tensor(stddev).to(device=env.DEVICE), self.stddev) # pylint: disable=no-explicit-dtype + + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + def reinit_exclude( + self, + exclude_types: list[tuple[int, int]] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def forward( + self, + nlist: paddle.Tensor, + extended_coord: paddle.Tensor, + extended_atype: paddle.Tensor, + extended_atype_embd: Optional[paddle.Tensor] = None, + mapping: Optional[paddle.Tensor] = None, + type_embedding: Optional[paddle.Tensor] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + nlist + The neighbor list. shape: nf x nloc x nnei + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall x nt + extended_atype_embd + The extended type embedding of atoms. shape: nf x nall + mapping + The index mapping, not required by this descriptor. + type_embedding + Full type embeddings. shape: (ntypes+1) x nt + Required for stripped type embeddings. + + Returns + ------- + result + The descriptor. shape: nf x nloc x (ng x axis_neuron) + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + del mapping + assert extended_atype_embd is not None + nframes, nloc, nnei = nlist.shape + atype = extended_atype[:, :nloc] + nb = nframes + nall = extended_coord.reshape([nb, -1, 3]).shape[1] + dmatrix, diff, sw = prod_env_mat( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + protection=self.env_protection, + ) + # nb x nloc x nnei + exclude_mask = self.emask(nlist, extended_atype) + nlist = paddle.where(exclude_mask != 0, nlist, paddle.full_like(nlist, -1)) + nlist_mask = nlist != -1 + nlist = paddle.where(nlist == -1, paddle.zeros_like(nlist), nlist) + sw = paddle.squeeze(sw, -1) + # nf x nall x nt + nt = extended_atype_embd.shape[-1] + # beyond the cutoff sw should be 0.0 + sw = sw.masked_fill(~nlist_mask, 0.0) + # (nb x nloc) x nnei + exclude_mask = exclude_mask.reshape([nb * nloc, nnei]) + assert self.filter_layers is not None + # nfnl x nnei x 4 + dmatrix = dmatrix.reshape([-1, self.nnei, 4]) + nfnl = dmatrix.shape[0] + # nfnl x nnei x 4 + rr = dmatrix + rr = rr * exclude_mask[:, :, None].astype(rr.dtype) + + # nfnl x nt_i x 3 + rr_i = rr[:, :, 1:] + # nfnl x nt_j x 3 + rr_j = rr[:, :, 1:] + # nfnl x nt_i x nt_j + # env_ij = paddle.einsum("ijm,ikm->ijk", rr_i, rr_j) + env_ij = ( + # ij1m x i1km -> ijkm -> ijk + rr_i.unsqueeze(2) * rr_j.unsqueeze(1) + ).sum(-1) + # nfnl x nt_i x nt_j x 1 + ss = env_ij.unsqueeze(-1) + if self.tebd_input_mode in ["concat"]: + atype_tebd_ext = extended_atype_embd + # nb x (nloc x nnei) x nt + index = nlist.reshape([nb, nloc * nnei]).unsqueeze(-1).expand([-1, -1, nt]) + # nb x (nloc x nnei) x nt + # atype_tebd_nlist = paddle.take_along_axis(atype_tebd_ext, axis=1, index=index) + atype_tebd_nlist = paddle.take_along_axis( + atype_tebd_ext, axis=1, indices=index + ) + # nb x nloc x nnei x nt + atype_tebd_nlist = atype_tebd_nlist.reshape([nb, nloc, nnei, nt]) + # nfnl x nnei x tebd_dim + nlist_tebd = atype_tebd_nlist.reshape([nfnl, nnei, self.tebd_dim]) + # nfnl x nt_i x nt_j x tebd_dim + nlist_tebd_i = nlist_tebd.unsqueeze(2).expand([-1, -1, self.nnei, -1]) + nlist_tebd_j = nlist_tebd.unsqueeze(1).expand([-1, self.nnei, -1, -1]) + # nfnl x nt_i x nt_j x (1 + tebd_dim * 2) + ss = paddle.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1) + # nfnl x nt_i x nt_j x ng + gg = self.filter_layers.networks[0](ss) + elif self.tebd_input_mode in ["strip"]: + # nfnl x nt_i x nt_j x ng + gg_s = self.filter_layers.networks[0](ss) + assert self.filter_layers_strip is not None + assert type_embedding is not None + ng = self.filter_neuron[-1] + ntypes_with_padding = type_embedding.shape[0] + # nf x (nl x nnei) + nlist_index = nlist.reshape([nb, nloc * nnei]) + # nf x (nl x nnei) + nei_type = paddle.take_along_axis( + extended_atype, indices=nlist_index, axis=1 + ) + # nfnl x nnei + nei_type = nei_type.reshape([nfnl, nnei]) + # nfnl x nnei x nnei + nei_type_i = nei_type.unsqueeze(2).expand([-1, -1, nnei]) + nei_type_j = nei_type.unsqueeze(1).expand([-1, nnei, -1]) + idx_i = nei_type_i * ntypes_with_padding + idx_j = nei_type_j + # (nf x nl x nt_i x nt_j) x ng + idx = ( + (idx_i + idx_j) + .reshape([-1, 1]) + .expand([-1, ng]) + .astype(paddle.int64) + .to(paddle.int64) + ) + # ntypes * (ntypes) * nt + type_embedding_i = paddle.tile( + type_embedding.reshape([ntypes_with_padding, 1, nt]), + [1, ntypes_with_padding, 1], + ) + # (ntypes) * ntypes * nt + type_embedding_j = paddle.tile( + type_embedding.reshape([1, ntypes_with_padding, nt]), + [ntypes_with_padding, 1, 1], + ) + # (ntypes * ntypes) * (nt+nt) + two_side_type_embedding = paddle.concat( + [type_embedding_i, type_embedding_j], -1 + ).reshape([-1, nt * 2]) + tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding) + # (nfnl x nt_i x nt_j) x ng + gg_t = paddle.take_along_axis(tt_full, indices=idx, axis=0) + # (nfnl x nt_i x nt_j) x ng + gg_t = gg_t.reshape([nfnl, nnei, nnei, ng]) + if self.smooth: + gg_t = ( + gg_t + * sw.reshape([nfnl, self.nnei, 1, 1]) + * sw.reshape([nfnl, 1, self.nnei, 1]) + ) + # nfnl x nt_i x nt_j x ng + gg = gg_s * gg_t + gg_s + else: + raise NotImplementedError + + # nfnl x ng + # res_ij = paddle.einsum("ijk,ijkm->im", env_ij, gg) + res_ij = ( + # ijk1 x ijkm -> ijkm -> im + env_ij.unsqueeze(-1) * gg + ).sum([1, 2]) + res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei)) + # nf x nl x ng + result = res_ij.reshape([nframes, nloc, self.filter_neuron[-1]]) + return ( + result, + None, + None, + None, + sw, + ) + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return False + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False diff --git a/deepmd/pd/model/task/fitting.py b/deepmd/pd/model/task/fitting.py index d9db44aff5..6e96b7b081 100644 --- a/deepmd/pd/model/task/fitting.py +++ b/deepmd/pd/model/task/fitting.py @@ -211,8 +211,8 @@ def __init__( if self.dim_case_embd > 0: self.register_buffer( "case_embd", - paddle.zeros(self.dim_case_embd, dtype=self.prec, place=device), - # paddle.eye(self.dim_case_embd, dtype=self.prec, place=device)[0], + paddle.zeros(self.dim_case_embd, dtype=self.prec).to(device=device), + # paddle.eye(self.dim_case_embd, dtype=self.prec).to(device=device)[0], ) else: self.case_embd = None diff --git a/deepmd/pd/utils/multi_task.py b/deepmd/pd/utils/multi_task.py index 680dc53c79..321883c12e 100644 --- a/deepmd/pd/utils/multi_task.py +++ b/deepmd/pd/utils/multi_task.py @@ -96,7 +96,9 @@ def preprocess_shared_params(model_config): shared_links = {} type_map_keys = [] - def replace_one_item(params_dict, key_type, key_in_dict, suffix="", index=None): + def replace_one_item( + params_dict, key_type, key_in_dict, suffix="", index=None + ) -> None: shared_type = key_type shared_key = key_in_dict shared_level = 0 diff --git a/deepmd/pd/utils/spin.py b/deepmd/pd/utils/spin.py new file mode 100644 index 0000000000..934fb3762a --- /dev/null +++ b/deepmd/pd/utils/spin.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import paddle + + +def concat_switch_virtual( + extended_tensor, + extended_tensor_virtual, + nloc: int, +): + """ + Concat real and virtual extended tensors, and switch all the local ones to the first nloc * 2 atoms. + - [:, :nloc]: original nloc real atoms. + - [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms. + - [:, nloc + nloc: nloc + nall]: ghost real atoms. + - [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms. + """ + nframes, nall = extended_tensor.shape[:2] + out_shape = list(extended_tensor.shape) + out_shape[1] *= 2 + extended_tensor_updated = paddle.zeros( + out_shape, + dtype=extended_tensor.dtype, + device=extended_tensor.place, + ) + extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc] + extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[:, :nloc] + extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[:, nloc:] + extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:] + return extended_tensor_updated.reshape(out_shape) diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 72c0967a78..ef840bf9d7 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -17,6 +17,7 @@ from ..common import ( INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, + INSTALLED_PD, INSTALLED_PT, CommonTest, parameterized, @@ -34,6 +35,12 @@ from deepmd.jax.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2JAX else: DescrptDPA2JAX = None + +if INSTALLED_PD: + from deepmd.pd.model.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2PD +else: + DescrptDPA2PD = None + if INSTALLED_ARRAY_API_STRICT: from ...array_api_strict.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2Strict else: @@ -214,6 +221,39 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pd(self) -> bool: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repinit_use_three_body, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return not INSTALLED_PD or precision == "bfloat16" + @property def skip_dp(self) -> bool: ( @@ -286,6 +326,7 @@ def skip_tf(self) -> bool: tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP pt_class = DescrptDPA2PT + pd_class = DescrptDPA2PD jax_class = DescrptDPA2JAX array_api_strict_class = DescrptDPA2Strict args = descrpt_dpa2_args().append(Argument("ntypes", int, optional=False)) @@ -383,6 +424,16 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_pd(self, pd_obj: Any) -> Any: + return self.eval_pd_descriptor( + pd_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index bb4a5db6e7..9cdca9bde3 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -17,6 +17,7 @@ from ..common import ( INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, + INSTALLED_PD, INSTALLED_PT, CommonTest, parameterized, @@ -34,6 +35,10 @@ from deepmd.jax.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdJAX else: DescrptSeTTebdJAX = None +if INSTALLED_PD: + from deepmd.pd.model.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdPD +else: + DescrptSeTTebdPD = None if INSTALLED_ARRAY_API_STRICT: from ...array_api_strict.descriptor.se_t_tebd import ( DescrptSeTTebd as DescrptSeTTebdStrict, @@ -146,12 +151,14 @@ def skip_tf(self) -> bool: ) = self.param return True + skip_pd = not INSTALLED_PD skip_jax = not INSTALLED_JAX skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeTTebdTF dp_class = DescrptSeTTebdDP pt_class = DescrptSeTTebdPT + pd_class = DescrptSeTTebdPD jax_class = DescrptSeTTebdJAX array_api_strict_class = DescrptSeTTebdStrict args = descrpt_se_e3_tebd_args().append(Argument("ntypes", int, optional=False)) @@ -243,6 +250,16 @@ def eval_jax(self, jax_obj: Any) -> Any: mixed_types=True, ) + def eval_pd(self, pd_obj: Any) -> Any: + return self.eval_pd_descriptor( + pd_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: return self.eval_array_api_strict_descriptor( array_api_strict_obj, diff --git a/source/tests/pd/model/models/dpa2.json b/source/tests/pd/model/models/dpa2.json new file mode 100644 index 0000000000..f83e319de3 --- /dev/null +++ b/source/tests/pd/model/models/dpa2.json @@ -0,0 +1,57 @@ +{ + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa2", + "repinit": { + "rcut": 6.0, + "rcut_smth": 2.0, + "nsel": 30, + "neuron": [ + 2, + 4, + 8 + ], + "axis_neuron": 4, + "activation_function": "tanh" + + }, + "repformer": { + "rcut": 4.0, + "rcut_smth": 0.5, + "nsel": 10, + "nlayers": 12, + "g1_dim": 8, + "g2_dim": 5, + "attn2_hidden": 3, + "attn2_nhead": 1, + "attn1_hidden": 5, + "attn1_nhead": 1, + "axis_neuron": 4, + "update_h2": false, + "update_g1_has_conv": true, + "update_g1_has_grrg": true, + "update_g1_has_drrd": true, + "update_g1_has_attn": true, + "update_g2_has_g1g1": true, + "update_g2_has_attn": true, + "attn2_has_gate": true, + "use_sqrt_nnei": false, + "g1_out_conv": false, + "g1_out_mlp": false + }, + "seed": 1, + "add_tebd_to_repinit_out": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1 + } +} diff --git a/source/tests/pd/model/models/dpa2.pd b/source/tests/pd/model/models/dpa2.pd new file mode 100644 index 0000000000000000000000000000000000000000..650f0c144e7c2dae9a265e61ed73d4eb29fa6dc7 GIT binary patch literal 119535 zcmeFZXH=Ehwk=8!5ky4=5m8YMLK~WIis&d!a`@C0GyY_wep8Kcxv22a@;rwPFV~)^ApUZlK4;cw5$uB=oy*XWI zj5UmO40KGn^fgosyj|!G%=C?{xb7PoYjEkC>6z-N8XK!xdArak8*AJ*G%ztWHoI@? z?d9!4@#}-$E|do9KOgb(mLPHP79&;jcJ!9<=9D0nAd}cHvBA+>o5F=c-PFoRUl91)5B~Op zzy09vbMSW@_&*&7e$CoEzh!Mr9X(SGV`V*6D-B~4E&~ly3q#`vCR{vR|M(3Lm%gg0 zv5w{MG(~g63 z5_T=e8X)j!LaWENG8}ogeZ)SIfF6e*#yd!5z=PiFCpxbrgUyA~)UzKW!PWYp`igN8 zJTl0+CC=gn(`C0>0?5kn6`ywF35g`sXAr;W%v291-6^0_}aR6#|Ky@c-RVl))XptM}cgDb2q`CoP1;QjO*=bC>Qc#^B! zE}agB0U7oU?^K#0hdcXu7-t>)iZa(hHgkC2`Q#U> z#xbA^JksB|cNrzUo<=@8Jo|IGzwE|*e~iEeiA_J<_!q`&)j4R=^-vYwy|jNcL#+c; zPDytO&qv_v2cJyGb<5EplOu4Utr#s*NmYv6BI@YsA;MBQqnU z8_mo^NpnKGQTGV-X*S_jw0!G*%s(R=EoM{XbJgpR_mqPEvQ0GXzuI8o;?RWWB26y^ z3s0lo#T>E^%$>k-jX^cra00S}uFsDLBC_9T4SslX0{5$?kc9B}!wbVYO0CKkpfkQ! zx;4BDE3V1dPLj323q`AFnbqN+e)X5#mhX>@xap@6&-*BNxYRWVp+y?~k?u`X$Ue7g z=jRM3cwpSv>L&6DZ05JwidcuDqWQU~n=4~5VB(%~f$AIDFJ3Xp)0_GEI=sJsij@Bj zR{vXr759Gt^BwQuQik}@1=Y~fXIw$%^kYq-AGh2T#Z+I@5~y`9rnG;m}a?On7T|Y(^>73oD0r(J(pY!nvn`;h#d~Avn)g&fLZV&mUh( zCUq)@vbdN>?pKpRPVKVFa!x)JIJ6XB%`e08{?)9Wfp9*rz`Mtn_Qg8 zayk@$Qxo7$jE5}ma^1M#OicA-@_gA=2Z8)E7bV;MQTUwdX}i+}cw#3jnfi1g1SSur z(dO2HvF&E9^)zBn-&UcL$Xx|Z6}uc=l>;EoQL&}YAQ$cXYfR^9cwzSKs%q}Xa;Uld zkwSMZ7n$YWgv!(v0hjD`<$_%~$SFFv<$G)s78YE1IqKvCr@H5O$^@ctQ22JCeYZY7 zB2{z`oJs{zCZ(#J{bg9eMVZ z1AIzf>~g9=v6k73*GCfZHFr!SSA8CiiCFoCaVFrM7v0h+7b?MO(^NpEybp$SHb0*b ztN^~Do=1I6De(BJP@WpGx!kiT|EBj;7M_r{A1SEI$ABcWVZMF&5TKy`vBj_kGum2* zS>G3d$)RIMRV*`bo5pRHGatgCHuA_o#+ooLZ~abs+O-;L{p%WZCklYIzNmbYdlDX$ zyW3@_9sxh>3&xeEDnYDi*;!gG7nWxBx{;A4!ILUSlUt1Wz{@UoU(U!Jx_ts3*^+0V zIi=Ri;VA;L?Hcb?&rZgtGt$Zq4dq~xKNFmBvKIOpEY^>FEy0UDRU@Rn4`KWE^+H|y zOd!A4EX({g55zORrhGPNMS9sD`#YS`D0SB`pmcjM@JcA2a3SP^&^MFNT!vh1;@w+K zl~@K6-L6H?Srg!sz!N*E6#~T5Q@JxiA$X34u#d}SV1V?bQlCl&YUyrz!#1CUWd6aW z^WPB<*wdFXoGXRkcIl4EebumA$mJZ5K@E^UdN(({FAhv7m7*B$SK%W2RmHRRdC(La z=yjbs3ZMJVX!Culh5J(%4l5ll0C75_eZhwsQO{9>vcW$H=_aSLOOMBb4gb4FM)w+c zbFfu`Mxh1oYFwxB_i!ST0nsZ5Fbgd^s2&|mdBkx1?h7O-{@chu|JIh!D_J>m2q8qcJbAxe! zuQ!ogkkY9N%_1N*W74!!b`GuzH~IXKFGa?}@_Faeg`jwi_f(@#5*FFsuo`|<3JZg? z{6>7mu;bA?+PzuHaE;UW(+3Y9tbMSty@RU=BFvODEnD;9(xgvu%EvTZCKwFH4z{8H z)t9p-FJsYIelV8WD*^Ilukd7i4TZdtZ@TNG-k==t-unDh4L(1%a%8!x1#N`Noli)I z;y|c1)5%j|@a~}cUda=PF`x3r>Axj`eUrhB`9>d1WuY)xO}557msf*(HxN*irGoYP zAc9u3Y-antWZXPY@g<8X2+m4AU<)B+<3iuH$P;R@$SKp6yO^JjsScjsC7cQHE%D-{ z((4*rThp0%dpZX)qSmVQq$MHgyYc{?E=)F8Q&ytw z+?@LbxpYKWQJi@+3lkmuIvizU zadGi$( zJ+I*Owmlo%C<@^At?8y1T5B9wGdgL!$reN%N#7lfEI`R@PXYB*C1f_!ka1P5gY-jM z`ywYoP;}vgcSv3YlzJ@9jSiGTP=m`zHU|M3KAajpv^NjmQc~aZVk*XcpQSeyisoT- z%iOTz8-KV{{mD4or~!^>=`hclHN(zzvF3A2$*8L^v$iti2V1`~6~3e>gftQrk_@vl z7%o|7v7$|bcdChvw+3S1%}|`|CgoHJkGDQd$SJ@NQI<@%PnSdC{n7LJyR+fdLS(iN zNhS{7IG&TM(uQuUIa|jBe1Jpt&?)l^`7koknIP{_f(0g;D`#jjaQ89JwQWZn@os|a z;p_wX_$7=SXzeQDb-;%-#n>V|wx^Ejp@%N|XCD#Njnsx?t1_om!z{5ec;8I|)}Wi3 z>|MI(Fz_lSTvB{TK*eN*G^IPC@F*@o{flufUYCrz+dY?wVMD?;R9Boq`RMB9UaCxN zJ!_XTNmhQX$Adg9?#BX>OFTIcuuNhJ)fFBFFpx$_*w-r7S?O2K5f zlJR;`955*9hMl<73Y#=^TFu?V@KL`e`@{x6SkDc-b9*Kjn-88=`aWOm;rR?Bv^* z6XSs-H`c4=0<&Qc8!xgK{MI=*d?AOBbTq9%2TKe5w|-ES|8R(Qr?5VIFP_ zs~21*tAe1(GZhnqjxg3JU`7{PgT}sBbYz%{aBS-0RSKfte%oE9v;T7pb|~!E6}q2+ zz834p*2rFf#hDOvpvXbWZ5zLtJ*~lUz7I*H3$MXr;}MS-xhQlFeEO1twGyW|WogPq z5D&ej^nLI-3TbwU^7Q)_VS~K$CsMu`AV0FZgIy{Qg(SJE-rOq#c5lY?Bqu*?=;B_K-o!t zkBukOFt>C@zTv7nQYq2&&y`D7i;t50bYQk4Bd^!D884J%(MiAewzJmuznWh#oxk!aqG4 zGE&%R<-r}-Vth z>W}_#a8l2`20w4n9dmV-1)=Q&*&X~@5XiUSvMS>a)M7yc23>J*wxu~OW;Gp89^TUL zpHq#BV?*a#nVm5vXg1e3D+^<~H$EYpslgw##&>I`i_wp3;Y;|lGT?Zjol>D)iLwE= zV-8&LfN+A$rV#fE|+w^5qP;caKrrxj6%U)Y=UF-oWG5`=k90XPMt@kgM$6`H0uGjLn*U-OTe)Nr6 zI69Lp?rHTd!VMh(y=Tn}&?JA=N%iJq2-tN&a>LD1oKy5UEFBVy1G!gqIbLSLa;1wc z$DFTDnxD$&cB2ltlXxFRL1Wfq+p0O`J7t`^+c+c5d^r2P1 z(BPE~am_^&tt?n@=o4JQw4EY(BWmTFlqnig1jW4hQQT8qa{o_nT!$qQ7C zK8tmUCnNh4#*%_P-?nf-S(pj z^p6egp$_wfL#5AeODrK|=F;QXqjY$&uTy~2F%Ls7Y!GHX7y%z#jC{gFv#>UcGpP}p z;G^&Ap|m@-kf^!bM(>Z8fqsK*^oF$zpE=c8(<- zjL-Lkpf<-0!scRFIA<_>OE?E-Iv*72KsM69zsVrQR0;_KhEk&8>8NwnBO%p06OMj8 zq<@P$3htMf-O3^qL9NV!kLHmAEWDj0AdymtjI=>gVecD1nI9YRa9Op-^Gg(Pw7W zaD_G!i;4A($;{mmJW-`Ct?sE>00`uf_|N122N=!$7cO)BIRm2vn=AdMKAPV-ze8#|IZ7cR&k=YEc?~ z-m`pXDj^(0Y2Hg8y`7D5-$JNu6Wt-qL55U`B_AI9(|(ZM*@A3K3iHDbSzs@v9^fQX zgjBm9@TEGW!HZUL`fut4+~~=nSf`Z-t@i376+*cvT@bSS#(F8r&JJ)Ylr;gArAZ*5Ag?t$8Y04A(gXVw#3f^L;cYYOfSozi&-qcKWGWtf2{2A zXFD1KvgYIY5oC(g;X zVVD!i_x}1E=%em-$CPI1QtjBmpGyFrc^h@Z#$3Fk#5ZebmI&(1(zbBgq_9rjd_y{H7_m_~qUs3D>U-`U z7-)frJ$}zlmbc->$J=5UIg((mm}G~Has-T=Z-{fC9>MDA_3NTD9UyL%N!m>4Mphd_ zsvl<&PMjUnuE}eK6;+jsNuwiRaqCI>ogW=PS*Jg-#vhI{exsk57uWd?GtQw8$3bc# zo(af&?dG4=v4|Pw_bJV@CV`?)%(m<8yPv4D{hry4j16^_wRpAu1out@Ili8xYs{BOgZ`8yNf}m0<1s&85(= zX5iGfPNU>;M<*|lfZLawkY2CmE#YGXj??T{7@+HdC`Gb056f=QIZS_JSw0Ch1@rSJ z-nE10fNX5f>tW#KUEP>dI1J(sjtXdxwL+4Z@IduIBl^A8DW`UP4U@7H-!{H0M{yFC z!-M-fU`2&*A?s`g-hD6@)9|zjXOI6dw~s4^fM{Khl|!Y_R!5x@o!bo@nw$Kx(z|dG z)9$Lbx53Vo5Z+SVQXJ-d`k>#w76s24zo6%?Lct9_fmzNycKdNI9Hat{q@6aNSSI26=AN!CB=+=w|o2G?Zrk` z=A}&JQ;TiiOKi$dBrI>49zodWx`I-|2rQ!eV**^_5yb>QrQ0EZwcZ}=t#pYlagq5R33i$F;l zObH}Vf=3K^tI&L}S1X0CkEaFAwCdofhF8Lm>xCH2QLjF#m4Ww#rjz$S?)+(sqrYbh zRZ~*~J}qT^RU;zFZ1nKKA5+Z~5*vR~%?pWiH0?eHeVAAz=nzX&f#m9vEDgrB(0g*n z!Gq3K=xKM&>0{>$Sf33Q=RMXB1i6z*U%$n{(c@KHC7io}>)n9=vV|GcafU}y{+NQ& zZYK5m{RG&vI(zWhXfHhLW2d)EaWFRxm^#Aw#^4iXABW-aYK>2_$N-*^lgp~#CeT1M?6r2llL%Tf6iSnj@gTKB9gY>RX5sxKo_UmRMW$_2B3qW6oA zCe^J`MqNl8{t*mNfq-XUDg64k);6J$fnwb{6%B96}I3jeDC+u!6oUS1)+(uIj z@4HjQD7_~^F{4#iw4n#=n!a?C?<#~1t zCe#9qE9LFf*3BSB(_89(Arq6g29LJN*W!Zyh2fi$U8rrBYbSiG0Cro-3s=*%rtS;KYwpgz7I_|TC;Sql_CvNa_VP};CGwe z^_TMm9M9KgcHY~8Q;+4zpgtBDmVzpZy!$ZZn?*k5>2%z!61-cKllXVy!iIWsWtgL2 z+jFfBu0N#-{QpP6=V$lM+dhVAE!{U5eq zyc}O+?`$3fn$25zJnsJKHNR-O{+-wS7dU@Vko=YcKONdrG&8V>`fqK|O6g$M4}EZ|%P&_ZjVGMmQ(UTp3ya+edoR_(kJpE;IDYW~lXn^POlA4Vf3zZo zX;&G}eJ>#N&IaS~{Hv~dv0ltFRpHf)D#dGl^bJ>}n{h&3Yu6dq8i*b6QxaWl!;!;* zsu!sTAzhoWZU2Ex*v4<`C}r$}v-z`O!5?#QJ$AFy^EP5N{^hprncF7lDLfV>we z8858dzTFB%qOSQ_t7ULoZ}r-i{3=x8|9G%lydD*wk+ewI5HKKv>w{op3touHYuY@W z1$-YjiH#_?K!RjyYEoVGPaiw{@2v6QPom)_i7h{sis#8mNP2jRf3Ud`@Q`XNl{;w_ zDxHni`aUxX+YP-2-R|}QM>1>NQ(}`#wRLm@zy4eN!oOx5UsVYwzT|yAbhR0*9*5uF z7|{Yu*+TYo`)iSs)HAW<;23HaNnD(fXoJxPQeW1>a8S%+X5k<AdBCG%BDt)VM11~$z~V;Y(C$iw zMs`XoUJ^R3IwFyW^%~9!jw9s|_p;V0biN-NTsHNNTjheU)2!VAn^EBOVMw^Sw-5-k zibwaBy#?oU=36dbY=siKt>SE7ZGh*RZV7W&GD?5+BRM@%1>N!tV@BH|aVc1JIH$oD z#J8whI>fbsjoMZU`^XxkF;T^23xug(SFi+~Dz^nhwaMp)_`( ziUZov;MBtf5BsvoB?%xsWdbA5f4xd|iE8GKjj;w!~V5op6G}XISN6R3ZSw_;vsSW79 zM>_Nz@qpol=dT>(N`c3B2d#!yEDp8lOc{4JBZYFn?%O=c_%iB|)JM+-Wa3M?ee+@+ zDy*E}J7(C1iqDVp?V>J)9~pWOyC)A6Noa@-T@ldgq)EFh=0neVS1cV>5nLL41Z3Z< z@EX&;GoJ`;$ozd$RPT5x-iRzVuYMhhgaPxjWto|H$@Off($O~5RyuqpAifpsjt*M7 zzpMlT$wO~=*NzcRk|gd&!yu-kgkR0OAMWWEBsOeqgX2+oB{TQBL2>gGP4iL)=<>To zO5YEK;m?#D3>3;VqGgEW_?H<9sn}B3xOy*wkF^C9#4PZoC;<}M zUd0~~u0X}tr^3!W@PqTR^o=nWN3lfyp&Lfz0?VLLQWa|hdbcYEU6?3D7P1JSa47tF z9FP6!as00=^M?}nO}~xbNr)P*{%1`0Bi$>2z%6EczcvNR&E&Z@5wgbc)3hyT7fec)XMS2kn}*0S3Ec_KlQ}4B4>c%G5bMxkG>9}BJ|LNYAP)O|Gb!^Dj0SKcV1r$5mh-0-aywF0x z;5UoYvouXeW5rhAI@Jr02HPlhT9v?lI{&j&42YR)Z1Fyu%3vjAA!f^ld^mHnB_exA zHfqwl#d z?!xn&I(FmWrfp~$vta?N!|z?Jav6d3k9A>RcYZ>~9}_R-dU~LXVIy z`icvgw>4=xhkm}w?mJS-!xy zMj0mdJ^_|*WY-J*sKj=syV6D*yfCZ`7)IqvV3f^hpL}^E-ikUolVg&Elb;wF?gV+` zxx>fEGyB4D^NEuFIhh=cW}>*4Pv?NAbtvi{-*f@D^27QYb`zQ4(Zb6^6M3*}g=;|K zy*KpS5prc(jKeJ4Cx!W$m?`+0igw3q*qL`J`bN_TO@nJ0?T)M)`JOPRDQ?JO)?O9z3$fyYFVL{!ozKTgPy3p=6uk%A_dA$@x1I+ z2?YvD>e=Vcp6DjE|FH~TB?{-u4Jb~B<5H$z+K>n&->40TYvcf6)l{5g@BN?EPiRtDcQ zk9zAE!-I8p0s%TT((Cezyq{epQQ zaKodu<8BegZYt1?(l3T@Idx}AYJ;IFE%^!R9ohvj1 z(m*$P_tZAWHW(z_4Dr_~gUNR_32#1D!&0i+0l6I+C}quggvzlRZY2-j_$HbRdRoU{ zMIWt!Ox=;|SMAG?fV8H6=$I};!L78woR64s1w8`Gbwei%W z=P?6G17$?I%D{_r$%XJqyCqxTRweLkAZ%b$ECurvkxBb|4*02$k$tVG3dbFJczcfO8>-2s&uhlq-`$mVX%u}EJ?YKcwU1GUiH;U)08 zKM&T1RL-oQ34?}EQ-NphQ7HOK^}fLeqH=_Di|hhz3}%~llsTp4LZ{jdd{AA6U%Qy; z1DK;hI%(Up&mBcr&cq_iGFORu=kx-T8VK0FI39eZ*Bn1LKE3xag@8?INk_I0=E2#B zTfBlFV?pH+$xZj8h47_MOU#kJ1~#WF2u#HZg8esHu>Doj3_Qb zHir;9j%)eoQJ^Z3O08sAx+hO1Ff=m6uD>U^|#KKVh!IR?>Zg7-3c7|c6AW~vFu zK)QamD^V6gsF(1?qI)bJG=8|*j&CVM!=%M8MR)RG?s&RhGm#rVdQ*^P>1{41PgTYz z6P0)My&ju`18dQBD15uyUM6R5aj6|gIysIKqjfu8zPJjPKfaeZ$NrPLQL1iI~ zo$t(382<27bEdTfqz+oEDHpqe^iGeLMKYPdvG_UM%Q_dGvW2%%D-z@3ZNDE1vDxt9 zC~MUbw?;fw_BrfoaTp$r8@Ieao&sjCH!Ppr8~~~ZA_Xq7CLsq|TrHZGo0WV|(NEa@5$TW;Q|Pj(g--Vd_RQl9}zHd{dAC6eG-5U2Vj;zh?OH zeRvywskp8FVlo%h++N7Lcjdwv%cWf(RmxzkPSZ`y_aP3R?q;6bo(cBe`U<B*bJo5c~2`6m5nQH=!$4z(}s7h~|QIpO3+ zB}5W&61&c}EVKw3mKJbn!Wr-GF9)BOV=w)9qZn%-HnZHm$DEXo%hTkpw_ZiVbou68 zvf&<}T3x>Rx@I|c+>7$xv!xzr##~Z8H{^g<Tc7DDnm+qOWaFuYg%V|}Hv1~f<)50Z&IgG!HHa!bn)@H=8{&zckl zEc-a`+WMtomvhNnnO7$2^YgZQ+UEk-_L0L<-yD%bu(M@VG#>iU*AGXP2?=zW! z0kL?Xio4Q&Bmi-LnXNjnVJ2jUO=yR7Rb$h|DFZ)+T(GIO2I{32yk*hqay&i;H1=R> zM2-$RIo@@A&hrKjtjv>E=?{QykF1CXoi39?JaOIrXSZU>@IVmH`2G4kpaiHUMQ&q@3Jh30zXxV>Bs9!0Zcq zXr+CMvEjz-?)t!X*km@5QbAS*cDLW3VfM?xc&>o^;thpZyuS8oVMh|!>?S!%_vH=R z+?x3uQcw#N`=_Y2`3sOxN~JYCnFA%CXLC#p!|>2&x47jOWstLze<>QGA#GkQ;FD=J z&Wp8bC(4#09@40H^Dz%v zVx6&tT~a!x9n(9LgkDy9qOiGl3~Q(#pzXIET~FO1GsEw0ZBYW+oR0E+M4t^ul4WN5 zi)-+uGW)7YQV}Z5&+&g#s{rHL^9Pp}gOFMHxZIVIQZ!V~j3{Bt$3seT4=xkg=*sik z=G@LycMD7F3vOBpbx3dvQ zzQ&A@w`ZcdIt`cqt#tetzb5KhUjYhK;_~%H)(B5(uC#ZB;NF|5=P#`~qgZCq-g}A> z@F{7;s4b%m^*rC#>q-#s_~!59mOqgVV#-<8ynR>E_oN8z&Dadsz-cXIyT1VLYz+?x zpv{191Q9+4!vfsIu2;lvNI<&nqX&o0N}=z3Zz@Y3u`d|*7%#t;3!#^6l)oDVL)rzB zGdqlfF;a9t3;+H!pc37vP&il#u?M90f6pw%rH?}?pRS}qC<`k^JD)oy(N?jO+Q%c6 zELqp3-IYim^!Z?^GVvC-M31PMSQ#wX*RBj%7DASflH8X^fxykxCY5<24zA`+&S~;&BU5q^`4hG)Sz2+w%i{ z=v9I3*k$S$C$jOzXs3v=NF^+%Y;5vAYYSI;66hS{5if%GlzT}&(DSZ!GVM&qP3?C( zPpyQ5TgWN($F6z!k#nplS2PFN-!s~E+vdZw>vG3;ooquUlRhPr?ZFsyaK~Y;#4;?u zwCTGEdmNmcf2XT!Q;vrfF3E2>5CU%=DTjz&bq3*7&T)szVyH4zrumRn0kX4?g`^(T zLdkL9QMD=un~a;5Tf`$^^TC%q)+*V+Wz2a>P_!5YGY))7c8&#TlsJR$h#0AWYD%@EGy{ zjV1lBObQ6zlyi#yhUp-B@kJa{QW?@6RrXY8uE2{2Md;pZ6+_OQ)?@82Qqes3Izd!0 z2lJm#-Qy!khsg(T2t4{3(6()0`_PX#RI7RSDIh!n&uNrg(|DJGYHVGXrymv}PlN62 zMWS+xZ)%Eht*iz0uIu^DIa$NbkNY$%`^@nYmDRp*I--J^-0Q;m4}SQ8hU#8%Rz0M? ze6KUt|;h#Ljzac_6#phzv)Ml}|(gqe!yoKFVcz0Iw9rX2%GL}$%1XgAl zyL0#3fLyR)Zr8&arLP8V7lZ}bzTt-HUGw$&071dBdWl=B97HEf7R#V?Jr!01d}nDc(M)Lgo1m z-6{QU6snIi5-qBMtu6MFgd5F2!T1wv{98py!=~ujNVaO!cXQm4#zj_IJMT)(akmBX z1t)@!NKfH3i-oBi@#4j=X8#>U2`|@wswknXTgcEJs)Vr^jLR{az){vGSu{r)p?1@7 ziqnTQ4A(jp+Ah`%8&3NZUcW8_HO7d7Hk&-;<|T~S6QRoX3LYd_A&{|OnmL=5iJ!~} z=bf~A;HbG0&GU+I)Xq$w|KwJRa?Ii%V~JNcs&+m)v6;v=UkNs*d^J~(eqIrg=EO+R z?nrjLz^wvpoAP?EE7Tx)B;9D0Y(5yD9B}aLZNks5vvSPQ{T9Jo53^qoJm?r z16KO6(G*Gyz>m^nBrk5KA(N&@PJm%P-kDTWSHE0~RiaZnkIhu!u@OF+mcV}antk$l z>Z4L*G1jWre_W4AvdiC1S!2;Y(Bm}E^-Or5VmCLPI{|lUhnACx%9JfnF9#K!9s@hx zY~k|Hh|~wGU(03nf!2HRz3v-Kf13Pv6ea((#9vpG7->&R3;DEw^Q+C5IY%jNo8@o?hq!&}QS?eJ~?ZU);% zKaiN66ML~*2_)(6hOYxE!18NJ>S5+$ES98du{InA3;Ao2^qq}pudcN6N@)ZidnGz; z$}7ZDMwakvgC1zguV#8FH5mjp?|73a)Qwd|s^b~9Be42P^csCaCqA4DekKr@jiF~5 z$=83Bf@htva<_gMrfFAhq|(g*UwY@vZ#m6BPsQ&jN_e^cR}>{FU)0Dogfn2>v~y49 zd>tC8Kj65v*b772H>}H~^N{kAI=Bf}!YSKC_E+}^5W2p=!abUT8~P!!`O{ly?i+hD z6c&b|b#XIl+6X?)j>20cW}uQa<9gkW3D|;p0xx#Vz_#uk(6~Uk4rbhZN|F}u~QwW@Le;*?Q1#6 z3~=A{w10yMWA~FZV$%UnRMm`n455#uhgrq?Ahb7l*w52y}A{JMM4^71s%b- z$I>gqJ`H>=p6ed3>48Z*ZkdiK;zpCIoLUUpahPd;_ST6aY|WH8+mue^?v!Kr1)0l% zz^XBIVJ;W5w%yLp4sV3SoJl>U%oK1N_1yT3w;$h&QFmT;>jAZp{T40vhw--bH7`f0 zA!t$LuI1*iLz`u-k>CM$G&i5#+oH?@B5q@E8Od6aG*aW)9CU+Ii=udRL>!P)y{?|y z_z%&0JW#YVZUIZ9 z^RAA>t4+mA>4j{>Dc{@O@&Uc5ree1q$7_*N~JX8tUSX!+%`Kg9s1*ZkS zw@a@q{B+mfQI!0)@c)FOs95`T#VhQ~zxv*XZPWtwkIcp=Zr z|9x*H+AlE0J9V~UnrPd@7$ks(U@&W3eJQ+IxJTBo(uq4>o|rY6uR+7r;i;te{lt5N zF@6_$2{0c2RY2IO9%jy-v6k%agGXl;-}#KTp;Ymb!LgidT$-xxaVA zmZYBhqc59qlTK<9EmJVQWL!CBP3&4eJ9o*9wRRrtb>GEAI&_P3OGB#V{AEQ+k=ZTl(8)Gf`99*e z`*=OrDcT?5p~-;zF^60|i$k!Xf%J)Peib}loVxmA*a8dRaYSt^&%hXI+jv&idgNDK zxzi+E48xDOV)CQn;2v4p%KTIfq_2OTBq5S2V&rBP#hDFoqQTFI;qqH}B^Oej+ZY9} z1Nmvsi5K9z%j#7s$;3&&!Ie}q=}zp~ChV#<9gB~QWhc)!7UL4Hv$!&KGM3ny>)0ws zfxxb## zZYiXlZ#`I(KyUg(7$~@%zV87y~J^>c=FIj|{-HF%D4R)~)=Yp3J zyWA~e-Qsb6yWp4E9y~;PxizYz1D#si`ZqUL<7Qf2bJsE-OlUzX8^vJUVyd^Zo+A$2 ztFB)kJWzwy%qkrVmxuB4`|Io+dn)l>yr2}LWCmQyDJyw%vm9B|MWPLe-%(~Yqc~>S znuEus!;7_Fwu47TORi&5JFsZ^Ssi8V1{0og#kNiPaEvFO?{f$mIDp7mqbRh9K(xeGIf7F1QWtQ|h*$rrWT3SVucx{r+OS1O~v375gl+@b$ zv>loDhwJMShX%iPJ-?$U`IqDP?^E!9QKA)5(qPdphihAdB$x|TU?)R5Q7jr=)?d#g}ONF3(-PB)ZC>vK(^E&Aaf|1ven|AtH8+HqQ zXQ>zK!4T3e?p31v!041G^)bI%IAtIn)PAEHqrzQ;NUPeRX7ed84&zeftT{widMFyH zu6&e_mX0BQ;le22-1ku^$Mu>-Vx|6+JGHrGg+ZL86fCcM_sW^XOu`_rR- zM^W;79`y$q{x?O*>oe)3x7j{J$mqt%o}SN;x71C^M64#ykI8D!<@MvyHLJazx0j&n zOpvwM(*=x={2QMwS_WgX7-%*tC{fCN@;HELK&rCGBB%nV+k2>-iIZo=Y{Qk$y1wwRKFyAGZ#D}3nE5oPSB=FY zJ{gm2b=b|Q=BM{G1A~JPY&4EegFEvA-w(YChqLP=D}!T1W|K?aXzn@zYRpO1R@M48-peE@&J4)(CfuL63mR1v8CY_|yJU zXzhbK3{Ag(G?tKqW4^u%6ZSD!T5v|`d^tkKC9OZKJ%MnCKKU&hSsXqVDq2&2UW)r3 z2g%=W55c2~Dteb}Nsy4k_=%q)9~8pfU$I^^Md^{@((reL6Cx%3Yd%*!gz8u9h`l0? zmz5<`3I)@_ez$&k*e-&2k(2rC+!cg9q%%JRWX&AT@*zZ@L7EQ-|cHlo;L!(%NbahTpFe4RJj1`0+mP3@_&gFL3pHeHia=-yv9 zq3)IeTYBowBt^dhyLavazbtw{h-cE;S1%0DEFOA#<5oQcYc-U-c-e${CKl;q&$IDk zKc)LV?-aGL0 z$H{wJc3`Qj3X7!7Cy^*vUKHn-orS+{P_{g~or8{1DXh9)QQ&eacg5DC08=g6*-xyuA-gY! zasX{Ms`TY|FyGgQb>S!3k|7NckrQ-Ikh21 zP6cEu8*Ma7L?W|6LZH>-B;<8u8Xs2;Ckl<<^d^#ufz`R>g5+cn_%?YSD5G}64EsC2 zM?w*7$V}aNTZm3=W_dE+%2y={P46PFWs5b76ZToFH+h74EMFNu#YyeJ`p}Ln>NPRAN2v61ApT#R*REy?~+3|7b2WK-6Y;M#Vv zT3=O9aG93UTMi(w;DxmSf0=5K+H!fL;X^)3r759BSPh0T#@?m;7L42V1}HsumxB>m zT5FVY0j%j{H5qs(AfNCm**nu*ysg~cbCSOjXrpCwTWynINBLd%nUxB{ePKs?eyIxL zKNPs-v}Qom_O%aM{k9OoJJfHYlz>0a1k$hZx8g#u!0QxuZ#;Q|x#8aXSh%J}P*PNF zaQSl7a@y`jShB02uzys55&F#e0%Pg0{POXUn?{w`r+i2_?F>OmG59;QWIe#__M&DA zfn=n~YHdpldj|}$viv_Aa?owQ{8UzZ5x!hL;xWZn3HlucfgCCo*xh$={h3`BxNl#G zevb)Yf68CO;cN}s2rWF}?T$n03mLbzpQ(nxJi9!*hBDOsKFG>SL^S<&(wKCP6+&z8 z$-w=FO^_Qu=y%T`v)z!Mg9_276v0}VDSqF`!YID0eNYL?oL@?Xb6&kjMGH_i; zgd5i5r@X`yup=Wri?Nt+GVJmntCb_6m_)r#NPjtW@072&nnDy^Zy;O#*%CO%@Ve^e zi2_`jvzL6?OZ7};;UQWMIAoho7M(FxPz+H|_Yigm_=-j${ zid(T3ZnafRMfVX)x4q4&d8Rd}vb7@kkzhNL7D8Qy?+~5P@%*jli^{Ro#$<*}zKqC* z8>y&vN5S049ovl{P3T~ltkfj<$Ly|_r2|B|q(y6%`kp{5n(c{brP}X`-h9%Zp zaLD(g=QlsBHzM7MWHZDA+Kk6HG*aLty`fXLMm`uMzTwN^%|b01`VNuH8Mq<>&mGiF z@#y?c>76dAc&qEOxpH4TuysTj>^YSI*J&M6uP_sF2zLXyitTNvnAB&t&Xo#|Mu$ZC zu^P{~YA(}9kO-3E%|}ZpLayN#FO}Q#Ab@r{f5kl;(rE@0_6TL-5n7W-lUs2p?Be9D zaJdF|3xp_C+oS@moVA4}Lo~!DtIOw12fN_Wv@jShV4>6p-@1OwP{=mUl!`%jq>m1Pg>3BL}NMJNgfa0 zj~`d>r3%L2nG&uI!2+C1b+?$F55)5tw~7^<5}?0WJtbzj36oL+E@)4ZP`beB#+{J{ zxOV!XF8!S>%%412vq?CqF2uc-4(^M@9yjsQTSQgWafLH}`}-X1uD1VWu%i&3))z2s z^UJ{L6C_RR(rhd}bRv6KMgg+)4c%V1_lNjq8jo4)F1j4EH;ac8>0+`?( zGD=B`gH9=9E6B}TEKb1VK07mB6Hl1& zjdRmXff>-e6gDzC+<*?#p(W!x$`PsFxs{F)TN}CkG16C}fM#D`3!iKSq-p%>lx?X8 z&b_yPnqBe$hrZU!p|=ZxF2jn`Ix_+L8C{2JrIRr3!(9yz5f5}$9Z~ln&g;1-O0~BO zNl0&4bT{&BG{{*#2pbM>gJ~hN7Iyh+yc)yA%wIxp~D@a!oNhlG4~=Flm32B!m6`{BAgdq4l!-a*wLXzm!9e%1Jy}XC4VrKKlwu+e zz-i6<2^}x0AnM@7KK>i|sP!(UTvW6IMV9XsD{c}5QWay)gNNBrCnb>~K{z6=eIR=l ze<~I9tuB6*$@E1Ehiblabh(&QdQzk306}^&770;~MBXeFStn%GihMJ?UjIQwocoQ$3DzVEQuP=$*tR7Cs z8$DSC_KTrlK0#u@!baQ^&_-hG$OeX_67G8?85m?9U+Q5Qi&KgVueXJN0JaC#mC`$I zzzK%6HDnp^FR%*AG-MJ zhh;9vgjbf=WCmgVr>%@1S6boG_6U31hE}-PmeC$pntt!T^6Bu(rVj%vu&ApVeiB3G{SZ08*X}?%%}uDm8bZ&93Tii zpPzi!c#`pIRfMY3VhLJw?g)Q$BLmYkBis8fK8CaP=P%HUCId}a?Vcl*gzG75!{7EK z5zFFhBDXk^0GUcbJyxBWK$h`?&gwZ);^{SbuPU01ri$88{9B69C@4wS;bt`E%2gib zuhD~&(?K?F9tr4Z_-z)fsxfc0B%q%t${iVez;S|!gmK3f?=;vGL`~Pw0d;>(d~tR* z=2>_)sBCX?duTzBsp$rb{6B-hZco{b-CtwCPP13La`_R~a^ByaHEV_8)Rb*M2=#-e zJmb)<$E6_lW!uqrjggQedV)upI3E7l-p@8h(U8)}K3><`4*!fO`S&`+zbJ%^(rABE z%h%qw7LdBC=qVG`>+_rx$i9RrJ^)8;B>vv_^x2le9PcHF3KXEBetAe14uuBnj-a4qnl*ufSpcn~MK)J;rYFO8IIr z4E&iiW%5z6NX61zxVLE@iv*8_E5>=Du>+qE$9-$8D|u-A0MR+AlLbT{^4g5KIeiv=BU}ia=W&`Z1v@yA+Qu!~h z#=j&=OcK<#suU3E&!BqpbN9y~e(xKjSu!?z)>nv~oy}INa$7F`Gos}5f14;d zAjcP<>-7;woo7P582v$&%-1~Gt`B#(brzLMR)b)}elw15;W*$Ub(YS(9c<-_Hw+(? zV9XDNG}Xae^3Y6y5 zhs}s6#Wlq%OSh>y;j@65diRM2OqsK(do0k0rHZ+a)85SB;oK9Vo#)G;(x$EH$CGT( zmwsP-nGSLNO4s&bLh^WXubX7?<9=)^3)%JX(G-mNX_IdyvX4{;cHI8Z9Scs!q(ms} z+n`PGvtZahqCj)h@M7P^N_ZZezU8vs1O#;Y4PG!C0H?7mFC}*`yvxs|7m0>{UGl#n zO8)&y{MU(+?mw%OC8P?#gqO7sBI`jaTT4bUrVHO+k=wSURR%AFi?-Zby#wJYYu?h= z3V>tm=!4AWDyV)>-BLjazf8Sf*@&f$LF_T!sKXZCxNqOb?@Bez$g;z|Tw9|PbVD0O z3cVYFopVd&{?lb>-9Expmh>6J2T2F}_-mp6q`^k*xm29)Q2b>1HV13_o__LWE635% z3reQtZBYF2*QbTrAha7C^D0>G$F`u(_PnunY~a$FiP+l;pSf!#_FwJ6!RHZd&wp0p zmF;x=_a3w(b^H4d>u%8))Tp`Ntf2}?->cI-)yA-&v*X+g;{Ve+W%8c9)eBFpSVK?M z_28i^K9*9UE%55B+{(FYW$@rQ?PqF)rS_tCQMVhH@cJR1&n+Fe6q z(Phxn-#g|!2r_DFuA$1V_++i^W%$P&@Jkd-T~qu7{q{dY7OR3_tA?yhVn`<>lRwz` z?Qto3kJVlamaN3#h|h=CzLrBP!{wU3M^#vJ`f&X2(kAR7bMDn=Zv>YQN?Sh@n&ubx zs(I_Zh<=h-@_PtL#o$X7+gaDT@C5twV4;W>&^@lNa>oB3TJOIhO8#;7{`*8pv|0U6 zE#)TQ%IPWCvM~UC&o(zxXBzP9V(mFQ(^}Ad9_Jb+Sp_-NU7k^DZQ#%)1Lxg4kX12U zCSAA+pDjC0aPW^~L!yC4JgE$gXc)YjsOqrv%D3T12}B?KkM{V(o83SQ?_D;J6R<^( zei>b92kdzgsP)k#3WZ3`5>Kl}kgy(e9PF%w<6#u?KkE81Y3-;bFI6hIq{wPcn~y;U zXPX(-d?l=2F+QHa+ze3~(_E5TycGkZWC3V*q2WaaF z33$yPa+my8VCI?C&(Ik_jp!>TqK35~0V_Mqwol?4!T8oVzFz3!%@!Q2uR`t^Udi-0 zJB(d^I1*dmhZ`=J0*ePGVJX6Gtk|m)zZ_vYoX+cnZ+#;ezfcfTK6{g+S!YJEQ;ET^ zXW}ys8h;Z2r5@0d{ru$0zV_h?0N%;r|~*$$$8*e{od*uS5y4 zp8FpVC930P9&n}?A9h)H&Po#qVd9ph%Zye)bL(0C=EE^i`@zH<{;nOAxb1!qNk+jT zM?Je_qS{2Ouo}b|F#%Waei+d?TnUs@_MgKhOq6mv(?BL{O?RfDC4fiM3?S9dULgX<{ zb&MHErWQ8&vZWRU>^A7m@>U>I`$0KghFUB;l6lQzTPLcqd~q=1pTM@U6CAnsDzPH% z;z%reKS01|mB-8(Fk8~{ecP{kkfB?ZmzgMl?}x(Ly@M*jthj>l5MjWNNZC3mUs;d; zdztk&M9F`?8vkXYB>3oO>J04$-1E3SHThXL@`p_(URN6@`p-!xQlAo=Po^MUJ>dp8 z_NbqU|I0Y)im)8GeZ3P+&A)OMu!q2>%c+NlTM0c+T24z5^=tIrf9^$AcO-r&U{kwR zRD#Q`^T?{g1&V`)}>pQ(PNoe^>GhvdZ3hJx5MlLgD;N6+r z$uHj7Xm*)>{mL_K6fy`_=HJnb3bzKLxDs1|^SS*Ez0r91eWFu=!6^dM1sC4nJ~Ojw zAagPpJ-FVCFGgcIcWhL{uS-LB$r_7s+K{1)bfpcXhZ84)B-_BX!?1b_Mu5Z^wN7Zd zBg_v<%(FaAM;booxG!}dfbSQ@RjHpXc+@-Ks@9)Oh&__M>&Q(*aGY$Y%6*!Nb@a~{ za%D2d);z8UBf=2#ej!!2m;p;Sg4xUc}+Xu^h`Z+`qGpOeR z&!s>tWTxD=(3b+^*T0PZ*;5U9t$nV6)^G5uRDh|UWH*YLwBIoe>w~w~#;YPkiEU4S z!b4e7E&R$KEjX%K4Q|yWo1_98@Vud)YQM7y-1lBrW1Q&)xhEXPT(cdReX%7yQ>_Q< zy6TVqU?hweDaQ^8J@v+0G|e48X@j88tfOvA43pgJl=mW~umOhIcQjJ6MgyHkVel|1 z9c~*%c3Y}xA6i72TMIGx4% zJQV~+=ynu1_n^HP$H12(t?0Y!xJkEYJO&i*uk}-D1m;~~JcHzeaC@s(^DeV7l(3NA z;1?Xk7&{*Q{3Dh4^vuOpI_@6q%rEg1+3JVaH)&U25>lq17)=BEo+7lg+A?(PP8~!Q zUrV2x9>OmQ>*$0e=m?BNtYtSVGT<;Q!aq34FIl4!qam>D@@j}N8POERF@fFR&g>*j?!L&Z; z%QaNG@u?mC!p~EhZE1jwUW&P0#w46B=xIBhQ3*fUgHJs$&qIdKIhN-X$6$m^rT1K7 z7wj;-*4UqH2gmCvrFjPbA))#kqU7)HsDJW>|2Lv!KWo}OJN=(<{YRYI_EQ76l69Tc zSHBafUZgzOVK;_*Qyy|1IxvlU#X=-gxBS4EylaQ1xW_P%!sEn~$N7JKYyPH?a^}BD zlq}^7`3ww^aFDD!*LAuS9$guKJEdNOKb32eADa|Insdc7>Vxhmq3`>LmAMer9+EX$ z&?mw!c8Y`-HL$L|J($Jpj;>99jQy8?Q zg8R3{z>SZqCR^?oLCvl+>*a4MLH)2o0pGn;NSL&z)={#DI6jL;uGvh~`sSroOAsks z_EW>k;*t2w@mjp>n;NVL(NWpCHycjAuvOAcu7;QC3-j8Hxi}=;G%QP516*~Yasr)# zxH53;=tA9F5ct{f87Rxd1seGZ{NSO=@o3`lE_QWCMf?n5m1p9-8#*gsyr zB?S+Bt9!K*^U(5+qxHm65)^y$zEwY+1BLFP)Eo7MC|_(x=G#;auI``nt=g(UMreKaP_co{S%YfDfCSV} zo~ES^&4=~8HaCCH5*SL3T3x+g2;LRhC${w=+~X{GXh4?>lYw*Jh9--!UrJ{)M1`9s^G%-b(Dt1yoH$g8ogY;ca))%%R7 z=x5}QjD9c3!)|ZBw%fD0IQz*x^O3qKbX?F{*B~4+BJpNdD=ao0o@=MwPl@!uyDCsI(=K-i&iJ{LV{KLw1qcA_nfd!COV(V3pV%4@$;jL(e5 zZk@40%BGif;b+_A^f#Py2O?@@>AQzthak{P;6{4l>bng(N(z6w$r$pwk ziff0m*qsbe3QP$+&k%}(!-JPBqf4Mgvh7a%U2ft3tXIrtO*?C0Z?o;Kw*=AC8d@^jG!cd>a@6(Ew-eus~6RlFg!iRD_CV{)U`49%X`cvFA9o9P%abzmU$ zBkZilF4c2*^=T>0@sDv&_~!xTaN|kist8EEcsNdTtpG`e-RaL=3SdO^OF=|SAx;IJ zJ{L@p0Y*1cZnBkCLTr}AjcWqLV(`MTp;^T|I8yz*rPy0R2HkQ7t`aUJK{U)9Is ziNTekg}aq7nQ-ads!k5Jkl%X57f8a1J$KbzDqY}M)fI;bQUTF9+UHWoTn+aqa;$G} z*2A9Kw7CU>w2AxpCu#I*DR^bv9jkd{g-LrFpT&0P;P$k3J#pgGl7IW@K>W85e48`J zLwmCVldk9;$>~pno)htOOtJ6KxT(PSK1~DWEk>|)1|tfVs&rYq8RK<#@952yRA5gY zoIA)<0iPCPMtf;O;jYJCEsD5Q5Oe7|)%++Q=mN#Lt8Ws9Q&h{;+o;93O{2pNFA-&P zb&fc2r2}U>ljZ&1Ht1VulfLXO2r!;gG1&nO)_1;3u0G%pO z(`(LCBC_eUSsj;M*^^MixQ~ZW3&Gh#L$c=T)u6*R_R;-v5z@c^#cgd8jTc|}-JSIe zgNrfaI(kfXSh2o2w^btn((eBl8NcjP%+G~5he1JU% zO7oT8xxhx1P>R~t25d8(8hk`NPX~sMo$2bUgdxj4%yl{qX!LWr@5e*p`n)`5xkr<5 zFqDXWJ;6|j0}OF?tj`J{Y`XEv(b-Hek}F`Gmn?u=+7p7j991yS5-lZ}6N9I|n|(G* zBM2j9Wuy~M!Ur$JoPPZA1f>F%ou}8!p`4{HE=x5JQ(E6hRn-ua$yE|+UzO#-%l`d` z>Fo(v_MK&ZK0y#vgbUyEr7FfvFY3E*UuT0+o%9?xbsShnU6)obYQQK;*N~USzTS*fWCZs{$W!YFpC^gl-bvT{UrS_fs$qT&ijVav&C09T2)2Mem(_n zPBz^x`&|TmO^aIvuQ)=<>eh(~{dP?IvuGE&S&OTZd#jRH5iUm2IgI~^K$@)yr&K3> zaOZvYuhhFrkQBC`_QW4OII;0rnzgAAmxkZ{Q7EoJYx~<9<`)P7V5{t)TA?Kx#m?=hK3TFUUfgP0MgyQ{Lv*%~W*g2)Ls6j);8Ln1yjyw8FLGHQv=p z2dFJCp5)xj#un#{<7{J27_sYw|Aie7;akkFN3E9|;DqQ7y4Y(;c(1@ku1u#CE@r%! zRg=ia6du}Bj?yKd?oy_&QTG7{zB?t|A}@e=oS7@nNx?~pgYQ!0d$3ygP`}7JLh7pK zM$i-qI8QNOh-{C-`^xd!c1O!_=e0_;mYr#^hj^#YEvBQY&xMvnk6bLjQDNBoG#T9V zli42%7K7Vj>3+BEC6MxPftj9gwe&R^KlMsTfiYpWmi5pU6t)Yi{^631BB>vve@r57 z^B*qRG4Brne(jEn!_AOAU%q%XwF+lt{BzmgWN?rM{TrIv;_Gy-FFm7M0k(Djo$*p0%r)|tEkioWP2APGdL|M*1{!)AD5CJujUy}91Z(kk)7tS{`dPqL z_HE>&vJcEERuxtlCBP0V`88g4eN>TR49_Erho9=FxTw?oa4!c7z4Cw>3@=VkB(~Or z>lcbFwct`5o#tR%Vygslk36S&`f|K#FfXxJrWz-m?`&9kehQ~H61ftX-xEk;&AqyQdzQDJ1F-lB=;}bq0d!Y&tpE}K>HwCi+!LPeK%>c z9VH{-*pmZ1B-JRC|8<7kjZ}ya8D%uDzV`*S-O~N9T60k4-JkbUfBbNse7-%0DiL0a z%MYnY5ajL|{*L@V5s*`u)zoVn50QgA_Kh**WBL{FDs#_9lQs$4QwPo(KhnJLZl_9G(bI0LN z#I&shjqR76h+K*5E#Dc!nIg9=*y*|`jiD#t-@)z%C!qY;T&JnIR@cr=W zn>@cLTyPZNb~ZY{ zA5L!Js7C?<+f6}ucYn~N3^$2|!sl1~&X0tPfHee}0*na~BF92}*Jd5G`OFgox4co^ zL5D2)RUNe3>@)MWA*L{1>7u;NWDJ*HzFVYnDn=2GK&|0>o zybQjPSDVz?7D2RYpI}55V%Emn^V>Y~@%0%&v5x`eu=R@3;b)?;=vX9N;AL5bT9w-? z%Y7^0TYa>}zV8mG;>h7;b07+Hp0Lv%p3eb$hvB_rM1#+T@i#Y#sRbve`nER`Tj&(} z;|_Gn1bLP2Bm3iB7n0kk%E`u-L!foqRRtqU=-G04JGBJ~^^@{3HZ~u|l^h*7iij`* zbK{`rD|0;Vk<%nZ6OEjO-*QLWyiqYZfcj^36oe1BemTpSk6*2Yk9(G-V}Bgu!jzQ> zkiLIRBN=wW8%%OdBQ&G*Z!hRrs*#ZC=qvhmVm*DdU zrxN%I9BAM22_TojCliPA6buj&lC#`bjiI^;SEL*2!0bstocd4z@QuYVkQcOp@4=ho zUu|vhs9wcew+~qu6K&cfY?gs%N^M<#$N8YC{N>@uFX=cEld8pkA`dRFekW;BG!pxj z%rA3N);J+jFzl070#pHta#Z2v=pg^P;Y+0#sCEsHsyuB)X4f*BzGAIKt9Mj8HBujLMd5TOlnZ`D1;Bz)CFFMeaX8aB4P@2mdRff7|mq6?9PR9QSK zvI9O~{5|lGr*bH4af|DiWk$79)A`yabKDzC{LkMzNEKIpIyisT8 z1^M@7O>mg5McR=i5=0t1b13QkQ0H5f!v&c}P}qF5TirW}HH-}`c?qXHh8we<^r%fW^9v)Y~Vb-?o@cCCycS~%X7 z7#*?5#jgGBI@POH`0CT}1L^Qsv>|O1zCkD)c8}+dQx}xMP6@fQ6Ep3QbmjE^=+Oi) z@|R?)|B;7vyX>D9pNK?$!}UXRP=^YvwEKR&%0#6(&sev4LW{ebbc(mZhnOnbkmaO9 zkblL@yw(JX#mz6-rY%51(V-w$zqWkj7W$O=_-@@_-LAi}V*Gm@;$IX({5H?Jgv(R_ z!TiaWDYTxqRbw~OUC)ic-H%hN}s+E=E7f|_tSOytnFpi$g) z&G^Hv=NYK;RYj%PAOmlt_$WUK2txUbeRckxGHVE}#u~E(sK3W%oyAE`;%kTM@SL0t2 zC7)kUt3ME(fO&;Ty)C?>NEWHYpZYQnosA=6-m3KBGwKTmb#GCX|1+ZG%zv9GF&=je zdTdjOcUf8}b#=$#^q@sh>X$mO5&q$lD$@@OH2T*}9+m>l7yi__GkF+whmXmKtRE=J zww~2*83XH*EY=zACg>ecQAv?V!}RAO18<8OpzaSx=V(+d7}8>!n?ePep4-+te6$NC z=l#3o`l>-xuJ_mXqE4KAp6d9frUtjvS(ZO6ZNoK(hoM5->p`p7xluJe9~6^@+2}nA z(CmS2q_%MbD67Tpm>=$evaKKPtXWpz{Z-~CDXzizXhpy|)}j*(Ofp^-KFUOns#n%Q z3f4fXF?&?y(+-hh&x)V?r~|RFr;A78iT3h!*SlfoTA`0wL@RuIA9Vb9`$*iT500LY z_p@^xfVqZqgLlt&U~bOV?W(s@;T^M`Qr4dom@jwZnAu)|TVu7Kls5h&TJ|?Y$-iHT z|2k2!H;3B9R%!zKvTiBM?(l@+*%Jq()55^zLE&twUpvgaIVCt1+X*gDbK0k}V)0f} z$o{V*<@o+Izv3osCv;?7@5nt8gW0Z$Moq)w2tS$DW%OoS} zvC?j>Im7wTYWlOzg%CvLh}KfQ`%Z`?$rDQDPUS&fsmfRJ7cFp!?d&Ul$qp!Qd-Ler z)*x)>yw0DdF#$iPKcvbn^?(Au+tG~3WT;e@9#?g$fGw}$^m|!Hi6Uv?(s71x@cbNr zwOa^3O~a{}%27i0!1!@XV`v+GbV}b}F4Tor#DZQwH}3&f7omhOqXwu^&YEXa>BM!` zf^>1=cu35pzTEdT6Qr}dBV!yJ!8$s&>RV?&DE{!fntY!GJE?jn^X40{q?vozJF4-o z$M0{5k~9A!qU6Z8Ef$<~gHR*OllF!6BRcQ7VV9tki{^r2&7YNmP^U<|l(lM%7`PmL zQ0rwIJjr_#Q%Ti~S3TbHIz$ItS6{z~j6@q3x)0>5U0ANsI@%>aN(lP{_Y?*V!^-|wQ1JLO z?xyqimS7>mW9B_;He0&k{#SB&p^z%fO1SpxM!gx(b}YAk5hwtK*cj^a@>)z2JWgNS zSc@-Y9|S!r&A|YTwVqMMQ9L6qGPUFSEL!K@WY;RH{_E`h4N>xsv-jU8O75Cm2xgA$ zfU4()d3vN$=*4Kt`z*8gsu^xir$oc<8DEY9B3r7y4*Xb>(N_1u4WgcOIhYGoAY=GiqWFV0EZ=y2vZkyDk55yl&30#Fu`F+$ zSBnoc>0FYmSkDDdna?vc)cwft`1a_FW+JWg^tQ*-^O0C%N+~n6s|rhGC<|@ag7GlB zje6Zo4@SN6(&g6c#;2q{sk`mtP^D4O_UBqV{_mIV-w-AL$BFsZ?W}(%N+hmn%1UUF za4zomwz;3qaK*C1U5>sHv-)S|S4cTHI(w+QgRKb#30Zj>)9%99512L{{$_T62cB2k z)wtKc4Ii%-sMeYjQNg8ahTFty@seldpKFpNP!l^>5~yEIq#F<>QSxs;{QrX}`47MKFOKT} zl_(+BbN>UPM6K_KfZIkM80p;Lx>Z~cV-lr83iMT&;^nI|Xfy)!n}6ohS8@^0I3Bmp zZ^JD`%H`Ue-KYgOEjVtB!D?K=nK79L%%?D(k19+6YtRfldy1$Qx`!4r_DsU|+S2<9 zb{$x#swV8YqaT%RCse*Ap2Hy>WmW^TuX)VrMN-k^n1$|o z_y|;)`I#P;DaK0k`kSvVx8gt%ncHWl9xUhDEhF(}6s6jQYc87?;I`wTyI*7{VH)3v zn9{*EklZd&LiaNr$=Ijg4~})fpJlOS#nw#t=u@Zo@_+}NdVkqyB4P-IHm7A`66%O1 zfoI{J<~*2kV|sg?y##KakCp#rL<0KPrK@wth+^E?{4}-R_z`8!a%nRHHGcW?JACZK0@?v)D(4sIA+hT9X*vQmFKeCU zbuI$_?Wdx4##iDK$vMLv!2^)bQ72wSO%zq@4;8Qw#jPJsXTwILiqKo>_a*s=5WMuq zQ{Y}=792U0Q{ooV4Osnx>tl5_z!OEaiyA#hZjrjYZcoH<^;)m49V~#&otL#X!CAk>blb)c+y&ECx?|lPS^7se!Di%RNhi)W!R1Qp+3V3w#gI$fflGVuiI$_X2e+oP zVK$76dpg~t>%eV6e@fdLyYWGp<|2b;4M<0K-VCl(6+w&0Eg(qJG9#s%;GG zFlRe;y64OQ?ue2aH&^JuV_BAKzJn>a<+|v+0&_b&C*LM3M`SI@HiFEi9&P_?OaC{8 z6v6){QDVyWa6?0osL+N}>|<0b#KW|^9~lb2z~2szbK*vY*bxxUYE;mUEfkF|)LW`B z?9?iU((PP4<+yJP`Bop)q-e6a+Ej^=+A6DaIu0P%$X^{Xm@nNwPK*$agg4KA*>}pPK*_JJ-F8oMp_p2mW4SOLm6*s= zW71O5KiPTgN@X2X>vm8+5o^Im<>{|nwHv_G|8u4#p}72Hd%yVQg(}dTk4a4DPelbC zvv2&m@$gk^AolI;I1En7&!frmgupl=zdDi$>%CD+f|+?hcREDIkA(zmLi=Wp5~Ihc z#C@!ciC08K@!YrM%xn~R`<&?jVPU3#%qcy-dSs_m3`T)0JoYr};jC^YoR5ipc5bi; zvJ7T^311*y0XykO1+7A1He{{#b#^;G2zYuk?qVvCw~uAdRuGlv?HAvN=9Iys69!I- zSdTt0Ec>rYwL>500j_9&T`+C5D!8YekK|GBa%jI+V54n{K$>tIt``|gGrVa;IPdS3 zte^;9%lBo{w;|3)ow@wepM>@y*Tsh>lELG$?6MdILao=H#!=8AIz?oUo){M5=)SBM zq2Ds#{&n(f9)eQ-`7Pj?Y-A}UaT;8@SVT-YD3gw!MiZbRBG%v6lfhqh{HaxZEKH<{ zj7G>31HwXUtBZ+<0Oc{?_N&{74&=V`y&K?+Pp;rkcE7ND?DDtCwU_4%}~qf_19}W4;ow??^RtYzc!7V{`g> zc0+7YRx0$qlM7GZp0Fv5Bcb6}Mt?LH3=w`mr9 zw<>9qV=jgxYogXmdCK@$!_q7NUNB^xsJqPdxei2MJZWwXvL~EN0vl4HDNy%NoiFT3 zKDrBRe_5td2mTlL)iJ)QhB4P{*^lOhXcg1+Ufezow(s)cU(6&%{k4azcu&T{#evKG z{A>|;RPst+n`#z{zM$k}AFY9yl#BbT%nRV>kR)T^S;82qCZHrfTY}=DKSd(Wx?_5w z|Idg&iJ<#Y^Qn7ZI{ua|84&I-08uK^G0}TfkZQX4+<_qzQr2U6qCX_z0l8-uTXbGw ziNsrqvgtHn41Os{^NS#1+|`b;8iCY4nOe? zM3{EokPK~3hd!&86PtP2cwSm%oBSxzLAz-2UN<2XvjuLl77G)5tMj%yo+XyR(V)8e z!?d|5o-fxH{W2SU9_*>Bz7q*Qzq$&&-`6_^D~K)L1qf@7<8`|GbzCC+i$6H!A|sR6Q>@@NEO?9DTMh@+}wKXuVZ4 zeM8~JmK3V1x<%kYI`!wfQXy`e8n+SnmWCxOXDjoi?V#)7p2(N<^>{}9z6|}@8oZR= z=y2zR4IX}8pLm#51?)1~Av*|cxn1pn)$=%FV4i$yUeQVpWKX8|1o;}nj}Z3ihCf9> zHFacq$7Ujm-Z}Pl_l^t<4SpwRu~r7BV|a5bvWP-*M66LxXfpJSG;xh+7{ltyN=)+m za2bnt6#ra#iJ}5b7B8no`whUCBXu#opPaH}x>&OGfsY zn3g0XqEmd^J{SCq)!wHh=3pv!b)aQx8Y-=tiC=w6xT}^9YWJT);O5f(W3ur9V>^7e z8YZ;CIg1EhsSzS&&@=4}o>3^(ApeX{C>b{2H=0qt$ibf@YkBsc3b8#W@68Rv8e+b8 zXrSP-Co#2dfhm@lsV5aU&>u=X zZMVjt-autr+khL225ZdF4iJtiCW#x5Q%itF&pAgu>mAe!cbN*8Mc|0Ri$8;};^0}5 z<#H5}nHLtfoDos)0JEp9W>j63$noWBoFf(CSV29BbE>ttnDD=Nd+&Iz|F`cyMU+Zb zg-}RTW+JV_Dn%k0Q5hLoQ6dp#XU~vL_TGCO-uB*mWMqWWKxygsImCANNOXZ5|hqKKFSk!dSr-z0Y`3@IOj|&)fDCKJ+09O_vw^ zOs*6G8D~&#LQp3`n68)?1t)^Aiu-WJ0MRkF~^+wkwo3_r-$7zQBg*uwY#3 zomio{8wa#BcOtmFTG6PEOnOg39@0*5bh*E+f=55%^{t7NgjFA^+q-;AU{Cba#OT)& z^e=z4meDJQrFJP{2Y=>)U4-M1=BXg)T*){@eWMZhMfJC>RuP-JX+gbbjPW?bum3q+ ztqBG+orQ8IYT!~%g63791T3MGxac7k3oQfT*L^P!B2`fR`G%Gt{2J9zD7KV=%yBoR z&upoHv8B3G18&y%Em>7;ogf|fVs-nA`;xGoKBG=sHxX6e&7Uicd=BZBcV~8m`Qo)b zHV14qh|P1yE_}9Nq=j;I}ILHvGYC5s)s-4jxv+4dE)-a8?1Wmj+lD=l=z$94R|B? z6chEKI1Cz2zw%2b2(@({Bqa}*<5G*_9=9tUu%NeZp~)x>kMsU=;+V~Y@i#&b<0xAp zX>?JIg}wuv{k0@>^GR@pEP6TVY!~r+u7K95Byd;N8X2xg#06=!tGzR+NPSgzGl$3% zkhqv$kjdbWqXDoRtoY+jWNjWfY<}9p$6;iV~C|zxV9m&kA&%Il6nx z+-o%cHYb08ofs}q+>#eImxgyURknd+9QrU*e zO^;i8Lzwc(k)_ns z94=aZnTX9DQmM_s3799=SbU9em8qQxiZQrMkSN@y&Du-3K((`X2mN$B>^{2R;9P4E zqy#MMD+AVCtNd9 z7r$7N5}CuPbzRDaq#dD_L?{AUSn(&Y_}8Gmg7n9_?b*03`ULk0i+o@#C~W_=l7oUaD@^-$ zHN!8{cHL&vYFM)y+>sKI2YWc?%CmpG0+~Zi^B3Y;;2^mygW}I7$Wf%Vy=R($^lU9f zI$bqzL^#yEuB!^iVq6W31%q(i#`hpU@txhG^Q!$I@g0&A66ZYjCjxV+YVSY)W(C(L zhqJb}zXHC4Z;y# zpFE*F&FfkO&T3&zTX*E3pVmz_hL%$B6i#pWDpUj=wCnD&(?rCw)9$%hqJf$gE-P!( zS%Ve6+fJB@JwTy;A(<00fmnSq+=J~!E)?v$T5P({57-&$n%8f(!Rhk~JA?la`w)A< zjIGy`AgsyxW^QO6@a_V~62l{*1*!>Uk~eojs4c>!9w449pCsz5r;mJdq)i0)6A6i?J-LA5d&)dw#Z zc%M;Ulq%bYO6Db2YmT{a>e108lcj2)II;ii=#o1YKPnG)*oi!J?uhH!O4`v!VMU9F0qi#HUqQEU$QRqmV)`I z!xELoZSbzJfIq~c7B9}O+T3EQ1|g`Rtskz(Gs!Q@G&X9nb(Ea^RB#&l*E>IbC}#&Y z91N;FLx>+QuJ+dud$JSUkF7TcNyJ?AH_Ek)cqkv@TBTJ6^#?`&ynjv^6!8IRMd?0~;P-oGJgW{3_s<`4lTU%g^+kEh7X&#Hd+V0$ zN-42{V-)5ihGSZyj&NGjIJOLlivHUP`Y$I4hN z01++C$A*>!hbeqdm(*4Ys^^5|4#o$g7u~7KQq2hD`{c|utyTpe2iMI*3r3-7!Tt3n z!Ld!rh(*~2Pl4y+8+HlL1YGEkS-f?q6fgVVjeRy*f|icQ?RWf+{FhhbUlJuI^KbLI zJKo||?$PB7#8volAU13uyB&TRtY$j#d<3akop<|hO~F4SN<{w0M9HT0(Pt#5HuMd+ zs8D*Ggcaj#r6z(!$jK&boatE!SM&FmpS;xta>=s$+iuKaXQSNRkA&0Yi!T@b^v}1j zU_s;Y=4UxN*yq=iovuWk$^qAh%7kp}zUYGFNINY3;xOqbB4I*m$lB#wzEFC3+ZEN7 zG_>vvsL&u{i%VYl?^1Il+R6T`K;uBv619I;%J=u;1@f|nYaN8#E9Izef=w0X2xUzK z?XG|~JlpK=FLyv(Os0iUA_=i)SIG(Wm!gT;c;5@E;R-XNw})_8cy3EYR6k_d zeZDy4OZceGHuCo@RuZE!8jq$s60kWnjB(Ry0zbr)e{%QvhZz5Fh?4)f690XoWYJSl zFUO`21BYCnc%`-w*s_Q-@0%`==~$>;muf*vQ>Wl$#u!LDkePhix);fp(;wQljX+U~ z*J{28qGsHYl)JLT+iaFav;0vL6n5UYkYU}7N8U3JHYBz|gwxgcN3_PU$7SYbur3L& z>))|gWosj%d}71D83bdn2>US!wYM<6yJh_&wu8-MJ1&^X((;d3q zXTxz;$i`c?Faf-FJ)TZ=?uOAT58a&qjpq49GVhvcz!45ls)*Y)*k1fjxL70^wtXz5 z;fQat!oQ{0&|04 z?`0<}PT2*z&#%{Radk#>u2bS>9|$W%)2HRjZLd*0Obh`&!3iA>38Jh#g-(2*L6QMA#Cvbc5P&N8I+sT-;uJT(rR?}7HL z%NLlLdqA-J#(C~z!^jiYoAUV70Onp8Ieh!XEbx^d4|FM8#Nt-);da2@6w| z8y>5hSbb!~;X&Xxu#Y?}@Im$;DC^%4CI5cl|38S5|8%T>@vHt{i4tNx_kTf@9DX%3 zBqE%PuCKe^uByF)FSidxtTz^byxu`+suB`tjd$xfUM~jD$Px#JvQA7b928t^>B2eM zFQN?3`#^;4cAjO?mS+YKwA$lc!K(Y7AEciSW7jnM$q zhf{osecXd4iI11`qDmo5JQ;39Rf1re(GgF!ezd&uvi4$8JQOH)(6gyX0fRpZo@({FKqdFhh)5Dep>g^eGUX8*>pcWBT^_*vYgUQoXclG* z(3AUCHG{&=v`Cw?B%nOvd^y0=6=)`ZXI|~CMyt%Z2tU1ijMFWbyu=WNj?}%5zWtr( zdH!`nyhaC%+$ypplkWzWD;17o{tX}-+%K0LTMOUzH8SMwt^-Q13Jc27ArSvzXs(bq zgu8Nj%J)@u!Ma*MZ}Xo@qKIst=W9nach0k1@m`99Fq8KiEOzx!MfM#PSn8l;-fy|f z(hZ)qziHVaR7}vqwS{NSx&wCK-;{FiK-nxy1I@3c;P&AAuLHH!u&6^=Tr65pfcXvU zW2-K7nrofPnl6Gz9=1wbuXbW(_{nh7z1jGq;GnfH(RA^!7qQPA%Ef%jLFj%E13Z^r zt)4nR27kJKR+B9_q4sLcbH2!AR6m{g*@QD6twxQ;!eUFX!i4V;o7f=CmTX>n-;_oO zaW~Ao%b((pC#J8uZ*;@^`Vk4yj$C{&op6ApP=%!JQhQpgY9RGh=pG%mTx2loQo9iy zkC$!=*16G-K}KHh_O`k>Sbb7$t~pl%%E4JeT=|5PKl=@L5vdjO9{Zn4pC#69d#-ca zJ!yn3G}k7>O#9%o7#y=*4@D2^d)fVE4iKyD&HC<2(SF{cuFVd+)iyTL}}YdzPU*@H7VVgfF6;~*Lsf4wfP7OK_KJQAJ}>c^_AKRzcb zz`y^Nqg4x0G!Wf1Nj;qZ*Y@gfh?4*GIR5KT@P8pn1n88O^D@%m_G+$DByS~<8(-#f z*){@u%$~}28w~=>rSNzAD&~m@5Bbujky_MRjWOo7NyRDU$NL7VGtj`1Z{3k-n@;v~ zzh^ff3}Bh8MV1L9l%R2Q+qIsJi+j?nj~Qn|E)V;Jl|?*s-lQHc7a726m7#jogM*O! z_{poseo3gMykMbeP!8{H#Xg*vi-p1aer59b=(7b%YPyUhTQD-$zoiR4DBt$5 zmmmg_DmhLciY+CS(lRmcQ%Nu*y}wGVx)TSKf#-Z~8B+6eZiwVH!ci;k8xHsC;rglD z%#}N9FlYSo;@#VspvNU>pWstYM4IRvE;!hSda3^TJGeR^PAyP=tB&+eh? z;TY&TV$yw~10G&@a1o9NqWfndhhLQgxUWs%c=UB*0G8&6k?*hO8)MS`X^ube)UVb1f){`>sn}ac`!gE6IOrX(x~=BG zRJy-u*l+wbO1b!PV-`Nn6J2|aF`$h$;5eb%@Yl`#H-(f_|3jigewJ$c7)2QlaVc^| z3)i9OXO7_ev1PD2TCka0+>Dj#C$?W9ZngW@9|RV~O5*d;53~Wx#G1g{e)|B?;r?hu zxml5busr0^d%zty5MxdrmMPF849R-@8g;(gRS9W zfAnb@Q1SBE`vZNoa0o80KBj#NbqA*h_&$|Ffhl#>LxUFR(i6B6niB^w6G~FrUx>Sm zn|X4-WP(_7O<~LRJg|r3$}RQ{xKD=sNn27Vx^%MG{Isn^$AN31PWy7vv}y3G{@4rP z)RM~_7fQmNfu@e(K6S9GpeX#ywswpT6^fiWPzO@4)6d;s@*`3OQkD--=3rGEJqsNb z37bw|6tq&UM$7EnZ^cQ?kn@Vq?e;rejN_)4>mxcj?gqPDFBcI6&!ZmF70X!IQnOHW zVz34hxCQdJ?al#{Cmx4o2+y2^=AOOK_QRjtVV4G+gnUv=t?a%V zng>O#k2v5B9mASCy2ViQI3xA#ObYx+Z5;a%m5wb3(h~Bv7vs_EpLbqpZ9@9o9Hp}k zM9p)@UFj3`-pE+!b=R$56!R>cvo42b9-)2rSL30oWO$;- zX45BbhsSBfww^A|LY0&5R&I$IsK&dR=e<%xG<&~{o0G~wepJt9M`|-9^)_Adm(4(l z(Q%H0TcgnV$q8GUE3c73e`msbwssT{&e$`Q-;E#t2Ckg$!^r>+~+jA~7%?opc* zVs&`jHr+GjIDI&kU+HWiaFBe3%=cH|aIKDk3SS*OYP47@4HE)py=PB%j@Q8KyYOe% zmzz-P%=7CH*bx1UM@@~WuVB=ll{Z%R^H4`}ck5}fuHPcC8Vz$qrb^?8+w1*tX`8K z?C1UN1t2@tt#vyh0l$fz9ucW{jXP#HtHc!&usGm{o7YEz+|X?j<=>wO@fAKQ z>K_|XFst*GU?B-;63^NUhjyb^u&k7pg)11%TR%DXz8@Om4fgpEqk4>ob7SPO1;j7$ z>g!P^V}z(te6h@xX0N{@kBzicn2duyy=>F3_1CR3H_{VHES7 zn21A7puT+~V!ELag~>MZ26pR$vVEt`Pb`J2364f|#Pe?{5w!1zbtHDW>T6cyIYZt- zM)(l@96zupGsfx#5}h4eM*dhr;yMxHnNU^&E?L_*=IYAvi5<(y&5TH>7wR)Un_hsf zeix+cS^~kB^YG;Ky(}2g|9PifBoQ{ma_dhOmLYq5Q$fd{LO5rs z;N8JXM+$`E(R?1#EKcU(icMFe&3reAb4ll)7bhXh4&;1(m`JhFa`7AttOwm|EEL^R zwfOk#D|!>6lirZpko=~!4M-Gfhm(KALCIRx?VE%rw9hmnYQr%I%4etQL&F-eG`ao2 z_pEf-Yh68d)QcGF^KQF#bFl-&s5u{4mX<)He5Ff3K@O4bA}n(6S3W#96uGd{Pzn@N zlS|tJ8!*Ro`(>51Y77g#6%kB0#`a34uiNb)+*b}Yx{RAyC>-MdVtjWu{!-p`d0wXt zDk_W|0!)j+xq|fQ^I9#OXmz7b_d~e2$+}utl!p?Tl}s1+HbNv9qwmc#$#_+B zJ37xD3?r-0;-3-@Fpe#sE#f)?fl+5e&T_3C{Y#t0&Z>37*I!M{yq1;V=}FEJ@x~So z7#v`izDp3Wj;4k0h!L;U<}dd?4wRru9i>H%X&Vu#%6CaHu^J!#VRp6-iG!0XyezYk z;pp&t$XI>67>1sW9e6{G14IZfGkl70!t)RPLR{~6g2&1E-)okgNb!vBVun*A8V0?x zeh~8#59p;|s%J>UJU$mnd&x@7eo(E}r4b5D*XWNK5vdFmr%mj8ZyyxrF3l}8w?sBK1}K;M}BsrK!-jXkokIk+Z2%s z;^Y_ZapzG6M3`1*?3{~(CsVF>zfb34Xxqku@2Dmus|{_jzSjUWG91R!T#d+#GkTW{ zGa#=gbMX%0ez`zTB`TJdhi2sbmF@k+dy8?86i-Aw#17NYliVM}ZN|J%; zD;tGuX3m=5GUMQg`Cm+Xp*MKmsiug%_82B0*m-%T&I;wT| z?s`OMA^!wDb}cKa#S6F8dA3|?gKhQ~K2^06=^a(uX$RjC0{7pHbPBdngqyBVF1Irp z+2D4|Ps=RCL)={Kc?~d;&;RT3(^@=yp>~Rj>k0C37m?{wltO?U&v{PaZg|PLwM9XZ zAeu;=-}&mi(B-#~rUNMzIDZGMloIZYtld-BpWpKakHMZJ<+cr=N)n>8KVOOGS9?#! zvDd)j+jAoU5e48oP9~bSr3p(MHAeGx7J{mF`O~JaB`{DdmUtl|59y*IkH&(S7d42n z-4lcPR+k!9K8Q0xshv1_^)qQP+1Q=fl zX`IY{0apvh6;dhkVc+ye*9bBXAc-o+Xj5CE+#{ZyyJt#~m5FzNR=g5i$s0P1%+m1N zU8b=bEW@Y+p&j3ej=Fe=qwHRvE?lI%VwiWL3MIvz+z)HEgP_9x{>jB);OwQR$jvQ* z&=MDBa)H;dkCvaB??o$kNO|>3iIhX^)r&FYrIiqIF1!43MIv}#xDZ~h7zeU+2P{TX znjlSimFQjNLwlNt`R<(rffqoTJ0n_x&o(@@uMtFX{x`!Hm)DYz$Z(z-JK<-0W4jd~^6WvCTuYr36Ezv+zj`!7#a`Br02G8dH?O%?VfJPb4&48w(YhrX>@G^74r zsspjpW!NbBsS#4+;P~r|Az!&DNKv^-B`Vzllq_N)M0XeCtetqKOp8En!m!Iqt{dF0 z=-z+ZQw+IxQG3xe3$<)ZCuTG3!RwDOeQ{k6G9)JYC+#8u-PRBMd7s;0|MLZl9D?vQ z5;tbpI93QvRKM0;=i<>Y(lau2524yW!qU5=u?iQr`N^G#?SpTveJopz3W<_MZb9tn z9Q>Misu;hRf`yd%SF@ZJSP6+~R(kdVmZRlW_n)eOX-35d(uu{e<^5#K+wy9V$~`5T z`7s%%B%2l=cV@r`=HBlCnt}K$cp;!*fpFXHc=CAMDh1Xo6PQ^()x+b`gckYbG90_% zv2|B+I#Qf{$nI$0v$x*A!rG7OLL?jz|-!)fZ$LXZkT=|K>Usg)FCAJdH z-aj@glMBYOIVL7{j}#0mi;~t`%?7^ntA`FpwPDZ2bK(0Ldhx@jn?0LV#dzqZpb2w+ zDc)LpCv9YiD4(@LH55{ZOQ(a`msrDrmrvkTmQohdDX*vfnz>X`V?iOa#-zURq*0FR-U%%BacT1J4+DMPK|yyw}Xs_T2M`!$2iJ0e$0IOh|6r z2Ie_<&NeXWDUn90-v7sqjwlBHFSOuqh?4)PL;QA1G zE^&VY5{~mRMh5e6O;t_!s^M?^*pUAF?4KFnT#c~sE%t?NCV9h}4=4Ys$^0J*A^$C+ zgwn6;txqD{T~&N zDnO{yXW&G{G{H$djxD$7!y_737z4?dp`SS{&oU(j{I*^`L<-Nt8+#ACSX2)~`5h68 zMvp;E4LRkqEv^?U^H;CN$jm??JwvKzPY;-RzdC=JGX-*P7lpE49tT;Wn?Gs=LxD4l zD4oq!fXFoHjH*IDPXGXn5QPAe(g4`MNO);}Bcx5~OM|Xe(aWdOp>Zf!PvMn74Pt6hBX+4^}#cYW1K?t?X z@)Pr%9-R8S-u8hWo8I->yHzMYFSlQyViHoH{Wi+paKqDI)Fn$^5OL0nh8c%A5l^Lp zA$wtxyLk}E|HMgLUqRq&_BN}y(G0V7gTJ}y0)STHBvZ!mc)YK(HL;4eosctq?rt9* zhq!xv>ROJWc;SWfTl4Zj5GXUIe7k23rL4@SUowbq1cz)akSBm_+lyUJc9GG34hE_x0gaF7(a;nYokPu`PWFV$Z%z-_Ar${_NK(SOa#H{TKXCKft^M z>L*$He$a9EY_5_yAx!i?wZ=^`h9VClIrq#~U^Tm8q@`aeR9V=gR zM({PC@(qQB4pffVrSYq-58n%}?C$3cMs=@Z%6l!ui1-aQCnJ$Dcp1o7cBpv(3>(>p z5`_maA!cOcY~m!wo}Kc^r}o1AO24Li9<}2~-sKUn8ef1 zOS6t!DE{?S{EegJ)c+Mx^0c3!;t0{Fi&EcGctE=ZS4my8odR=k@^?VG2t_S^*t4T+ zPP-5GKT^--sTly#qq-Bj>bvptk#CNDSy`BNGpy^0M-|-rcqWp+{tc8Yw0F=GTcwNg zjW+&J1cx2kcU|1w2Z{cf&uC?OP^s3thnLuLo}98Xc~#N}4YyuqvfQ3P9tULLx!r=i zvu`{&1d||K=b=R-M>R}dU0QQ1htCx3u{)Q;|$JzUz z6D3nk`}MvL#6qdxGj$$vA|W_e=MkxH1gW1c@2Wob4)kbaDzBzjW2&MD&6h{>=s8** zHzV5%4qH^*A2yL-t!4j2jc_-b36yDPhP7hVxjjd&WDlX_?%A)AE`vBbMRvs_tOX0y zWCEt_n()S#!WwQ@LPunjxRIS34H=J@MOcmu!krs*!8?bSP~f;s!SZq&9#r_n>Qq#R zPxxc0ls`4#H2VK~#Q%mU`QJ{=ziwy! z2T>An_I-vV6A4C#CWEZSiC&09Ba>|DFml{5-sCbM;-HyWNJ3P@P|BZbyNjw5Tq>mM z1oVo~PhhNHFdiXJXTF2-Bf$@SmuuB3t^hXPlVt&K25`Zw$&&X(0{Y}I#)v)df${Ut z*;eE8@K@@0X}`5WIB9%j*SS`IJoRD5g!cY0Ms8hosMhVothd<)8jT%5nEPX9X~sb7 zCu^jhXElz^=?11pmV#i9u;p|4VyI_dsg5M{6C(#CW9||@h5vQe{0&j^&ri-jS&;lY zQBvQ3;3W(D7730mhbdE*-(nMuvbF2gR9Kah^jIWT@UGgv**nl&BWZFu$oMOuq+&VkA$D!JW;;{^C3wlFDL@!uv&RFw9`qi5~z<*_%Q02rh#QfOqx#fHpqz8v@ zwQWUw_2|(_iw`q6bwf2lBD4uBw*~b*>W{<6gjM9t@jM(He!(?8*h4tKI{LZa^n-DI zn>VFyE<8BN0R@CW@yZ9PEfg@p+yL5VWg2@3V&u)njI35xA24Z*J6`4| ziez%ft>Ov6slWTOh+cmXN?xKEh>jql+*4bQtfY_Pq0aLYM%$_(+v;${Wr<SWO)p-MsZ zU_n_oDBpeal}DllEMFE+#Qr;h2|{UgM@3Pxc{wMCT@+0@_x$RMz}xf*Qua3;+vW|cWtL?KZ9|kb)dD`51 za2^UZ9RE5tUkJ^~3+`efRd9m5wN^RT2U;x7guI{Z$I38a=UDSDtBIO5_ zcgy9YOVJwr4&_uh%=vw2ro0=DmwrCf;4p=g`aKam!L=aD8sgO)(vA$OwXYRkJ%x=+ ztp(2}`*7W0gL$~I9+lI*xvIkJK#>1#HC@**vh1q9af7-AOCt_cza2;i^K!$G3zO}r z^2yz>IH4J5)2}@fYwZN1Bs1ox{$t>4OJnlBv>jbU3xAjG>jkN#n%~1_ZScsOyzbJ` zT>NSLLxMK22SSoQXJ#$cpvU$R=ZN-x5J@K6tMob#jQ()=84~xghuegX&(C^vREoHG z(>D))lx~sbog|8m?^GPiRk{fwZ|*$3aWg76Gg9y7n!q15+a#Yx%C^2A(ejcOkR5}Zgwro^<>rH38h=7kJJ}J zGIgeO^G5`umc~A|nr`%MN~v@RAVEyPcc~{nNq8*qMTc%?37mRzQOGT}5#HEwOjtcH zLe{;szbG0=$eFjWhy@&0bk~V?>}#1dhc*LZaQ*a*R`QQRXkRti zX1&&kM|Aom)GaT=7mYN@zN$fxl;P#{RSJRJn8Mr+2Exj-v(82EHVI4y57KwDv|;b5 z7|97G!oaEFa@%0mI zZ7w?JKC+B?Jx<&}*Ci)}f-r~9{|Px| zH6(-!T(?;Wh7>nb3z=7Gu(d5QYH*7qycN!>k@alEy0$ho?t@j>aQNN4YB1py`(kMA z8(Ir8&wKS^id&#e|N0IY{U-QHU&-E-mIq=tSf2{KehL>14C^WD6QIwisLeqy9)1j# zIY^8|V?rQZ5E)Yi${aoY^Yc~0QoU$5VcJKCymSXQX$YIRbl6%<2&FFORfzEKudBc- z<3&=p&eWqzbnIc#Q?;O5_H8xe2$81nagi~^Gnoij>ff*kF2pwu9U%vMYf!3Ii<+sc z8XA^I-b&BZVy2tAuho_oSmtT=d1hV>hon^PC29D@x}Q+m{T<ZtSvSrGSuh?Rj)kb! z?F>ND-q5l0|JDh)Y%kic?+cuq7RyKd>QU#l)OWepB}l2>OBcI69WL~3>^?S~f$RIy zuY7)z0XqbQoXCz=;B~{DX>t)Y_>K4I$B_rINZWLjGD*H2)yc2c%x*)7-0^91|7rz@ zf5`tByq=F1uWjhXj}cLW=LDyUZ&kz182h(DL%HC6RkZ3kZy_dSOS$aQ4ucONi;I4* z0`R$I{s%L;3~b|fdfGew5>oMY9pBFi=x7W#v{1-_(Xc|Qm4k)Ye=YVksU#beBJQ-k zen5DU9rsPJoUR7;pOcywN5U{kw9_z$xUNrZH<^96HNn<#XO8Ww`Ix}T8=!9=0`l(O z(^u&o@uZ-_k5@ zPum!)C6UlVefO)OkFPLrFGED8QxSRvJrPmeEXQ(*xyKCK%klM>OM8538&J$ptu&1Y z;iW!r9`ezm0*5r-91g!6j3-uF#~Qwuqtq|egb|uz@Ue;)-;jHZheU5J**3oh)#XSx z`ERKh<3fKmydoOhl0@jDtx34@w$_xkgc$K^`^9RWkppL~^e#&C6Alj-F6J0b!V<3G zs7kk1iPe$Uc5#Tc;ZQsI=x6Cy$i>4hr}8TvG~zEu6||F}J0LwBtts16xEEc~%_iASonRn?TMZ8%1+Juh^PgaTSyWnZ4mg-5$$k1nuf zVAIdSG)L_u$ea(oz;Y}Nw0VcF(zRB>lm*kkxq0Hb9vG-T*YV5>Jgniz+L>SnIIhU{ zma7%dRhtxMCe`DoeKK=R#T7UxPHL_4tcK{fp}`M$YEXoUmapt!3@T0K{pnB1!_E31 z;*C6w@a2fKRiFl;51Zz=XYe`=C+{oEUgz{cB@>sVnqUV4w~gGtjfsFgO>?@B>?Ts2 zD2C3B8ztgCclx_lM~D%?xg`->jtVFY471y*S`T?$W54V!ghB(?>91Y~YEhw6>cYom z64v;ac3O9~fsDbF*-BghM7FDNDJqshNY}{bDNh&3%4g;3dK3v%5#&z4?Oy^}!!I|z z`>iOw_smYlBo6`$rt#nUtOzGkVz!c;7KWA&wDZ!8H! zgq^0C^L;RC<5S>oB0YqT=gtjZLhoi!{#HlqMhX&<_C*05ahRkYE_+;oI9~|~F*Aff zx#Cjl;Cyof=6!s)F700pedg(oI>XiA;w3zoLz#!ZoXPdic>5vng)zU`?jpPcf$qsM z$*|%~l8-%Hj8ZS?cQX?(?uqa&=bR5^FmI)4#CDZPVY?BV*Pzsd9y8(QoY#8cJ>M>g z;SoomY2e;nN8r`(?Q$pU)Wd=9qwo37#3c9-y6(LvqZs^Tj^7X6MRb(Hnx1dVBGgNn zKfb(4jRHT-WA#rn3PAt#cSi%qCMk26mwA=bQRW)sBT-f@ z=x#nLV>45LK739C{tEd(h&Oz0+^L43@U`VSLJD{-?95L=pF}7-G4}lG)(GHPfR|uh5>C(<6g?%a(I5{`p!@zPpJE$H-1qg3#F>&+VTdovC%48-$A?>#@V=i z%0If{VMpH)SEGFJ4!A7&ep@BxyKQ(zYL)@lNhiMN6v0T7A@X~5ts1%5_6C?;Z^UY##IPe5%=*7C4IZzA6Ps<{{ad@?40lJ!ULrD; zhf=C%2U#-a+_W*IE3AaDZ%U0mC9bgQAr<=gMFJdX+8ey}SORLC*H1Z1T%T{dbzX(9 z6hT9*tY-P-O|Y%G>QbIO0<8nr7h6pbyAD3nlDwq>5!y#})@3@e`lG6CMtBumY+Bt< zX;gtF_j$$Q-Z`M|T;7iEo2@VxtRxqBH4H94kvQ$)*Z@Y}>t3I!Qqb||lV`JDeej@n zP)$|Xn!S!$UyP0=0a>ZwRbh|LMEJi z8c|id<_JBPKF!DPH3MF{F@=D(7Em@(`(D6T1QSV=7d7ka;EAriCdK{+7+n}>7yR6X zMl~EocU}g=nf(o8>KIQ|3t<|uCz3zBhWSPEkkv~Mk!E)L*sY(X=+I7UTbS*M`}$A*ChsCaunj7| zF{;FgwGcBspO?Uw!(w`;^3mKSe!B5qlJ7$g*Lz+an zbSvQ&BXe7>Euk+)JHDW4U4o*{XK!nk5v+m7@QK`em2iL=`oEmhCtP%r+H5uVFG6R7 z^Oe#g!N4tNfRX_|UusnVLHS;~Pc zR$Y8gC>lXx+nj6>yDJW#)czoJ+YTg4S-)>GB_Y$nlnyeR3VeT0r@xI@2WtC}(Yz&G zci!vPZ|Hxt!N$j)!1T}pY}{TDm8eq#_Lp_IXk&BXxlhos`Y$=i@$Ko2+zV z7%$#kl?q>N^ZYm$Jn)^bWkio*71o*RPaZWS#I+0Kr<$G@f~L|ZW}_ojaFZ<~p->|n zw@~jY76)I*SlmlzalHa~(sP!oY4SjIXwyq%vk5YprpfP+h=Q}eQ2f3u1N^bON0t$C zz|u$e;M3E&nDOI7rh9TbZoK?-wueZ6-22oe|B*-z^0Wq=klfRR$2awaWxrRUakoy2 zJT)SN#=)Oy^NkQyYNIsAlz`FN9ZXK=no#7q`KIp9P}HJI(ULFCg9!FZS?QGosXSGF zkL5K9KlJ7$Dl7%yQQMa;BAYF^9%ekPsqmkAAgZxsU-Fnk3)s3i8Rt3d{(hpk5TQmYxbK=A{Er; zj?=@0Bv^{Q8?%SC07!aQ^g67HQAj6y$ML3c5NA48^fku?4znlu%kR!0NX8)D+`4=$ zp=fkeeA0k#Hhj|^o$7JF`Qr)p_r>tCV@gUxFc4=A^|wS?RN>Z(ZmhmbiSR0Un5{_1 z7x!ySK$z0|`>3l^CmBJC{31{EiH)O{uaUsdn$2RH_w}%yB@T54xUkX8V%=8z5nwv-U|<4F01sOeDsb~dtY}Y z4;${^z7R$@llI;EL)}4}0W=eyVQpniFg&mjRZFcJb#S-W^*rQp`9|LESpalh&gn9oDR^%3q&xr<^qjbj z;$tALui(CBdk4;Gp5Er-kpl*55|c)OIavJiV*WJa-_U~J5GDVpL;OV{#F8;zvnPh+ z$`6LJPl1jLsCJ+vGvf6-42#I^-V!$lLEk$HrnJ_9<Wj%0YZ)(xN>Yqa6vEA<$Eq_z57+!QE2T;jGKu7}FZ z2_5C1<(T@FO_zRO7o1yvye__@2+gG=9!9(LLcCJ)*KIqe@YFT=%u8BrK;E*VVahm$ zk>y%0=T`<`K!YtS{OStOIBf=pptgPoj#h<6%{Cze4 zk|+_U8)l+s{Dw})&MAr(w}GDgFO0y5)u{7#df?CvQu;h!+Z1$>lH7C1a#gxx%JHdH7ZG#Mn!jXwWuv zmRGzp0B3233PRn9Tsmgs-R25^D^C6fE&0ci{Of3mor^}@+1u?x`#)JT3T~OL-viY88BoX%_M=Yz6hc5F-ki4tN%p7MS;V5_2RD_NJ;j;sQIV z^emAY5Si$2v)DX{ltUDIuGcpsYxJgh&9iw><)|Ln-kAhyz1!|>@6N;>^cCAFnud^j zzh?1rbrsQJJ0a7vgJ^jTrlFyAG4kcOk>2$#MfIyqTH8!Z&`~XT{a$$o`bIvKPu-M< zt3pyt-zTb(eJ%V(>(eq6(i;1qVfhZXk10vBD7NAjrMrzTLPVihbyME zyjH8ve=z9{hL6LK@=d>k$2+q0nlAJ~yW|k_M0_WxJ-J)V>Cp{Z&4iB z32k1%u2Wn+1RGNm>ciKE2W-C_XmQO(ao#N(6zFYQD81bcmIf5HMVl=Qm7Ypq4iqh z{0jymwdp^Yo8O=%|7~Lay07&Qv_$UCFDtSpOUyd0h_n-JSif@igZ2rca%=0Kp(4@+ z#k=}%OW)kLey9(L>yze4Tp4Z_;QIQuq77%sFYMo?_zq*~s&-l4 zX~png4^K*cCBl>@moMZK%(Ywm=7w8z)p&eACW{doNR>@J@(Ds6K)>tZ=S;d55P2M_ zU-j4v$#-^Lvij@+K%?;?zM&2GWmpWpsvpCv!K}|KGU{+v=)j@-h7&kouFT09+=3?F zMY`Ru$Nrd{-=HObJ~@AKGWk1N(j}N#+15>3Mm-)z?ns!E3{Yz&9(5gtT_wXV+hxAP zmqMpqBu;;8`~LEd-fYZI@EwH|y z!RkhK9Bk(mHn#~LcZwiLta9t_w{h4kclZy&te47EL ziz_CE-3de(?mk|Nwlw%T-rC2RFaTYI^6z~=FjKlJM8jQEXM!`Y$%sNeUHx_IZe&2mF1Ru(U z@dS>y1Bd!yX;HFP*um9faXp|G<0m9uXdL+)RQ4OR2kW0? z7l-4UFd#^pt9Gps3cI^J7B}Uhki^p7RHHg%v#@d3dLDz42EhSuCn`Xp&Z54SKOMKX zQRQ>h=HLRQ+ikn3cHq(PE2Q6(gsXv9rhSUKu;nOiNiu<*T@RXjPp{tqUFI*8@0GQ{ z)(g8QA}xOunC3 zj!$Y#iG4TU1NX3e=6#}DlvyqiXG<`Dr})VNzw<^qUxU)&km6b{Q3KXp(jkBBHVPN&||f-+!7x0^WEryny+h+6K4!;7AC zFYQZF?%V0KWA#K8;GUMT6H_xjTRkVidaD&Z>c0tpple5N?OBnO$w7RuvH$kg`_-`T zx~cAlZ7S?}sQ;2|Sr76q@7CoSOM|C*Z5QH2RWqqbk zJjVCFFDKQv;2Mp{N`p!V}OC)J^)f z9PUI)`!}{{uhoN4Q;gdS)p}&PV1GW&yBne(G3>qdx)$!9YEX4;tp`)#18+92mBY&7 zqXR5M{rKyMAFt=NJUBX28(nzd3D`XmrMNhl4>_KD)!#X_p~KzRk@ECSXg2j;V7R1= z(4Q+NE-pr*EVEMUEL}ZJ>oQp`>ck=QsncuZaVeM=NVIIOG+`O@gp{dw2fSW?HP1za zh?IIgtg4{n(YB_q^4R=$09)bh8$jjih(6 z2*1a{1yRAMlbuj;d8~2YbRcdhH|hu9u7o{O$qxf!qJZ=9tfEeph$1bNfTj*US}k2?gzU< z$+3$+>NzX^IHCLoE%~R{@vpbw|ALm-2A%ul;yr-;tvf#C#kHZl9^h#ukCR9;?q7XQ13hW>kttTRB|>tY)d1^ zwr((R+#znl7F|^@my(gIKBTP51|d%-aqvcQ3()OW<;dIW0&{(?e0IP3v17Nf=v~H6 ztkCR8kumCrrlU71wpriBeE*4X`(_=app@>2o2gIcbM_JSkk2wa&OAdf@a*-~z4v3_ zq21>A(}aip9bL`pSQeu1c(p09>e7qzR2yrgI%yaXug&`^djJGEE8?i8 z>_PCd(v?dzAF#M=l7aOi;lV!AcDi)Q8)6Obx?X4W!ReHR;;JV-C?lA#LSH!oMrS^U zXp5v`Zb_IcOIHbotfqW>;ZliQiA#k7&7>TcKE-eC1L%`hGwl(=)54At)#pP#-oJh+J11CNly ztBi&4Q&-m5=|v*EJK9@1x+M?EkN9*^^Jb&f*&o-=Lki)q2g8jkE`gUFh>+ zr(Vu4|IB=-cB*lo4MsdMth8OUI*$Oh&Ef?15Xn&*)G<}zWtePuD{}7W6TE+Ht(Ti{ zbW)k0*&fv2ig{GucnrKYfb7#Ag4YT^1z}Nqi-a3C94vx^wEq#{+2@6pB7|s zq^=fYr;aadBJo0ZBbw8nS<}$3(~O~SI2V8I$a%8kat1IA6#GzzW#hrCx~ne~5|PvN z{p&YjK3KC5a+xVJo(PnFxMkr^CS)2a4|>t0W6rOpYn!^_&_>erv}ar_o-7%%4)lva zDt0znm)m6^At!ErqNo+`aGH?bTS^Dt*}|1!>IQuA<%lY8O%~o0IeOZzybv_MEZjS#f%#_GnnW89xf~T~r zeU{$f)n%yHnqG`iTze>leG0*jqJfh-!wJ%>j3b6stzo6U(a}(;9QVIkxO$l(2P)F& z_~so{adXq%_7A(Vkvd%CF%@+Pd{0&W)aX-$KLdnqN)LsAImy|oiqt6RF6AxV-H-^u z^?jjb6bWEe@*#((!UK{HnkZuefeP8!**nB$LbkuVU+%3;n4xIh(RwQnR&pPI-gev* z;nuk3VCSAa!|N6v>eT?`7^KX$pJo@K_2V7$yj#H=fFT;G)$%r+^9(^gx=KM$}IXs zf|#Mt67#QetogP)F)7!ETf_&S^u2z7%GN)}m85b(>u!SryFm$D9e#1$&D;~#7q^c! z7L>tK$K$z+2Ws%9uY|N3TP8d^{`7LR2*yl$_*OduQ`%|_Ltoc>qv$m zqy1q`@3U#(p=y!UHjxBus_)X?R9NHYyP6i_G9}<6l8vH-_8|7?vggwNNFX^8o%ifn zI&6Qb$x`fGg$!!S_vL83F^KmTZ(T$px*lkDW?+s3j)e4XLxpSzh>=t4iq63+RzjaG z%?*&!+<))ID{o~3-SFB&dBD_W`AQj8oAA3WU1XhZaq;cj8gCSp=1=`P#*fiL;1 zfV06w5WK$2`%BFE$h9S#ZHsmxe%CMgW);{BUaioyjXxgDPWG)h60$sl8(+Fq4ioS5 zyUPtv&)VV>j{U;P(=mAW_QZCNN9MTY*bA@kG70GP_{{vv`cyo%O5VUtSqTSLE>!CO z(t%D9r=@hOMobvd*fQ;v1Bnc37i%kWAVPz)V6MFZgW?ZYe%=)a?t_+IwS?Kww>*yl(9#r`EKea%N`w#7h&$C$wCR!Ns_*%$30n@!h09e`7dQzA75^=8i>e{?ip_j|D-f z1V@sJVLN7a?RGx+q#o^6BE5g6TVjV2ji7T!5uT$AEP76D0Ba1gv~u#{!28lNZ2yTo zV2MwzU7X2*bH~k%JP*{v-t%XZr)YyA*Io7(yJRsa3v=zC?9arT`P6*M-Q~E`vz(o# zzX}bvZ{8N6N4PWpjGvCHD~A~(J)rtVI5?6;R2}8W!GsX&`&rJ_IA*fDY4Ug?R_^~I z-ux*QJS%y(bFNf^iuOvTaBBuo|H`e9rHDrfp@LmEjn$#^22Dqq+C>bSJjs|%#Lsxp z6@g>l3)t)xRyayk2dR0MogwCgE&an#@?r`iUdw5<-Q6S#RKCW9Rh*86H}V&AO(tXE zCzav{Lnj|dRKC#PNF9OZZ?1@Vea!_)sTZ|cnOPv0&d{oHFag7BCe#Fo7^lTk*A(+x z3ETKWE~9+|NuYd|F7$0(II2zCU7${Ug0DIQv@{K>K|O%hL7sS@FN~WoGE5g@$6Nct zEq3*Ak0dKK>1{dQifudeZeKRxvQC*8csChL4BzwLTu%dKkiKEsS_#I_(`i5Ne+|=( z$|L#Ob@1WzNbIYjTr^(}v~T&DiLI4#onIfu?TYHRhJx@jEVq!o##wumlO^& zjO+19!Byy5Asi^xk^pH|KLoC=W&_KaP8R#`l}K^$nTC}dfvlK)845a|f*Col?sfWI z#>FkSc9Q!O?|~|5qwmK;RJnZ^8!ZVz7Ip3uPRL-4GO zawAs6)HYbx)xy)yMSe%_SE0zNTh-4WUU)Q&I=Y#<5N59DN*0dgqH0*~_@llEOr5QL zT<4SsWvx$h?K3mM=>=^`EsT}uHj@hVlcUAJEVjb|El-zfFb2x?wtkD{|Vij-|2{W9Dq&&IZ}b!^Y%C zQ=3{nvT@v-PH<*mm;6sBFVyE_o&XlG|2+n z8;X_%JP4m_5BV6bsKfIkThB<|G(#iTmaGXPP5nca%IfOTb~KDg<|!S@fU}N93OhIp zP%gAfj+HnUkgaj}FmpTs=6F9t;j5{**HqR|=UE<%Mv_!84`w6P2wVLg%@VBJy=PO@ zu}XOIVum%TJsT~)I3%FVo#G?b*pLN=y#U&`?Q90psIUepr(VPth z%hDYKC$qs`d}TWBHW8E8sPRp0YbTnel4QmD5@tWT;hI_Z({S+U=TVsxoKvlh{<_xw+nJ z``CyC(v$f+-L&&yO@miVm^>HmEH&ZEd!I} z_l+9gOELB+RZ(V7J`Pvv>Oc2Jq*YD%yin!=sh7*kwqFiG=_(t~UcGqu68fNHRyh$* zWzYqXkB1?tp`ZQjkTxQ3UD5FMS`B`gTUGb4O@{J?(?1!B{gI#Kx(S6)6&C%te4EiW z7L*>+y1E@m1y!5tfja|4__3xPmI(Re$K+SRr|(Z;#VOA8x*(3I~lm7LzXjbd)_W zT)F;}R)8TAuB9zEjuPF3CEMwboT7S| zalh&S>A`eZ%9>#)J>G`ofwf5jPC4*6P5RjVnliZm;nKO_N7g7&yJz;yuR_Q@ugG!d zb^yvqUHG>0AMrpm9fsP3X_Hz&iP_u3%`o%1bCbAjIALlr|6tzN8-}$;O$XLOSvtayYcg1eeD)5*Kuy1^uf!QV^7W@PPO7~{=RJn~O8Whkvv{dHcB6BfW zfc!^2e$nrL@7hnw3f{jHjCv}gb$m8D6^sGdSxr}dCW4l$kzko zH>sUk+q2-|4Qibkz8Hw^BR@W!K@>+n*+d9vRieyP+?BJ*mEbzTZJ|M!#C5B+D3wv` zBZuI+nZHjvmUEbqeA%0c)LXYoP(QJUhN=PI^rmzq=@?|bxkTLi_9bOfy9PtYxO|-i zMFEy%OE-##m_pzy9ojs_GH7`%`DwvD1L-UFw=L-hqrk|%58{DIC_tZfYT!fzPhlmy~+k8m}_Fu|7(lv`9g+=n$!wtfNA+gkJDYO<`&V1$VBjSv{ z*j@h;XxofDQ&DbHg!z$#2usX`=^Qj2Rhcg1$;X4Lf_@v7nkbkPSor=0A#W75>3XkN z3&$Oo4yJrhg9N3HipY!t%vtxR`7vCA{Bwr%!$#T2%4NKI+A9{6&5q@5C*pIgghcN* zzc)cfwv%+%WUFyGcb`?BWjbE{xc+mK^h4Zgdc!){GXt-(2ji9HROqf#l9pY|#Le}M z!6|$}m|#WR<-S&euDRl;_U$?FqYIRrSsdU$yIsFQOa4)Z_=`fwfqS+s90ot|b;DcL zgZswd$~9SHg6FWui;bCsYz&v}zSjMueTTF4zm^M`Yq5vRwM0*^A1j+!9KOH&TQ#(2_x_^OKZ;9R%n$`6$e(4sP9fv+3^RdSqr`V8|H^$Fige9>c?3u;rP72kD;o z_^9BauHry9B;S{MrFfG_0Cz-dYL#A)q*qC2ZWsj>?_aXX)8(*V)kN|bbvGQ$&-LK% z%R=#@f)(qA_fRb@L6Ll~1^La5?rQoCB6sm#`?gjG5kkW=zmy^zeG!9 zk8jiEZW#l=hwG2TmME|2C&vHuvHONy%O+=c5w-fNwmU*^gh&4v=r?91{C|#?{0ULA zp~@#+{?!xT$R^CM)wM$KSsKfxkG^m;qIsY|ssl#bj&I>jEI&KBhcUrXIHg~UW{XFl{(3j(g_MUVINWY4aCIj$Mo^`MMcpV%pE{<1cD}&(O z+u~vb3ZeAb7qR2cM2@ZF6_MPT4A9g0(G$)efWn1B6RgJlz`WbGQ!vyEjkj=|jtucZ zHqLXmi?u3%^tRB&uYEyqA!CxEp}hz70-ef-!v?Tw_CsA#RT;{E)wx;EHUOv28?ud_ z>_@UAcjk}i4T7qyCS<>=z?RY586Sg*+&9s%@cq_3pl+|f{{CDo__x-zt9zUOu@b)_ zO8&7D|2k13E-XA{WK)Ce*Jk4=rke2bkc)yc@joZbC2_=uhvU%+J`MM_TFf8H2q;HJ}?A1~TkDYeW6{>_Np{ftM zgcIzQeSxNjHT$q#Q(|OLv>ba6c^;xDA|#U4)Qn9#`;q1)yXr=ECkC^eg0NFPME11D z9TCzR2(-*`+2R@kYd81o-21*11K%`$-nYL2^;mYPb0#2eriyW(6==lZ54NJ&kxg*T z?ws6^ zbpDn;^&6su|9?f4JWJf0|8lYrb*5t-OAf!sbIyWBbm#9N%Wy(tV#GU8F>RI;T5o`y zMs9Ap^>)a-IcVVjwHlAH^YJqAmVm~*;?BvZ4NyYMk|Q%Q1xqK5Z9aQP!i>u4=-9Y+ zEIL?UXY*+o7r(vGia+)ixCJMB?U;v=*X3!q^PNVdwJ}qQ-O`B~l4l<}tql@=!e>mR zS_3Fnaxua}kob66FAbF@62e!3_F1{l0dW49NT0)3!eqa)UO95P3_UoEcl~f^#+qHN z{M@uN*k&t2Re8GxRhBZ}b8~lK^|4*wRF1c!@549u#8%w?*h+pwl>BY>{(Yk4o^{b) zF7|Awy6}!S-%kYAa<|SgK5hg4)|ZmzgG7H@*=-{JMHj@0U)|p<*84V477r%ceWTAo&;smyG*TRj|@CKbuAM6}u^RB%eg?>tV#{2J;;P8o#@?E@b zFk?t-m^(NFXPCz~sGjA(8uP>Ko8HA(z3;GOGyCU1Uh&@$CI4+={<@v@52A$KMRMvP zK}lWNc3wUyV+uMY=+#RF#^I*iz_+WqW$erz#GV4j))@ChnAQ+oI&x3*XXWsn zyn%EuhN$ z`^%m9KrXd#?Pvo8daZ^(2&l&gepx2h`7&|%Jel}qLJ*_&pb5D`-BG;X29s=wz$&#* z;|UVe$iY=mDE6WPWI5F7RS9#F=pB0Z3Ww`(dC#@ew<>yJbc)_K>nRb+dppbfz}srV z8gzeGo3 z+<>>Y^Ay?9ApCy3nAE$Z7cK#%O=epghF>>0qGM8to^LXI`3U!}ec$X)dOz%j4S8AZ z%MCp!boNm7;Rii)|K^m);ro0a_6)%cf*l0nVFgQU?_pzrK#7QHnC zuGE7{t-g4Y%_e zZICCoNhXOs3l#ctIl^-4z_Nq<)|FRr*ufShqI?{ocbBt_%)U|-l~>=ex-p0eF8m*E zh}K|0-I<9)l8-@jNbz7zTm)zwSO{QtAfm2(79y#mgJGuwxB8858E8dcY2v9)6z%64 zG7O0(n1WvU8S9-jm~+97dEK4J(0#7{Y-G6>Jn~C+9_^|invEq;%ijbv6YZ`A*9w3# zq}rd{s|3BHc^qQ$>Y-HRg87kLLg$@#hlZyl2?E5fa?Xm?qv7$i`qE8ZD7EI$Lqp~Z zn{q==Ul#9%#61h0j+vEs#I-&cxAuatvKG@$_I&J$(_|Z8AWYLv*WHpLijLW@p4~Vv znT@&;j2{ervN6)|iEe*ICxo~UzLXni07*aEhD4n%?A32HTs@SHnKb)%j9TQweabP8 z9czT1E2sYm$@4}eQUW5ykM}`lvC}q|8-s9q?ikEZyoL1(Z|NyG+A;a8N_P%VBv@Ce zrJniJjuGAD3|uRH;C|_&zvk&ycw3xv>BP-u5ZkBZCGemU-tdnJn92@-p_gC^^FS7) zNk?R4ZfV9+jp(L%@iNGfc7K^*oQ$9OkNZg^j^MM3ub0F<@=>K<^y;bnG*I-uV0v$o zXj*sn$0owbx7E7cw_4ysP7Kqr z!d$$jv2{zd9HCqNan+^Qa~Kq@>a0cm(t%#NLe}|h5nA2uXtPm{1pQ<^7r~pvYdRvZ z;nvpydE$p>=v2C4&h3WD1LD&CbWWY>aV%l6@xw4DI44md>mK=yM4(@2?t&Rnysvq#;z1wL3O%wBm!2}0!E?%5twpjS z?CAi}+EOPdkZ|C?L6AR|NxLj^d#f=wcl)-bg=qT436L??kdth3BnmZ^A;ThE&QoH1=f z4=IWsOmc)|Ik%pxc@)YhN+FcqyaA8NyMU(VdoV@;QchIhKEZ@D3tO8wNF8qQe zIj9;tM0)Zb4o&}Y%KS|s<=nqXlvJpF|B;p00&>Q8Z(bzC&a6_~uQLm@p(5>_FNOQ# zp>ugd>QHYeJfp4YFA;kKb{CAJOuPwmjX7akHI{tz;!U2YU}?Z@VW&FEjnBX~iZ5^0 z?0q41?GdX`S{Z8Z+0gcn&cZMIquFeAn$h`k(3T_DW00SxgJ#Gg4R~6T;}j^pG2dZh z)N!U6z0&RD0;Y2DWF((X?Nd8cIN6!8|AH#ctO^TK@Wo-^Gwr9R#~N|hQwjYb($-*-;M!>gSF_bf*1K=hk%hGaz*mX$uCwRjs2 zx#cS!fvlDAzUZcY9eFHn5ux4JR+0naKMJCLvgBf+hkrt?Q#tZU_LQebyTYxhA=#aW z3Fk~LA2klfV&JE)+#kr5jV?`Go6cVI1}TA8eA?P(fW$Z`{wn=-xX(00_)aSp} zA`Vod+}XBXwOl!bySWPTOoZ%#>4RLCM=&U!a6F^+xD<6&GaBj0YcS$mOJ$S11M&ok zBp>N$!Gs+bS08-|2Trv|Zmf%WC>3hna_3bs9F6*#vvxQGipu(5ly>IA$GXRYFJgV6 z?MGn#gM0R9*mXFJ;Z+H~AUQ97jl39ihg}mc`DTJr_N%JN;0ok^TWMmdmjyPaa<=Mk z8{sKgR?}C)?w!}&Q{ps#IeHJ1Ggnsz;Mxlj?z*2EMDTqbi`L_Gw7RABGUn?`z?c~4 z$xb52Sn}e<=sjNWP9e80K|UXC=QgMol2SqD-hPMnE%iV(ctNqW#UIL_k*P)!6;4ix z@nPmeWgu8CZpbE@j74UyW-pH9fyUb)i3i6e@ECtCd738TWS-e&22!F!ljV2OwK)ww zgGiBAXCkJ@e-gQKB@y9NET6M#D5$B+P^+A)#A_S=lOtye!Le7*ecUqv>wLs_e_twx ziksq>uWu5Cb7ZkUbdMFnAaAlE2O(Y^jM`ji=n{&Krzp!0-$=%ZP^OFetiHHOguQ5i z5Y1OaiHu&%4#rQ1Ikn%9cp^(hZh!WHKoC4=DOFl!ja7ZKI;x?*(7+U+!dw&v1_`nq zH_|J?>G%MN9ajnjCMNbMduM@(?&ymDG7;O5JiKVo9D`sXAS$*E@we&ua!GpStTse zMD9D8o{uR4BnQsNl!5n&{&o}kO1vbps_ecc5Eg_qSyFCQqFLC9smZG6@WbY!r{UdD zjJrpEC+}SWyqIvgx`!~2VVY~d^KDBRM(M~4DTqX&eEFUh7SB{L(4C%jI2!>T!6t)z zbDl7AyhONUR}V)&EvQr9mTgANQnI%uY-m;)fXz;L`Z6`6OZ z4aMijz^gR%Xj{E>Xk`xAUa}gGM+8$YwjXeTv8qFAf(o86o!EQkYgGY$3bJ6eC#nId z}NYX-i$>cclD*#YA>x@*{;TOrldfzl&y-s087!v{!obCKEE(JCaS z20MkO=NnVYF?J8xY!@_x;-+&2c|s`|WixWwqc0uaQ%OaKD8~_7>3()@uPnlMjIyIz zEC@Fbm@N6!Rzo8nC&RXmOf>2H5^U#~4y&Y5+hRUGgNXu9b@O>AK%+>v=QmRE$uA&f z+#2*Q9u_{w2I=b+!->$h+}6>RaQjtqXnklDe5N3e)9fN__<5&nj#2tSWR;)Euw53+ zIm$c}FbV|LC(1u}t4AV!kYsJ?|_$$OiLYuDR^oggMVUTf?`PYtgLOB)3Y)1Hwl?m3<}Fmx!IRRF(0G z=?Di~`ghVyL$yF%)I5BasvNYm+&vqv#-OmOXwArOhkeFZL8WR<$)MieCkL0Ya(Xj+}t?K@2XpqF8eDV*G{t$h6k)<^U2PBXuP z?Iz8a8^;sz#M{;RVW}dZ*u4*L*15n$Y4zd6b`$*W&wmMxvw=yfx8etH4wPCB$#upx zL8iIMmLCZ*AS_$|NcyAz^oI0}ND8Omr_b6aSGEx3&UvNcpu#eEHDSC>#Bz?(l29GdC6!&GI&1smmiQANO%dIEX8TQI;uF8796?M8{f~Fi}B8#mL1WVfQoMK zLms8T z%=(#KCg2UbEn6b$q2A?A><5lI7^^*3?$x0No=V4?6{13Mj!uIyM5hYNU;C)-xSfe+ zTyGh{%?8*+Mtb~DWx;|sX_TKN!hXR@bc7=(f%knCOutCMVO%6U5#^QI6)8v?1NN;mioA*hpKz? z&F3iY*KQfskpK;hL2p!*TcL5gpo8=GdVKX#mhCi|0PafpR-7c;tcGEBK+A#!O)Dt59^gB1I|5@p zpGVMIXW(F7C^_jH1kwj}+!2K7!M@DbPxZyT;Aev_^?9a3IOeNT_rSLd^@S~XgAe6n zMuLCn*gw_^~-*+d>|6~Q;r}W-LOIC%??$%_#WDG+RU9Py$#5X9qXV0O@ zDmOTFJvc>rs1b*+>7SXmu7MpjOC8LFSquT(etvK+6TBGpX->J85*=T??6K8A91vsC zvTDhOzElzkXX+X}$u6k-tuPTbEt>mv`P#yr2X=SmsvB@bFs@bOZY;7sSUfCh;Dkz0 zvRt&k5CthWxDARk(fjGzUMlBY2#G@9fKDT$y7}#Q7qJ$M%hsl_Ofn zrBnMlRfE%`TcKgKjYwlAAkh9e8C`C4x3?5!!6@?Xc}u|JA|X%x>4=!2RoGRgp2Ww^6XzLmseOBI#Al!5Lzv9E6YxgR6TW1wl$n35LfJj)pM&VbVMi7-<9Ve> zn9og`YIF;SxV7H-p`<);dLN`L*Iy1l7tVd9;z#Nf;{2Rg)8V0xXqKrux?T_Zo`&VIsujR$%5i1;4?d^o zK64pxv~S4=#@U}cntq1iL5KFsd9F$DtgOU^Vk{lE#!gm!GJ1oTB)%B?IJCpun&x_~ zjxH<=b}5qIEr;Z|ukkcJH9*$?-GQ)%#hMRNPUF{Nf!#8r?>yyOn5Ej}dTw(D+%XuS z${G!ZZE}8e{#`j>CV2g8>_{f0k84%(RA!;{kVY`cLIRxBnti@Q(F@-zEd_R~t6-U( z|DtSYAxucNnQRauh&v~HWp)W=!Z}%AF^?rjB%jadR3qYDvIS%}fN-p>Z;t3JP zPDRjL#=k!HE*sb-zKiX(j|MSe*7^D2JpA@}PFdKs4E9Vo>-IYOLAD`}@2kzV zp!xjTQ_i_abT{RC)KZd*olGOo?>)~Y2*Iu*ll45z?3*DeDX+#k_t^#ER~3*gCM!0qbmtfhi^IVepgaSb!J3+o61I^i)(zmFGp+)`dC-2>< zaF-|g2;T=s9MeeY(s@=1sc(0*?KCPysg*p%a??~uwY4UZkS>Jfao)%0s(f)LgLfO1 zgey8zez?JNKN~{0u4V=4=7Mlskg!6cH>4e5+J7md9`!Cg6>Z~+heUy&hxC$4&_w+` zd*b)1KT0CMVHo~Vhxm&^2%EZTYD>*Gu#mVgNjejTvzKFQNdlKakBsdLx!PxpX}jt4 zvStn$1A63DbH8Huve3bX0vGUS+PnN};cvDX|3e| zw%PW#>HT;=GP%UwzlksvIn7bXkO>2JG=wy90SGo-u*GqZ-|m}|29!lplmF3h3Fc5F^mb?t=kOSHyw=KMH+~z>{I){ zH#Wet%kCZw_I_9u+Nk8F)&%?xpBv8+?rDi%x6~vSyo1G$o8{}v+fjLaIDy=ja8n$3 z#E~u;jiPkwA?0heK(?<=Iw`Rek7w1Fidk2{Y|}H%Z@y8uWK*qWHTni(RqG;4jdJiZ z-R1TJX&E>Z6vd`;G8G-a_`m$VISOSypEGN_P!G3?$eiULPvaPCdAZG1B2RVG{O!^Y zg~;HU)c;DYhOp|kd?{p95BzZ($F3$cVDb4V$}+7obV^R3)axAv`V{AGouc>PDr|Sf zpk@HX?wuAtHeL-nbzLJz-whzCi_M88+Ah2%dPcgwIv&HGWQrPYWMIMJ_~+Fo-Ej1C zq75TQ8#X_^KeL{c4eFEri(`jV{#c3M5GDUuiGQ6ad2bJ?6$B`??b;8^)PWwPQ#jSJ zn~)ux7xucK`f3KL^F`7#V~Zgz&QE{WF(TXeb?w3F(qzK<(!1YwUpmCcg|SOxDa;Z{ zaq6dffmM#$jJxbDge*NMJ7C-m%Clc6xwa+a?D;9#69rGOtleMB$gc?VS5BVaw6_km zgrXxf2L~}F=KYg*g_S^}#zj{3pp}q?)IDA-%qDPWqeUI24tP)&f06(r;mMr}ulSo; zM0T9Sd#l@7NEx}fMO3g2^z#be^c6>7#HtzSx;p5&u1mn-+x(WVr#tWrc|P@5 z$c6C(wiFlqN3gN*#H6ukHK-gr8Su?>41>Z7O81raz&-PC>ngd+zhoY4+;C6b1U+r-ofvk+M*=FJ~Air0vZu6$YU_ zfz+qxZa0za>b^z&W*t=W413hkmw>u0kev(dN5_lHOz|vvNOygvbcBF3c~nC;9c#-+ zgBu~QYyzv%>Gmy~_Q*)M!^nBHdTtCBj@;KDkPU*&UKy=TN-=OR|3|^QgM+XUAZ5Wx zh?-X11*6{F%Yf#Pm26+LN<6t=t2<=51L8%(4}_4#z?bOoUP|d!+;(Sym9Elb>| zBY>{(YikS+*2iKaYWJ&(N)|Gd&P=^()^lxpv}fuD}*rQBO$JMbxZKJBEDwo~ z^Y6vN&$kZLvnRveId17@y>u*DX0DJQ9|uwoVdb^$WANiq`L#50U)Zk3f0~Z55?i~! zhcn)+g5&d~xmSa2P(;s#hVyO>lH1d6-I_quq*pWibQuQWRm4vq;Zh!TyOf&u=O6r>{E`zFYE0y{4ilyCYEy}AAwQnI6+aDmGD^$%VL za2}Op*~Zrkw|rJVY;|hDz2qC-7o;oj-qO)dlHf7o;yF3?Q>Gi_PdG@+1SUY{kSSOfgd^+!d4@Re_}=8qCUzaWmMtiBkj zu7vUYtgOH*RoVl6*V<6zk$ufO&3Y*LNV|=GS2ZZeZhIVYstbEg-7$|Ph>hJ{uiO3~ zZ|@z?_22h@%Sg&-7bQ|jBGItQ(-uOaB9)z_D58=iL`HV@mc7@<_O$ojduBGQB9X55 zd3?{~cb?~Ub^MOw_+8iS`pbVlyvFDCdOe@_$KxTr)f*SAFK^#`Xb9>_KdI`j6V=I{ z)|D|QfA~QTWt;A_f#89A`tw>H@OfLttMdkF=(L9-IF~>pYYRS$S1pvmJ>O?)re;p~ z?(IE6*0nkm$=W@^x32@jZXG6Hc{+v5i`>1gOzm(y^-9W0;V|UW@Z}EwN(QPUvFp+e zu~<>|K;SOnk+OHwB$r!RD`rQ_7u-Hm2xsCgFtw^4!Wt?RQrE=x_?L+i^_g=u_hqf!DTD1EeN|;F)u^>0VM1Hp4G(zzz6uR@VQi3$N}paaSf`2YINwl*KS$m? zNGFDsYw^^T+uKv&!20JBZ_gR9-`3bUzMGKgjelVcC1l{;ArDOLiE~zWI;}E5xEJj$ zUQF306@Uas;?)p_O8lxPE|J92jt3hw>ewH5;z2j^wU+We81V7U<;&~@hiBnA9GAP$ zsg&GdprZg;2DA2!D#xNZx5>(n6EPrnK}cSwry0E;y2kIiQ4cmvAAhBE6~gEbI%6Ve zhX@G`K2fz20qf%PX_th9;4br9p$|3>;BiZDYBps7syR1ZYWr>vDVntW<>Ew|($N~H zv!apkUQ#CTSWg%BD6WY(xNhWyFi?dJbsczPi4w(n|E11sHNh5m!Gn z+!>SY!m|x-hXe|0;ZBp2_Ic^|OE@+1FD~-5zEd#7h z4D$+Ktw%%B>ZNbztDyae1K<5$p;xo`^C6gU?=l#kZZ38 zb)}1TE3eB5@>fzO&W&(-U6YzYX z06`tulN}u8_@X-7p;@jFSbJPf1$H(AZ$rx-wcVBIr)x6#eXR@+ERJ*TAozzxC00w( z6Yijq74~>}knnKX-*!#?MGO|UkIjdQ4Wg^eK;e@U?SGv$e;`W!+jjitNAUk3O74T9mH4*J;IKqsKD}C)q62 z*~vlvnY|IW=GMKre>X-I#sAW`Z$j>uWDAzB zh@F3p$*3aA3q%7kKXK?07#anKN;g)auF`pJX}ymyf3J`E6VyWKBjdp5lyy+;a6cj3 zX#&ZAtf^U*5bcyFq2C+@@?l?jv8-54G@crnEEu=wz&F`iu8)e^;dSoY7oRgnaNo7n zdymuGpoKN1e9WyFCpBiT{TLgqlISYwm+FTM2S&kOq6kgOS;}RP5(2f)Ou&W~we!id z%5G_plxWmFESv)d?`y6m%ax*@>MCp& z=zXCKt59Hg(85}>8$(sB)8!4zK*Y~9M(9C13Nrn2`?k~%xg%`)N0f*pp#N~f`$Hi` z@ZTg#rX?vjt-UjWEWVO|M}R+m{&3m+S6Lcly3?#aP_~6*c`L7GyuDGzu>Gyfcqn9N zjQLZu*+Kg=BZ-~c++aM;#WzPF8>7|RnaGKlfi1f7!Da%fcq22xwIsL*_44QB#C`Ii zP@RYRSZ@w;?NMge=qQD4oBLixp7TecbVaTLmRL}uD>K_)U=N1E^rbj1s6gfBp%q#}Tk5pTp)u5*k4K?3*9>wYcdkJIgfr0cnnj@U z+ge=P>LBR*)eSSZp8aW0kPh2VIaZSDbHMn}-Oz_u5igr`D1O@Z2;bWGJPxl70p>Gl z%p{IHXwqbTbjdyeKV_N?r<;0!r-qs3i5=c};&Peca8nI_@0-z=%gn;ZA02)t5EaG; zl=QCBeMGFv>4EYdfg(JooG~Ro+W@MP=A8piQ-S7Lfa0)G3=|!|yt^o<0xAwZxn3;R zjv1P39xD7zASf=k@rgeWxD};7>Grr{09o9D5<+Wt`GgsDFINSAZ)y4nr_J$No85B* z?nWFtujyR9nui-o!f}$0k1*h)NHq1E1Yn^lUEaZ*11baL3Qz9)z_B?l?(JJL!14C{ zneu{aY#x}bNv~9cJ&Q|K52f=Ve^+GbOmHT0FPOOR!df^dYnD?FX_!>Fe^m39HU9YM zF2o>Lj_huV6cK8TaKmD(r9Z|BjyzeNY|E92Q_bXOwj8O)6gj#8+FiNG(9E8Y8ukR<+jdEY3ZwzI`t!Z7E7MR|X8RY# zM~>(*Dn8H{UJW9&gJP+VYH;_#1fN^d&6vwa#rxcysOfngjg@6g1~IFW`+Wp%>EZN# z@uRODvguwbHVrg`OEg;Kg_m>j=+@Ft=@Z1^)qdbvGi3+%1S+!5KJo*G__-HnBQxQ} zCGDPo>lH8>sALcY;AgUU zyibHuh!^F-RT{tJ%+&~DtXE_*x(aZ}J5Di)$sJqlPY9Iiroz-1Ux_S%^JbBA&Ndwm zgEPCF*b0Ygkl)7E^Ud`zc%P!^|4=CvPdd;GKl|>5O&{#lTUaBAPU5((Pgel+lz)|Y zHT@25i`h3vUTec<*Wb65J3_JE_v)7uUtD1Km!s+uli85tcjvnm-yLM25w@kfR*#a* zA*bcNiTEnMue^$hQTU^nKP7MivFy#n%0aVmP@$|h^W`tVn#(O*nyitKbCM)%MtHx- z&R!4n)Im7DO4iK!4537%=wqdB6pB{&g`N_s0ULApolGAAq?&`??4goyo!@F%1JQ*`Im1xdBgqGSEF<769+4t}(ju6P+mw)(<@L0l@J#eDinVD$M(x ze_4powmQnVz8J433ZJn)K5ra9zylL5dFjw9JgfVGYO_~4tiI>Cp!~BGl;2<8dNlbh zxOlBQIXPrOQ?*S;1T=x_*{oMzCf{Rat?1oIxe8R7%E?p{r1V*a+J5&Iblbe3nr`|4K~H(2S==w5RI&GvULU+BCIoz=6EvP&`l{GT2@{(+nD$63^z3 zmBCHMDUn~HM2tjwa!aBBk&&6T<5c5^N_^<)=wG>&@YYIvU~Yf43>vaFnMhk!qgCv| zdco9s&?n9B*UoD}X3bdIiU3b|!>yyFNnFYcIdWq*Ek@$iQ@rnWIehT0>p`iM-ZI=9 z9O5ywkO@@>C+hdArXz)~1DB#>8Sd{XFW~(}!gu`D{wvm17(!o@WGbJCHzIYy9Gx0K z-ZLo9h&>x4Hy2Wb7Py0Dffq|mFhQa!SL|e?LEIs8ic9lJD-6qWsEAGz@m8aHUplE< zfO~tgL3e@|e&Jz=E-`rr{&tmF-c?zk-X3+)halnRm@cdrm}aACs@JvD`dCcZLhG~E z9EoyIDVjx6e6d(ncwvGd;tCD+9B$>WhOPsPJMIzp2n}YveWoWXaWN{je50%aLS8aU z&+jXMddu@R8Jnu{TCDCEqiQtfIow`l_>c>F`yae=O(8r^46Gu*>?8;+y9WE8NxE9qIbm7v2Q0{}#Pc z1CKrR1@pD5af|G2wMR|m*mHa5Ue2T8kSaI5)yjj&&ZOl?h%Sl(i;hr6-YfY4oxfj8 zUrC3g43<;URQ4cX+0p-Gt{NUC2RsWY&BA2sxlNbGUjb##<8rU=NbFIcjoM;$KAA1zDW6=`gyjoHM)?PKINY!I;(Q!HLQYBFuPH3X}HuE z5AO^NwKp$<`Mq`G4-RF4T?#kHm(C2Rpo~k_|51cNtm)xT{6jJ0T&>ugX*~Gye3c2l z6oZP4#%Iib+yOB+ogW{)nAg8Kr7p*Z9vMwIr_8jg*=j>gn6QH`FlPLkMr)QJCF-^ zW;&TR&SztX5>KJt!D`4VN7#d0wGiTrFXqiCB~lqS|f0+1^SF0Ozd~9 zN7m!UeMb&eVA7Q0vzUG9(9~pH$T=Q~wF3v->^^!!H&)cxZ%Kw;nYXvSsd7-7;o4*n zp$~j%-F|tjAPppQMgxRzk}!c=Z}P;+1b8i^nPE?mjPEtJ*&g`ihza$Fy8>M*(TwgV z)2LqshVaurfA!244sz&kX4SNVFN*8x-_D1k)N+Ls$;1j;t*TxP(>cTQF?YQVwG=3d z-}aq;vIOV_@2g%iOTrq7af3sk0BmjgEYUb!;e+Xj42c%if)9(W_%80rty zqkZvqksS?5z(jsMx;fSyOAk1|dnca>hquufbE}s^8}E+}ZVqQmUgMYID=t zSFOmHdRIvF!$aU7CT}GXPpFks@h`X+N}*lx@;bW)5x?j*=Hhvt@G437zTW-o6|%`s zd^r6q8b-vIbF)ISAV1P)?et<2RQ>wels|5OA8pA)6*kOZRxrZJUo{`vVzk8Gm&Jk< zzvKBM77>_W+md~z{0Y2u4>3IPIu+BOW)(e%CFiq z5EFTDi0xPswr9!=T`^3@qWbHvLY=DM!1l|Wa|crKvrM*c*JK2qYVOZ(X-kLGme)7# zmKLIe&as`b90^1`$_`G-hE`FaAI=+EKEFyUAN45cjl3ip;^YWw>@`%lDEPxPK5*gMEHXEWxu(_ps;&4S94xI9IGi9Hn>m$PrAl}`b*2;*Y5mBFkJ;(9pvTur6XWp zqv(y0_6pdV^VQGgMJc!)?2X*a5sx!2rxuR&SAk?Vi&65qa>4>)d9xZHK{P1DvV73; z0-xqKfs=dc;O5}eqUHQ6_#*$-;@h2gJmowf5M5b=6lZ0#=-yT0x^(91r$|ENa+;#| zzHBL0o(Q#jAoLn|>Pl_W%UZzj%C+6>sf`%EoUyI0DHW2}<~iBEAf%YxXtf&&gx3+% zg*|&(aC?_QP3JXd2+?+lGdWxW?Y=dmT+(4EnmD%FI$Dn(?e}H$ixk7A`)?)wRbrG*gg`<9@Bw(J<_(W16nHUj{Qm67~G$gx|-JqrTgm{BfSO zDxEVv4ReE6TE{Y85X{2ezBP$_81wFu%l9D+izf~RMOIcozPHHN>Vq|K|MU1jWugaS zp!Y2y+9tRlVP@rUvKXQnf%nbF95^C9Cn9^W6jXQGR_+gKL|qT+1M@`uSbD^zv@1FB z_`a0yNsG28tZz0lPZDs#ZzESHw-GXb@qO|4<%8o<|6bnXJ+x(zzL6;6NqACzbfFv@ zA@;|uV4Du^C0A6uxp*p?*tenJXqCqu1M`+0e$7|PFlBF$Q2V56tfu_5sy-dcFme9(!zk za%uz)UbH_MJk*Kfo`*$TgTDRs2mEgeA^#FlLMOUh&|KJ$dg5nA=-Vr?+^UJGT&Eh? zjRm>MavG8GyFygO5TRp{NIPJ0u>y{_G<`Dn?g5Xjb%8_$5;s-ImAam(gY-z_&){Oe;`T(|81h=;lj4fmi65bBmFV@oNF#PZx0Y@Tx>(TlPamkMy)7Eaz8;m zR|ye@hJp3cLCCzb`AGf2cHAp*Sy-c|9K&JiC}UF#OrD=QtH~AvDizwgLEecNee~{F zss$m{DSv=d@>ZHyGc2*rG#&Q?Owu;R%q8c$5ESKhLhX&b0{>08q79>B(> zd%Y@*g>W}KrCPVR66eNW?Xe4PgGWzfMfUtiM=dGNcL!H{&|XmPsm-1&!i*~KVBx2J z;6BF`8eJKK(%vrz45T}-!DlE!d?gR|7PKunY$XuSeeRoHZXLzu5tfyiULqIlRl{Pg zWwA4DMi3sDBk^4GBVMAUe)fUv}ApmdsV z3#xVg2HRj z`50L*!)go2@;50;i*^I^1D~ao&>U?39+O^lCm3X8hGItv4--Cy@4}yIN^wOq_vC%H zWc1rHvf4yF4pE^#%Ymk&m}K!Y;jL&9z6n)fW3JA?yV)o12?ln;;TD6&-GPRPmp4@du9*!T%$ofJ14d*S!K+y!Jktsz_Zb0eGCB)DzF%!y5oIruXOan)l8wTFA~WTOpYu^+F?u4`CJp|! zlK()I{O9QX`$S1e*jb6hjehJGyj;K9k_J4Rt$D`Jw!?iVnrAhdQK)UWQ)R3bQCRE9 z-TJ*Pkl_BLi=Vj*Wi1|fYJcoO^V!ukhW-E?y0!NUi+3$i(?%~{(H1QIF_J6A+X)vd zs3>Z6dU0PKr|a|L63Dsx3K%c!;7X9^9L}YX4BEMq&D0!BdOJxG7h5cS5qz$HR};;mD@JMO@9r{ABd9w z$H4sacGiCpC50nDo7O*d!%MRTt2>e|y*-?~%S9_NulV&?XZzL=+=ExY~Yj3g7 z)1epqKi#8KJy?xgiAS5tEeGIjIvLkF&K{)Nc0`6FXdD>_GLH52H$lvY4CiB+?PyEZ zEFgQR7e!yXXR>iuV)QQ|O=i1M5VYrXVI>@DctcaqB$a33Nq5sZj@!HoAO(5E( zBT*gavvSZrPJ8s}lPrkiJ|pAA*M><&x8htP38K#B=w891N(_8sMcPY9Z2q=t{*Zd} zj|b-;TuA;-lzixA6u00SMb-x;we-#+V~{8#^Pi4x)l@qa*+=-1!euhUhJQrYHR)w&`53~^iGRvHd_hx2;+rbea#l=S$HzY*%<>(!|5vfXseNL&u!yeszDsx zTIgfv(hiB?w|F}4H^BZqZYCG<+rThA?nKZ~4DNNw>DWcr2aIwuMgvMcu<5*u+2ZG1 zcra+t)hH8;>65(j(h`Fxd2APBk99XT%gR4-Y#e~lr@2cm>J2#R?k}@U+y`x*L^v6n zRl=)lna29$b+{BNpk>`pIMLf_ZDHk$#cnNyq5k1Kc$TBUJJZku0X5364}}wPh?nn9 zIVHBi-jK`Hh5PGa>RzqE`At3ex5fPjqU3+h#=lII1m0TaqzEBo1!l4DTtsKd(jh0mHEK*pBrqm=>~mv$y{P z*40vR>=z&$2n6CXGd~bbwkydywkW6ISlV;u=dqbs`=YX~Y^o5)cZg6kJ!ylWMti?! z5{{Ute2cXD*&1vntAmc^cEc{OuT_3Vt-#JFetZkTMWdiv=u~4VHq%`Gp%&4JmHQ<~ zLKliL_s3Mtg{2`VxuZUFoTCAUHwCbEapeOmc6;2fi6t_tQ>hmvI#Gk}MZg)AQdo>| zQRgvl$2vpjd6HokDrkJ|xX$hbMM5qTk6AN-Kir5c$-NKndP`4P(U#!$EfF_9+xDV^ z*h#u0&pyD>{CuuT{SHXEpnOP4mH4xo?|!CFYC$Fy_s0yM5(&=~McU%qvY zNCwl=e!XeV2mD?=Su2{(!5~$8XMu)hY^r}Fo7vxvLc?SiVy1e~=xKnq*o_{1D>Snm zOWV;?+dQklwHX#@OV_h{Yw={(W$XIyC0H**=*=m+kW!t$)W4?|gp&hiG_O%Ma3J6CN=QH-e0ZR6TLKuS~eLM+)luaYiMnjKG2AK?Te0ohWDCEysJX z3YOfX#pRA9K>S=6#muWzFnmAfmVK=Uv^7HWOX+h!eUv@3ORogU{i+PF5#$mj)w%m( zBO!3=yENa}MnBPH%k0`_83~7^mP*`r5^itG1#0eB8)4f;{YWuhLcV?b{=o~_``32- z15xtdw&OoPg8v6mVwgo9Aj#JP)pv)MLURVNZJ;Uh$-QATzg$(Tu&WYP?WA@eB-kM} zDeh>yOagWluemenF#?{!B2lxM(a^HV`L{cH89Gz1OmCg>$8Av$i`Ix#nhC9Emx3MQ zagt8Se7!RnI=#4VnJvXart(nVuwfD2eO;AS&|D6ETMX^;#9T1vzB|{4i?z^_(c=Ei zrUzw*loCdNWn;{p-2N-MOvIJbWU^WbCnt_r#>RAD@V#%^$D6Wn zE_Zf4Gp7pqE(D(Kp71~}9uxUsrb0X(&?)Y2kO3PPc`BFAcH<3_&rz3!Nr;e;e|f8} z0m^O0VmuPs@OPnzCzsA31g_aJ)UH*-*@sgO{xr29?vPC9P(s49t%i@)q7nN$IoayA z^~2#84-P%qUJf@qOdW*F^Fc@xnebDYdZ<&9%&+wl9H<1#7QJ1`9YlTSrD3t$X3khi#6vYqDmt*Fe+1o>B2scsb zi%ZI91CgHYC(pusIKHmC7#63W4QEK^uG^e*K+Q&OKtL&suo2q4>KR=MOSh&BR+chB zQh&nXm;nhRRrS{-KWE~8`9QI?m)WRZ;7TnpTnzjAt$y0d6hXj#NB0YlQ^7}I+F{pt zHAdDJ9K1pp!m`ONJd3$ljsx9fCrA+qQ1N|l^dx;cW`2~KijNAy*=X4d3^rAe@#yOYs^_E8%$B*Upmw;i5^gvi|3^b`M!6r8n^ zT=x|&fT+QCHxo4j=$@sb4BA5=CmaVgR*Z;u+-e1@gU$!M8#*VTML2dHPNigR4J`pt z->sp$JPXitsPUG(Rt!cDQ*L4Uk_C%b&DZIcGLZA^dr?P&8lX7SuwY+NhsurRH7sj{ zsfMF;x}A3hibp(s=OdpBhtwneIgR77w$5ymFwvrrW9Yst%3B8CsrT@QUoS%_eikn! z8i~#yY^FQPs(}~J9o^Z|h(^EOpW~`4z)BD2U3Ss0u*IU4X5RqfMPGlJg4^x*aPXOR zOjD=QMAJ;WB*}y5_UsYx-xh@5;n2Lua>2(Cb3DaI3=kyiNj~+yMeK=_LHpe)IN= za)LVzIly#5GYdGk>vHESmEgOGrLgl$b=ZB7a>IcTscKPEsi{b(LnQC-#1{$~$nWXT zI206)TvUUMN~C09vpH9}N%0N#o+RDA;g$p9?z_q(HZox#SEZ^vq#jKNV=1M#RpH=; zSC12nBGETxcdYDiA=G(uJ{T}8g@!Jrpsz1eP#O5U8CBEZR;Oj(&gB|7_a+<(p9K$36ggOK$-{I}aL+4^0juzVT1v%0{QX+^GY!!i z&PlWg*mIGH6Z$4dCSV(ke#uD=eT#^@CEv}ZzpKV_v7bW6OVZJ~?X}*%fI?{SZ0<2h zdxxh&ctbdSO0np=nnhgYTjUCRXt+PU08-3nrkM`Lfx-h#$5Qb;DCZiE{2uEC-aTFu zTtw@r_KU68JM&Di7g;!ONZW>s&!2reR}qgdMT9SXB7Psqk4%a7O9nh*HCMezAZHEh z5AIu8`{NhMwKlz_c6c?Xp)5^^TR$viMqSmY#-3Ks&A)rgA#eBalon$p-nHA=PS`o1 z`1bmO&!R-GwQMG%!Gkyp zVc$~mg$$qe{2>x<)^NZ0{c<{J7MM%ZiX>prZL@#|eG=ZuC~6$;O9h!^)fxXR;>_N9 zjb}AU0NhB4UR1)a*!-T#P^+^7BK_sRaTAEvsK1()!_iW3q06f786`oycu0FwRSp*4 zNtfIzRSoTHQtLE7tFXzulA{C2Sj4YhKI zq`g{&F9Pbp>Z1H_Q@a*aFuvrdY?F=Mtqwx!F1cvmHQnhcnglG2dy8+<7XmxQRwrhS z$9RjugzTIak#jpb%S3xN498O=-6tjp8d910hsEVENT$|(&&WlD`Re`J zm*tTG&+IoGc%2GB<6KU-s81&Jzh*h#Otj+9Jr}Iwx>y5Qt_90Ed^n`w%D=H$7>oA(issFVc;R3y5a@zl$7hG>IuXhweq`w8T3@^f?Q}}BV0am(3^u;4+~5=2V#5+;9b2O%QAjSO@2^}ljgik>WJpI)oNegix{B^2+ zCA=ILh1D;rzUsic3%=JV!x4@Z7pr;mSE1|qt+9TLB}Bqup*hpGLbMT(srz*=5Mhsu zTHmfr5M-m-9lpsEhqp{ppVmtNm&qDigF^-QEUAv>-t!FjX!@i^q^b@jCXZ0%sbmr1 z(JMB)dQ(v^W#^>6J7H>6o;77`UINw1HUnQa6@jxwEJ^QEIaKeW>)uO4#E3PE2a$az zK^K>}-A%(Xyz)9%t;O^mGHS@^d&d?)Qs%SSvQ-kyy#KtK^R*qnFpYZ!y~#m|Ey@Rp z0~;6zDo^{xFHgG5(eDtjJO}~aB?Um2n`N4*m#wV>;?_YtT`e($q%QT|)(8*(wggAd>ZxC6m zPZk)nx^=v#ibqNhSMmK{bFp!;^|M26H405+laB`n;5s!$jt7q?;_$BalQiW>y?C*!J?`0AUK}qV{ol zI0IKxw@T9xW^l(Od8H;}9iZdvLf3n?ToeXv|16#YRFS&~!WON-{aQzepFJB-(MV>A z?<^s7pBwit5zd`eNBC-9U&zAW@y*)qk99D2i_i6LA|6a4xx(O5QWdPz&pA9)?Ldjt3--33$;11vA-9_aJw1ZZpOhWG{dEmo z+r~z|UXli}+<_JQ`ke5(-!TnxiF6#hw`KRkhZ$h7cd43hl{mi=Xo4LK3sF-}`MTAO zWIQXe(Z{^!J?M9+YVb)BagM%qzbk`{;Y8;8*UryefX9}uxGFW8m5p3f6ikNmBiJGQg|v-8+&Vp=hH@O1Y2IS{5FzGR`aY%R!} zqwekUq8JaJ*rIR4nFRWeg6^c>uLsNUT~gGU*|3+cpjlNp9qd`FlM~(DP)fR}fL%Wk z_eVwOO25kohUUtBQVluae~E5quc9Y%4MxTv-b%u~xqL^&XzTIJQK_Q*oMl z=XD#To7dnKY;h zRq6WtHUUVFzF(go2#2L$bC&*x>FCSYuIph`2CWwq{qC-(;DCUaxyWiJc)dNRGc`k0 zA%1e0kPed|dhk=Ds%SW5Z@r`MR2GkV&cjx-ogbiTM>9vo)gl--u=%9oOT?=+e`Ga% zScX$#TiB~apFsgPv4OQnU@L7T8{-{;I@B5aCUYwB&BRG%sqKl_YAn%2sAO@|>sh)A zu@q3+c0XmvNHsIUF zRviaiJo&ZdQz4iy|Km!1Q!dP|Od9RiEJ6EOT~!YrBBOXv&F*+pAQ&}Nm`&;t4z6)( zQA=!@P)#!DQ@WN31yftR$4{Ce+Y_(Y6MiuuKh)&&RdPltnH| zGqpUp#TJgSZd842Co?drH`Dt);YWJPa`HPXg%3D296cg7oCJo{d0;a+AG5y;2uM1L-+aLi7gll#-aaV;>h7NI!qb@;7p3HJ+PDH}ACXLIs4MVE zhmN{7RT&OmJl#t_nudJSZ>9EHl)?HTw;i=QCAh)cJHnnq!lgH|?^5XUP|wG(uU3eN zF?#2!6!6gtqh`NGE8MijS3GPD=fX;H=xEKQpR(~V!zp>wjoA&ixIaB=-j$ABef-A* z)goc!Rzl0GLnSb%p7ZhQr8h9gnz)JQxh*{Dp*qCE^Bzyo@09;mcnh(h0q*w zY=*fr4_NH9otXq>p@a1xxt2>RUNEwH>p=>G+sbu|-W`cRI!lt|J#`<4sV7!MiwJpS&B1aszTTfDq!f)w(aO71 zzNVqTlahah{#0{KY4BNk z%~!KG_^+DCVN?A_cTG$m>YJLKdZ7LMKYskTLc~8Qh3w)LeqG%A94gLH<()S9hCI(- zG$o%~!V#0vtGkN&v0_n)X70^A{sje&W_T`WsVl>x*jV z-Vkq8JL4m&xwj2w`tDnm(-U42-g37(N(hs#q-z2ex#spYy@f_ zvuX}T=7ZxW!Ndz`kNy%Mja(Mp4si?&m&( zNy_2WddCR-R&=@PW_daAY|>KZbnKB%#LgI> zf8?g z-a37R1MOP1?WRUx-7uQ}88`;&neKEv68$LRW8l3pT>-+MyNbL%mSBLGpH1;XBCyo0 zE(~PHfL+vs;lutJnB;x4D&}G@G;O`KwAr#1di6|a^p@*Uf$O+_{JjROir>6EJK2q2 zu7r&&X5}t9m%40pSw9d}7a0Y< zTnO64`f2*e2=ZyXcGZub!p*Hd#ixn3?}+%rb&9*msF9~@UruCG{jKwK{C_(Y|N3Nb zdsIj_JAB5P(`NN&^}12rEw6$6RS!Wn)D_tF z&{1wOoBS>`f{9>r?Oapd(&)|-PS4uq(KBZaa^P>g`sC$^0%7k9(p*1IjS zE7MrH$>@Fb)p49Hw3^lb+zBor+Zw9}Yamfa@rcn-C>)>eYBm;YK(%Q#u1)J7@#XRA zr~R+S@ayR9+BU~FFlg-dNt*1$Z^6zKzd~N4TqnQRbFo$k84paPicbYEWy&6{;Yuho zTH-ye6png@WqmR-rEv0*;EZlNVWX&17XEFt3(|QQ<-bsO0QDM$>^b%b*bq{UN^DL8 zH?@l_j7~AI+>veNY}bV^ew6=Ey5t46LM@r=RylBGg#S0USQ#wE*eobm|EF!!ANmf& zH{yS!M(`PE61fq7%Z@Pm$KerF@|zue?(J}rn7@B4<;YX&4@eWGZ@s|afyQZfckS~} z1Bq=N6hGJ7P_DnGE+eQ0GV5bXm?)~SQuplS_OcETPu}%t>Qps$^nbX+F4T)piuUIz z3XX$Z)z`;C+diUxg-Wc~NDdl0?XdQpXoMc7!~IT-)hOe`%16;k8dm&bVyMK(mS9nJh_irA}xo+rzkt4I>N&+?b)h_GnndwT{b|{NsO0XBJ z4*!0md8ZXPk`%vE-0g%Xd&R~-rV*})I#i}1zUdg0+WeqeDhQQXlP5218A7SU4Z>#} zs{i@|l|9mIV5&wUxg5-DqucfI*tR1gx6%r>x^D z;JWuwn$K^EOp16nxfIGWWa~K-!FDAKyDT=|$Wry&h?pA&2K)X9gqw^MBg_xC_* zl7Hr%=y9Yb+cOGJ^Kr$md2Yd$$O8HhbkzPtGZZG!`ZEwYqHj;@#Id-wAcK;{XzkZ> z*x)05tKhA~de)aWqz|>g2@lD+E@GZjDO9vy5Jm-|iJG5zh>KeA!rdPV-y6|T<8sl2 zQ5$S?G%zFYtpU4(D;wSZMDoc~e!s_ZWw0STDo67<18yux2A%%d2#$^R++i~R*?s5z zw?^k5TulDXmhf@;jc+h%{E-R)vpa6unH_Q#S@uj3gX*aB?+T8@6t2BZ}dR# zJ~a)`@KH3OJ|Fihs2juCH%1?vAHpEMOY?~gLvU&C>LFwEYPcR763BR@6vGV8U1xG` z!gn;!EIe4!U=O*mGxObtkSIj6<-TGq6tEvX9CoZ$|-(3TF#C>YeU7$x#p5Qam)knUlGxx#PR-VI!^;yVk z<5<@mTY$2+mj<#ed*SX1Z}AvmNcWOmfIe$9@vmj}hwkJ5Z94vCzT~TffVQ;MF!tXw zVjbQn16f{~4W5Al6gmEov-wRMe7O3dM^`c$hmGH_Ut}kI3s`xi&Yt~%AuH}9HxK)Q z!swpu&J_(%bLvJaccdMtSzX$`#l8Y|NP5kB%{0KTEo`DYDeUos?(XYrS}icOc;w42 zm3GK$BEM$P+5wt9Jm(Y?GGP5Q_fIQ{Ot>57J(Th(2X~*E)V!C~h2@fstX4GvVC!@} zIx2<`iQaif<9@UbRBHxg*d~j>=tAUxS7aMj1(Fp;yOiRoOeV(cVMiprel`_D)(LY3 z8oCRn#c-xyINqAC3k8Kvis~3=qn)eES0TYH(DGb;etWSL!kpMu2mMGOXL4PkfFmEP z0-ETn2_iF!rq6*=kYJ6^Itws3R6(*~=^M4eP7qrpPvmDQ#zpg zK5e4wiMiK`thpi9w!~#IM(B1rTWKHY_q6AdD>-1`CVEp^d&1ksf6`1n| z0zozIWZ4$eC&2Q_GWV=U9nxIY5nm*1ba#CYcD9o(LOx%Gu3H}qK&(ZpOStqA4Bi~Q zHx-wSl(gxiK|E~`mZ!GDBb$cS?)(q-a8;oV$5R2yr)7lb*fFWBJ`sf}xA+`-mI6Lm z7U7>NJAgbRpqpc<2+s=~5+77Z!`%-1g2;=WKwD9kze#2c3_UivHo_W({;mlxe`}3^ zhBEVvO%(}7>EvoR?Hk0W>YF!PinQUUz^`?Q{jE5|*dcX<5Tnwyc1-8)Y{U1G6Jn={ zj%)i%PTGB@gkexey=)UZ5!;Z<`TO-sWp?`BK9Aw-}(&0?qpoDIb{&Q^s2^NzSGwX`twW4hq_>D3!N?$rA+Tf7p- zU$;g0x|ibfys36)K_b?`f=9J{R|Om%I&M0kp9|u{M@3e3>!F^&9K0upfMgM!0Fj0) z0t3?#)p}U_*M8*w4W9_k*GV%i|gf*w1qad$ti z$CtGHTeeYV;r5N1Da+IL5XV#|z+ci1ZyF8t?dZGkM65j3TuVHXqBG~7PuJs77Pzfy z`4S}A$$~t$_Cq+Sl(Z(>2QK2DCQXK_fQgKKv$$6^CNGQZlCEn*Et$uPE9KSjGc8up zfx8u-zX_ndz}kV&+x89D3$-JO{k@Jli!Vq={mOXsbgj52#!ZXYt^sThynOK=c)#rE zvY*70Lq`MiL{TPm=V#kb2kS8X0os^J^x@@Ij_KxK?T~H8&erg01YIxJ_qP#E zkq;u`xfFZGp$nas!lBXuO;;d zB>lgC#s9A*^$$+@zY``;D-#Vl0!CrxYk5vq$_h#@@V%~~8iCU!EkQ@E0X)QTXi-k3 z8@+>+Lr10;!7us!4B56Xc!}fY#etAE{M(!IPh}O}f0HoTQ>jzYQny8A!2_x#=9vgXNltG#lnZaJa8tw;@)Jy0X7FY?3SSDyQ3t zPlPj)H(UBKFZV)ZC{1O^&M3uAjL_hwoec&Y^%3J0DG*oPc*r6<33{lBK-z0*AoR=T zLd%XuwA9{h?wl6^$0L$1SBSKL#pDTU+aY7rJ7=W##V-Ij+wampXUf8Jhdr6Et|kB# zN9p@lQgPs(C~z~PBomk?q;Dwa+hfUg>$Jss6sjPlYC>kVd4z6d6kz zmPDFSA|(=brcTz4!iA z&wBOrSTAST@4eSvYm=KI+f*5fuhiIRyD<_@dGHF^y(RnifKQbRDJg z)WYbtI|FvcSe}-dpN)-HaUWbwuA=D%&!4BhH==&Be9&p%1oTlH$cgf5#Bz@Nol6(f z!DiPa-KHlP&qN5A@bEEFVEv1-kngFu{0pbf5?zpvZfMJ}ajW!aC}oXv&kIo}83 zI7*5Io!K3#?blKyueP7(9l5#oTYxMTWh$@FG zd@kpl6PftJb}Gw7x)8lb3iAX zLrn$x{9fj~nB4{H4_7JA+K}%{?pEduS34%oq;=M2gu~|EHhcHOX~6ry_U=S)J7o2&#?K*14zuh|6Ux8}TeQ$hYioiwB!|Rv*ZP?IGPq)`k z!RvoUTi^JRN(RxF_xeL)peu8=TYy9qX;5eMtdV?gpDp_3#Xm0^cF4PlFScdEyYbwg zNo1Unnt!mLrbGe8i7R-O3Kha#@zsU(gI?eg9Gq$z9u7&-yirNRbgVodWN+YF2th{D zcbG+0sNU_V5@k(ZcbSfkRvzW>NPpvq#1AjXSU8jE7RbN^0lLd^De}FmuYPIE;+yC* zz4M&mvRF`cUUqi7?=?K~r%rrKr2qnD)A_glAdkoA8IvF9hr^zg*~`0g!$D`g(6q#* zOsp$Vxcyp+frd54pWa(d?0F<}g7-`*6SsegkosD4{ji zd4n|Bui2}(a)(G9)J%4xeOePbY^r|jDppGBkuwDrk@nmc+n(SIdj@Jc9Pkm|TZcww z%Y&x!>u|#TL(HAv1n4)Bv*0A_km&EXue>K?9A;NBWKBu8Ot-4rq+^~8?znod&b>Sv zyuUpfka^66A9H3tnG6@AwqeHHKN^VB!H?38>Xt%cj53zI$wZ~j%6(P`3UH@##ia22 z5-cKD!L-INqDaA3{&{iDa9nG_g=m#D{IgwKRh2Z67}h##LZd$fb%EeB0hdy-&Tl@i zxMw~t+`ekCWz%W+-Njj@IuQVDJ8UNpIg*o(!M7T}Wy*nOvSq{8cVtYAI5%jcdGDmF zGj11vXe`!xJk|sDjB2B8vZSutdXd)>0d)xZI#D}iRfB;B>CrO5ckyP*GD$u1ee?3S zwE!-N!4kVoy7Y9i(=JptEu>q8M}&K)0>&E1PW*SBXJ0B%BcWin` zA^!0_DDvV$G)`UHe@$zCI#%x!C|#*s4VyNfzgvcNC}0}-L7TH2ByPnk+`3eXu|qK) z6YH`tZ;>$jt%-6R3J|)IpO_EjZ9I`~5l=u}=Bv?i$qvvhGAc-tgdkB8;C3H?G*_3)Af~y4w8t&{}sg=Ee__&{PFc_KzM8$8Z$0fi!L5-7bKBpTxa_> zj0i2u!nPENy@3OTaQ5Ti#RTaTNYYhTc(K0}ra96?W0Om9<(0^km6Cx_INM_&dSe3` z`5&$7JLU_|>8&aEB4V&<%Y{WdEbB2gwBl>CUj~-S z68Q0d_C3*%iD4#zTE*u|u;SgTWglFhqP^%*<)=#7c!d#s^^`*>Jd~N)7f+fKh#gtj zX2gCE_iAn3r?tT`O;!>rhQ@sDX=WaX!dt zGBrf$0cdocVMsed#;}wy7fk+mj2sWH9h4xIl;x^p`%mw#2mkBtq49nB@JggXSFz3> zuk~Ht(0`CD(jAO%C}%$fy5*9eTMie2_@mVgAq7tm{w$x;RLMcDMQRV;>pTR{^&tlq zLn}TrUAB?SkW6emlKAzfRym}mM!OA+RO9%mTlJg#=^)oW)0{mb79=|QC~-#&HkOf(k9 z27|YiIA_NL4X#u1@ao8Hc$u?VZTsr5 zaa*rX!FC3S-&dXfbFvx#I9<~`s8)b$SA>u4dJ={PCA)NTBVyn+d*snIq~dSG;MeBR z7f$R%Ne0A)-%u=cVBn88ONG5327;5tRsB11&iFobAy*sOz<#d8rBlXVgpHZg z3c6hlpe(iZ)o<@|Fj@5O*qK$O$k?(iZL41qh$eZhH73 zJ5n8MtXu(#_xCp&Ua7_C?2*L0%|+lpux*RqwWsju$Zz-6pRyrxiCO2m7h&k47%=(9 z=phPxO&3$K&BKMCygP2@C1KpvH)tSY90ibFqFAnwR7Jo0UM4oYo0WxQp$6uh=?z3Lv# z#G`urxxb|lxnu8Gdy#ci`wO2=f722?oUb+aUS=X>oGK2@%xJ*WbzY+D_LsvPPXF{V zc|-{>AF*%om!Y|C)sb3<9N4tP?#6HOKEJ7vy5*}wHtfF`<*NH50_iy+=kP!$a0;yI z?IQaz&CAbv8Ry+m9dU2)QnZN@G5DP*m>&AnC11f$>R_MW5ZktP){ z7M~o7g#^CH*UITx7_()Amht!*^s`I4ydgFdR=$0GXBSTcZsU%~P$f;D;9bd=vlqzc z`{$h_0&Gpd`PR0B{Zb?NSoPTSa+78i+xot*xmpgUXHCULzo){n<06fgmN_u$wc~fq z;sOkt(^`}#oevwAE`BPIngyZDBBVSND`BH%2uPIX!?}6KxpxIrK#o!LfegJmoU!@db>e&?6bK3(T(-6q zHSDui7`-S&4*KS`5C70%LqJ+jynPltb{&O`$!7GD&j<>YCF2sE1V3A!NJYug<(5G^ z6ERRbe88Hc46?s`HZTs*!KY(Cn9r{Jp0KxM=0Is^?t+)QN^}A7T`MUARq= zU6?Jh8P*LPDpkV`Z#80e#3W%E|HU_A#))_!)_IHK3I>d{x_sW%Lx;sHj|LC;M?!*W z&R(hG474~Eu!pCc+(7vDY=^m61#CBqR#WWA$A^j{dm?!XV4a}s0-vTUAa58F|687p z-_Smd zS`pOA+nc5PF|ltjdrqP0BXBA_xbBN%DyqBbmVe!w2{&G3=5El z;W!s&tV3A2-mD=HcG^+@~3QBPV{L)*y;Yjo5CpFG3nukhH5Z6Gh%v zuAt`^;KGD_&GNKXs9KTr(wTJ7k}%WW%05Xx*q>%3zB_am|JXqR>V@<3pBhm2X) z{cZDdhBP->uPQD#lr)It`^^4~RTtx>dmf8#?resG)iF7qq{+fpwpX5+TN+Sbw4L7P zQvu&z8f;(6EQMsvx?gM>!#J1Ci6djvJ>VO3(dPF}!*~@ZCztUWyk5XnpvYDSPh;HA z3;%3{&DTtFQ+0~BRf@0W`%ZM5&8%4Xn|!G^ z37^WjSq16NBYPvvTjA^UVyv%!0m{4&p8Lpu#?Wsc72kO&8rddiIZ8##L3iQQ z=CWNGsARtTruN4gn92=0`i49s%>S+`@@HWp7L_m9)@WD@POEMmn!B15rkp4--`V~g zzFcZrkGZX2^z5if>Zxu}I(|YTxiu5Bq(1ibz9l^s>k?3J@Q>1PCtoD~SGc}}W1 z&whE?QWAt(K_dr6yo%76>AttxzZp;R9!#^Be~kT+!KdurWn-y$YwIKKF7(>^gG=&p z7e;be3%~8Z4yPZrH{3~LVw3t2%`I2E|9NMyI+Oa(zmEU*nbdWY+m4<5y}&zEVkV9w z_-eODhK!pZ91>Eo&`d1`bd>(s_s|zVE$NTzu`Iy#6PX94-ouDbA%hwSw|QQ64jQipG&B7+u37vcpSVcz=Qcxb)1EOGi&A6jmY zdIK817|j!~XyfrN2#Kn)VQ99&A;Zc%ZE_|QEbDIb#bp3`v~h!_{3xn*t6#FzWMcd> z!NHV;)ev-8*8Vh47p#iBleTnT73^79x4vJe4UXO!IHAhZjl4^DF%Od>3y~AWa<>B+ zz`K;)UuH!KXdiF57TuW$=Lc5rOFWc|cHfTqtWxiU(H4GQ5Ay7t{Iw}Zi8MZ=8)z6X z9*kmdC7=}g|5R_JcK_mfL*O5y+pqPWv-8m$rjULwewAmPOT>V?W@09U%nc2-hG*{zDue) zow8-hWMkpdW#b}6J#u?XJ!0+Zf_&7~nwnq#S`C~=P zL#{JRB$PB`z*=Va}CXS~*&=SS!(CmWK}ewr;9AX#EI!!>fm< zoEq?(D_nW7I}7$)JblSgv=!N|2}vCrYR6tbiw)oJ6k=xP1k==|8SOKj)rG9`ahdUa zUQ>DwTA1j$KF^HD5aqw#?&;{kM@y zOS`Tkq{x89F^5KZhbwSsX~B4e_*KjcTPthwGyqpb?&6woX~RZ&Bg6gmq%~K7>4(s#kRMEcH%PG2X_NCNVWcNiPMaQ?eJrY7UAvKf+3Bv z$I9EE;CRQI(i6TFSQM$xbGOzLC$ln@Z9kFQXrJC)K$!~obY^{kd|4C7F?%uvPgTKr zCEj`Kvx5J5Ke9TL`v2aK|MvjF%Gln{;s(upo;lZHw!1X*1)j9yG_k|%hyT-=)QjHL zc^ic)AZ^q?UX-l{Z$I%Aa?~$|#(5vZ=AG>Y>!#{`f?Ro~& z>)3VPn?%D3wn4vm<2(!=i+t*7T@IaDE&__1=n&g)W574r3bXECRa3|@y|<8^uuDJ( zmI@W`krTSr;zDqIa&&v}wl|NA622 zt*!ui>+jtWGZY~92csyUHLHwoFbm5F0oH1-6{GxIb91-Ht5;5=^7@K z$-dmY`T+yT=DC21NIO(XnQAXoh=oa5DAQur2i?-L>%YdeB6H}SV9aPIwgk@P2`ATK z$ENJXA=mREr%kf#fm-a6hCNjx2Zo z{B0vSX7-f)<9odVJ_tTq#xesUP;bATZVlPCE0a5NsryXd)Ko%lVbIf`WFkaLm6q&!sXMz*Baz=BXAJy zuG%d>%kdcx7xbTLqL`{2IAJEzg~(*AEu(dYmCEZgiI zZOx3#tj?R6np#+2lxFcaBWpAJ|5J!8LtSJy=kWZ0itOCH3|3VvyZI6FnZ>=_+T<1= zzhrw~)F97o=8`1s$7*fXA=~CvS4kq-ECN zEvuB!4&5nu`%+oHqG%GFO-xl*a=*m9vl_9(95Wae6|^?^%}eaBc&xeVNfs!KB!@Zh zPT{{xR`{PL_AhMzck6tp{@uU-P%AIW8eKIvCdTdUjIIAJ@}~AF{2xA5+2*scStzpE z7+5pW-NR@0Cn?{V6HM<|^SupTo4!&yRZ0iO=Uwd2Oc{RjG?%^4xP3Jq zxn8BShm_pI(j3=!q3>EjWp1)n%JpEpdg=V$R_tfF+sdvI`!a4BanGKzz z*N05|Os}^wp<}IUgY6;pZj>zZ4|u+_8M$sSCTnIp$u!l&0RrZ=*y4SEeUu&pe&p=n zG+keh8`D!%^4_Olz^dkZd(+93(gksoYkA4UGp~crwFAw#==LlB=I>;Qcu3gtEqgQg zDEB{^d0!3|Z<+=sE>z$F&%&_m>n*4#|4`e2*$v8^I}(ONI?$$?O~C1BH@3{~-}yDL z8GAUm9vTkTqswxs9p`2rqmU9qs^@7P94(aPty@|Psb+Wi7JX`maQc+{UIJc?cPQ)sC3@PY2BHPyS~0!GgIz}Q9Wz&;!~MuwS@L(178`UVeg_#+F77Fr-YrM z!$5m^&(E{1q=la1>(Ld*Xm~&=;6CkC0=&yDT6N%KD;|;!;BOac1GliSxnk@M!0!2t zCau*851HCW_{zJmC7Ml7jXW`}5LuMDo2?n9weoyvWDrekX1KoRz) z3%xb{R7|$SCN8ZdkDXi6j~Rat>c-t(*Ix-8tVaPp?J3XJA{1omcP_6X-(L$Hw;t-P z!G5u3zP(pk@nT_<6dwnIU!iSL>iI%!(f(X~`*;r;Itm&2-@OgL7Au@&i+GA(KIrdQ zbef52qfDcvAa_ydn3kY1gPC0Wn=5Uy~|M7D#x3)fxGX?MZ=`*giDB8JoY}$ zZ&}NzK@rns@q#OuV63wF!>7DX=#I12=`$nS{N{CkYNra}rmdKPLsAcUT)y~mD7*(Y z`rW!bA|GOi&zlF6kjuo zM(KLFPdp3|Sw6jSe|RV4co&?Q5i5sn%_dPRo7(aAp*EV7)pJ}vFymXBUJlzeM?6O? zQ*iSMyBny{WHF|ZuyyY{^J zK%Q(jWz78Pi><|-Pgb^x%~Zp1sPBdQBjok2X{hZ%{xrC@#4PY+2W-lBF`03%hx3J9 z!N(?bKTkB0P8kxq8AEWzz)4iZP{j5NWHIWTbM@%HZg-IS|jNY82Bme zDIZx!1Y5KucHq<7(t^X=n=vFbVAmFsFQ5CtwSTd9RB@ua|`dWRgn<$Yt!PDt$`^u+II}5x&0~B^94V?)0pHm2{iVFXZ>* zu2KBs?9D>t+j;rto?&-*CpDXNRrV>CE9{nb?aF~>hq&{5$2#!5$8vVs=O(;c?;n5L zu^V3Vz4359#DG;}X<331oiLFz8u+m$7aebz)n32%961w)9rQ_q6UntEGmZAmC^OD@ zINCsm-mNz4%S_YJB_$~F_;V(_vW2jq^UuNQ_o(WAe{zvc$>i{y);5$CeX=e!umlym zcbr^xBcCnP*@Wu3dh>I1a;RnfEjN ztKi!#YR=o}gAp3n_o;T(LWuac-(9LPFlDP<^RT!FqzpODT`rQgT-+tRh9imptku?> zu(!KrVt>uf%v8rl%h=A?%FN!(?il$)MTK^gCT7m=yZoY=wV9oC$n)yZ+VR@RlB#W>D&0LN}SdnJFgGG27&0K{=cstEpgGE@0W`2xC zcn8h=q%ZG($x@c2X}~H@@-$%;Cy83Jijz#Qu!@sZZ?KA!TwPhkNwW7?#Ywh)tl}hH z8mllYGlq#Yw)^tl}i!Mpki>ZwIS5$+wSH zoa8&qA}&wzeZwkF^8LUnPV)W6Do*m9WffPTS0z*f>Rjhz%nO zL~LB3K*WX>YothQ*i#^4!eF7h}a0GK*UBQ1tK;QC=jucPJxJx zT-InCu~AHch>dazL~PVjAY!A10udYC6o}Xuq(H>RD+)wxyrV$G#xw;YHh!>1+ldXf zIV|7H+ldV>3Pfxyp+Lk2KLsK-)>0s1V*>>uHY6wzu^~%=hz%tQL~Hb7` zL~O)RAYvn#0udXT6o}ZMQy^mF83iIXswfb#(eO9oW3g%?I@+li@zG1gh>(|5j2IcG zVnoRl6(df*QZXXs7ZoE`=FX+ak7!v)#fX<>6pSn2%Ni<1 zyogXS;zgW-{k?pZp<=|#b}B}^?4e@Bi#in}UXD;P;^hPtBVJBZG2+FTiV-j7RE&7B zp(V#LcMDn`77QZeEsnu-xGNmPt@$)IAyOFjkrdr4bL z#fXkEs~(5>CbBY1ghv F{~OeTgaiNp literal 0 HcmV?d00001 diff --git a/source/tests/pd/model/test_autodiff.py b/source/tests/pd/model/test_autodiff.py index 1bd9dd0d0f..8442844a24 100644 --- a/source/tests/pd/model/test_autodiff.py +++ b/source/tests/pd/model/test_autodiff.py @@ -60,7 +60,7 @@ def stretch_box(old_coord, old_box, new_box): class ForceTest: def test( self, - ): + ) -> None: env.enable_prim(True) places = 5 delta = 1e-5 @@ -86,10 +86,10 @@ def np_infer_coord( ): result = eval_model( self.model, - paddle.to_tensor(coord).to(device=env.DEVICE).unsqueeze(0), + paddle.to_tensor(coord, place=env.DEVICE).unsqueeze(0), cell.unsqueeze(0), atype, - spins=paddle.to_tensor(spin).to(device=env.DEVICE).unsqueeze(0), + spins=paddle.to_tensor(spin, place=env.DEVICE).unsqueeze(0), ) # detach ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} @@ -100,10 +100,10 @@ def np_infer_spin( ): result = eval_model( self.model, - paddle.to_tensor(coord).to(device=env.DEVICE).unsqueeze(0), + paddle.to_tensor(coord, place=env.DEVICE).unsqueeze(0), cell.unsqueeze(0), atype, - spins=paddle.to_tensor(spin).to(device=env.DEVICE).unsqueeze(0), + spins=paddle.to_tensor(spin, place=env.DEVICE).unsqueeze(0), ) # detach ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} @@ -133,7 +133,7 @@ def ff_spin(_spin): class VirialTest: def test( self, - ): + ) -> None: places = 5 delta = 1e-4 natoms = 5 @@ -153,10 +153,10 @@ def np_infer( ): result = eval_model( self.model, - paddle.to_tensor(stretch_box(coord, cell, new_cell)) - .to(device="cpu") - .unsqueeze(0), - paddle.to_tensor(new_cell).to(device="cpu").unsqueeze(0), + paddle.to_tensor( + stretch_box(coord, cell, new_cell), place="cpu" + ).unsqueeze(0), + paddle.to_tensor(new_cell, place="cpu").unsqueeze(0), atype, ) # detach @@ -177,36 +177,35 @@ def ff(bb): class TestEnergyModelSeAForce(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_se_e2_a) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelSeAVirial(unittest.TestCase, VirialTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_se_e2_a) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1Force(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa1) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1Virial(unittest.TestCase, VirialTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa1) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2Force(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa2) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -214,7 +213,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelDPAUniVirial(unittest.TestCase, VirialTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa2) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -222,7 +221,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelHybridForce(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_hybrid) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -230,7 +229,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelHybridVirial(unittest.TestCase, VirialTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_hybrid) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -238,7 +237,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelZBLForce(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_zbl) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) @@ -246,7 +245,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelZBLVirial(unittest.TestCase, VirialTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_zbl) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) @@ -254,7 +253,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelSpinSeAForce(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_spin) self.type_split = False self.test_spin = True diff --git a/source/tests/pd/model/test_descriptor_dpa2.py b/source/tests/pd/model/test_descriptor_dpa2.py new file mode 100644 index 0000000000..12017bb840 --- /dev/null +++ b/source/tests/pd/model/test_descriptor_dpa2.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest +from pathlib import ( + Path, +) + +import numpy as np +import paddle + +from deepmd.pd.model.descriptor import ( + DescrptDPA2, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +CUR_DIR = os.path.dirname(__file__) + + +class TestDPA2(unittest.TestCase): + def setUp(self): + cell = [ + 5.122106549439247480e00, + 4.016537340154059388e-01, + 6.951654033828678081e-01, + 4.016537340154059388e-01, + 6.112136112297989143e00, + 8.178091365465004481e-01, + 6.951654033828678081e-01, + 8.178091365465004481e-01, + 6.159552512682983760e00, + ] + self.cell = ( + paddle.to_tensor(cell, dtype=env.GLOBAL_PD_FLOAT_PRECISION) + .reshape([1, 3, 3]) + .to(device=env.DEVICE) + ) + coord = [ + 2.978060152121375648e00, + 3.588469695887098077e00, + 2.792459820604495491e00, + 3.895592322591093115e00, + 2.712091020667753760e00, + 1.366836847133650501e00, + 9.955616170888935690e-01, + 4.121324820711413039e00, + 1.817239061889086571e00, + 3.553661462345699906e00, + 5.313046969500791583e00, + 6.635182659098815883e00, + 6.088601018589653080e00, + 6.575011420004332585e00, + 6.825240650611076099e00, + ] + self.coord = ( + paddle.to_tensor(coord, dtype=env.GLOBAL_PD_FLOAT_PRECISION) + .reshape([1, -1, 3]) + .to(device=env.DEVICE) + ) + self.atype = ( + paddle.to_tensor([0, 0, 0, 1, 1], dtype=paddle.int32) + .reshape([1, -1]) + .to(device=env.DEVICE) + ) + self.ref_d = paddle.to_tensor( + [ + 8.435412613327306630e-01, + -4.717109614540972440e-01, + -1.812643456954206256e00, + -2.315248767961955167e-01, + -7.112973006771171613e-01, + -4.162041919507591392e-01, + -1.505159810095323181e00, + -1.191652416985768403e-01, + 8.439214937875325617e-01, + -4.712976890460106594e-01, + -1.812605149396642856e00, + -2.307222236291133766e-01, + -7.115427800870099961e-01, + -4.164729253167227530e-01, + -1.505483119125936797e00, + -1.191288524278367872e-01, + 8.286420823261241297e-01, + -4.535033763979030574e-01, + -1.787877160970498425e00, + -1.961763875645104460e-01, + -7.475459187804838201e-01, + -5.231446874663764346e-01, + -1.488399984491664219e00, + -3.974117581747104583e-02, + 8.283793431613817315e-01, + -4.551551577556525729e-01, + -1.789253136645859943e00, + -1.977673627726055372e-01, + -7.448826048241211639e-01, + -5.161350182531234676e-01, + -1.487589463573479209e00, + -4.377376017839779143e-02, + 8.295404560710329944e-01, + -4.492219258475603216e-01, + -1.784484611185287450e00, + -1.901182059718481143e-01, + -7.537407667483000395e-01, + -5.384371277650709109e-01, + -1.490368056268364549e00, + -3.073744832541754762e-02, + ], + dtype=env.GLOBAL_PD_FLOAT_PRECISION, + place=env.DEVICE, + ) + self.file_model_param = Path(CUR_DIR) / "models" / "dpa2.pd" + self.file_type_embed = Path(CUR_DIR) / "models" / "dpa2_tebd.pd" + + def test_descriptor(self) -> None: + with open(Path(CUR_DIR) / "models" / "dpa2.json") as fp: + self.model_json = json.load(fp) + model_dpa2 = self.model_json + ntypes = len(model_dpa2["type_map"]) + dparams = model_dpa2["descriptor"] + dparams["ntypes"] = ntypes + assert dparams["type"] == "dpa2" + dparams.pop("type") + dparams["concat_output_tebd"] = False + dparams["use_tebd_bias"] = True + des = DescrptDPA2( + **dparams, + ).to(env.DEVICE) + target_dict = des.state_dict() + source_dict = paddle.load(str(self.file_model_param)) + # type_embd of repformer is removed + source_dict.pop("type_embedding.embedding.embedding_net.layers.0.bias") + type_embd_dict = paddle.load(str(self.file_type_embed)) + target_dict = translate_type_embd_dicts_to_dpa2( + target_dict, + source_dict, + type_embd_dict, + ) + des.set_state_dict(target_dict) + + coord = self.coord + atype = self.atype + box = self.cell + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + des.get_rcut(), + des.get_sel(), + mixed_types=des.mixed_types(), + box=box, + ) + descriptor, env_mat, diff, rot_mat, sw = des( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + self.assertAlmostEqual(6.0, des.get_rcut()) + self.assertEqual(30, des.get_nsel()) + self.assertEqual(2, des.get_ntypes()) + np.testing.assert_allclose( + descriptor.reshape([-1]).numpy(), self.ref_d.numpy(), atol=1e-10, rtol=1e-10 + ) + + dparams["concat_output_tebd"] = True + des = DescrptDPA2( + **dparams, + ).to(env.DEVICE) + descriptor, env_mat, diff, rot_mat, sw = des( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + + +def translate_type_embd_dicts_to_dpa2( + target_dict, + source_dict, + type_embd_dict, +): + all_keys = list(target_dict.keys()) + record = [False for ii in all_keys] + for kk, vv in source_dict.items(): + record[all_keys.index(kk)] = True + target_dict[kk] = vv + assert len(type_embd_dict.keys()) == 2 + it = iter(type_embd_dict.keys()) + for _ in range(2): + kk = next(it) + tk = "type_embedding." + kk + record[all_keys.index(tk)] = True + target_dict[tk] = type_embd_dict[kk] + record[all_keys.index("repinit.compress_data.0")] = True + record[all_keys.index("repinit.compress_info.0")] = True + assert all(record) + return target_dict diff --git a/source/tests/pd/model/test_dpa2.py b/source/tests/pd/model/test_dpa2.py new file mode 100644 index 0000000000..f441007cad --- /dev/null +++ b/source/tests/pd/model/test_dpa2.py @@ -0,0 +1,333 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import paddle + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DPDescrptDPA2 +from deepmd.dpmodel.descriptor.dpa2 import ( + RepformerArgs, + RepinitArgs, +) +from deepmd.pd.model.descriptor.dpa2 import ( + DescrptDPA2, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PD_FLOAT_PRECISION + + +class TestDescrptDPA2(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + dstd_2 = 0.1 + np.abs(dstd_2) + + for ( + riti, + riz, + rp1c, + rp1d, + rp1g, + rp1a, + rp2g, + rp2a, + rph, + rp2gate, + rus, + rpz, + sm, + prec, + ect, + ns, + ) in itertools.product( + ["concat", "strip"], # repinit_tebd_input_mode + [ + True, + ], # repinit_set_davg_zero + [True, False], # repformer_update_g1_has_conv + [True, False], # repformer_update_g1_has_drrd + [True, False], # repformer_update_g1_has_grrg + [ + False, + ], # repformer_update_g1_has_attn + [ + False, + ], # repformer_update_g2_has_g1g1 + [True, False], # repformer_update_g2_has_attn + [ + False, + ], # repformer_update_h2 + [ + True, + ], # repformer_attn2_has_gate + ["res_avg", "res_residual"], # repformer_update_style + [ + True, + ], # repformer_set_davg_zero + [ + True, + ], # smooth + ["float64"], # precision + [False, True], # use_econf_tebd + [ + False, + True, + ], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp) + ): + if ns and not rp1d and not rp1g: + continue + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 # marginal GPU test cases... + + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode=riti, + set_davg_zero=riz, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=rp1c, + update_g1_has_drrd=rp1d, + update_g1_has_grrg=rp1g, + update_g1_has_attn=rp1a, + update_g2_has_g1g1=rp2g, + update_g2_has_attn=rp2a, + update_h2=rph, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=rp2gate, + update_style=rus, + set_davg_zero=rpz, + use_sqrt_nnei=ns, + g1_out_conv=ns, + g1_out_mlp=ns, + ) + + # dpa2 new impl + dd0 = DescrptDPA2( + self.nt, + repinit=repinit, + repformer=repformer, + # kwargs for descriptor + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repinit.mean = paddle.to_tensor(davg, dtype=dtype).to(device=env.DEVICE) + dd0.repinit.stddev = paddle.to_tensor(dstd, dtype=dtype).to( + device=env.DEVICE + ) + dd0.repformers.mean = paddle.to_tensor(davg_2, dtype=dtype).to( + device=env.DEVICE + ) + dd0.repformers.stddev = paddle.to_tensor(dstd_2, dtype=dtype).to( + device=env.DEVICE + ) + rd0, _, _, _, _ = dd0( + paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE), + paddle.to_tensor(self.atype_ext, dtype="int64").to(device=env.DEVICE), + paddle.to_tensor(self.nlist, dtype="int64").to(device=env.DEVICE), + paddle.to_tensor(self.mapping, dtype="int64").to(device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA2.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE), + paddle.to_tensor(self.atype_ext, dtype="int64").to(device=env.DEVICE), + paddle.to_tensor(self.nlist, dtype="int64").to(device=env.DEVICE), + paddle.to_tensor(self.mapping, dtype="int64").to(device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + # dp impl + dd2 = DPDescrptDPA2.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, self.atype_ext, self.nlist, self.mapping + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + ) + + @unittest.skip("skip jit in paddle temporally") + def test_jit( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + riti, + riz, + rp1c, + rp1d, + rp1g, + rp1a, + rp2g, + rp2a, + rph, + rp2gate, + rus, + rpz, + sm, + prec, + ect, + ns, + ) in itertools.product( + ["concat", "strip"], # repinit_tebd_input_mode + [ + True, + ], # repinit_set_davg_zero + [ + True, + ], # repformer_update_g1_has_conv + [ + True, + ], # repformer_update_g1_has_drrd + [ + True, + ], # repformer_update_g1_has_grrg + [ + True, + ], # repformer_update_g1_has_attn + [ + True, + ], # repformer_update_g2_has_g1g1 + [ + True, + ], # repformer_update_g2_has_attn + [ + False, + ], # repformer_update_h2 + [ + True, + ], # repformer_attn2_has_gate + ["res_avg", "res_residual"], # repformer_update_style + [ + True, + ], # repformer_set_davg_zero + [ + True, + ], # smooth + ["float64"], # precision + [False, True], # use_econf_tebd + [True], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp) + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode=riti, + set_davg_zero=riz, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=rp1c, + update_g1_has_drrd=rp1d, + update_g1_has_grrg=rp1g, + update_g1_has_attn=rp1a, + update_g2_has_g1g1=rp2g, + update_g2_has_attn=rp2a, + update_h2=rph, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=rp2gate, + update_style=rus, + set_davg_zero=rpz, + use_sqrt_nnei=ns, + g1_out_conv=ns, + g1_out_mlp=ns, + ) + + # dpa2 new impl + dd0 = DescrptDPA2( + self.nt, + repinit=repinit, + repformer=repformer, + # kwargs for descriptor + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repinit.mean = paddle.to_tensor(davg, dtype=dtype).to(device=env.DEVICE) + dd0.repinit.stddev = paddle.to_tensor(dstd, dtype=dtype).to( + device=env.DEVICE + ) + dd0.repformers.mean = paddle.to_tensor(davg_2, dtype=dtype).to( + device=env.DEVICE + ) + dd0.repformers.stddev = paddle.to_tensor(dstd_2, dtype=dtype).to( + device=env.DEVICE + ) + model = paddle.jit.to_static(dd0) diff --git a/source/tests/pd/model/test_forward_lower.py b/source/tests/pd/model/test_forward_lower.py index db6497b605..1d924e2d3d 100644 --- a/source/tests/pd/model/test_forward_lower.py +++ b/source/tests/pd/model/test_forward_lower.py @@ -140,22 +140,21 @@ def test( class TestEnergyModelSeA(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_se_e2_a) self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_dpa1) self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_dpa2) self.model = get_model(model_params).to(env.DEVICE) @@ -163,7 +162,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelZBL(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_zbl) self.model = get_model(model_params).to(env.DEVICE) @@ -171,7 +170,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelSpinSeA(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_spin) self.test_spin = True @@ -180,7 +179,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelSpinDPA1(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_spin) model_params["descriptor"] = copy.deepcopy(model_dpa1)["descriptor"] @@ -192,7 +191,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelSpinDPA2(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_spin) model_params["descriptor"] = copy.deepcopy(model_dpa2)["descriptor"] diff --git a/source/tests/pd/model/test_null_input.py b/source/tests/pd/model/test_null_input.py index 5d67491943..29d2f84eea 100644 --- a/source/tests/pd/model/test_null_input.py +++ b/source/tests/pd/model/test_null_input.py @@ -23,6 +23,7 @@ ) from .test_permutation import ( model_dpa1, + model_dpa2, model_se_e2_a, ) @@ -32,7 +33,7 @@ class NullTest: def test_nloc_1( self, - ): + ) -> None: natoms = 1 generator = paddle.seed(GLOBAL_SEED) # paddle.seed(1000) @@ -60,7 +61,7 @@ def test_nloc_1( def test_nloc_2_far( self, - ): + ) -> None: natoms = 2 generator = paddle.seed(GLOBAL_SEED) cell = paddle.rand([3, 3], dtype=dtype).to(device=env.DEVICE) @@ -100,3 +101,10 @@ def setUp(self): model_params = copy.deepcopy(model_dpa1) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelDPA2(unittest.TestCase, NullTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) diff --git a/source/tests/pd/model/test_permutation.py b/source/tests/pd/model/test_permutation.py index 135c5ea819..88672457a9 100644 --- a/source/tests/pd/model/test_permutation.py +++ b/source/tests/pd/model/test_permutation.py @@ -416,7 +416,6 @@ def setUp(self) -> None: self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(unittest.TestCase, PermutationTest): def setUp(self) -> None: model_params = copy.deepcopy(model_dpa2) diff --git a/source/tests/pd/model/test_rot.py b/source/tests/pd/model/test_rot.py index 85c90dc60f..84a0d3d724 100644 --- a/source/tests/pd/model/test_rot.py +++ b/source/tests/pd/model/test_rot.py @@ -176,7 +176,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_dpa2) diff --git a/source/tests/pd/model/test_rot_denoise.py b/source/tests/pd/model/test_rot_denoise.py index 74d5d41791..4a1841d10b 100644 --- a/source/tests/pd/model/test_rot_denoise.py +++ b/source/tests/pd/model/test_rot_denoise.py @@ -18,8 +18,9 @@ from ..common import ( eval_model, ) -from .test_permutation_denoise import ( # model_dpa2, +from .test_permutation_denoise import ( model_dpa1, + model_dpa2, ) dtype = paddle.float64 @@ -112,6 +113,14 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) +@unittest.skip("support of the denoise is temporally disabled") +class TestDenoiseModelDPA2(unittest.TestCase, RotDenoiseTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + # @unittest.skip("hybrid not supported at the moment") # class TestEnergyModelHybrid(unittest.TestCase, TestRotDenoise): # def setUp(self): diff --git a/source/tests/pd/model/test_smooth.py b/source/tests/pd/model/test_smooth.py index cc50043ad8..f907e6f4ee 100644 --- a/source/tests/pd/model/test_smooth.py +++ b/source/tests/pd/model/test_smooth.py @@ -20,6 +20,7 @@ ) from .test_permutation import ( # model_dpau, model_dpa1, + model_dpa2, model_se_e2_a, ) @@ -189,6 +190,36 @@ def setUp(self): self.aprec = 1e-5 +class TestEnergyModelDPA2(unittest.TestCase, SmoothTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + model_params["descriptor"]["repinit"]["rcut"] = 8 + model_params["descriptor"]["repinit"]["rcut_smth"] = 3.5 + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + self.epsilon, self.aprec = 1e-5, 1e-4 + + +class TestEnergyModelDPA2_1(unittest.TestCase, SmoothTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + model_params["fitting_net"]["type"] = "ener" + self.type_split = True + self.test_virial = False + self.model = get_model(model_params).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + +class TestEnergyModelDPA2_2(unittest.TestCase, SmoothTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + model_params["fitting_net"]["type"] = "ener" + self.type_split = True + self.test_virial = False + self.model = get_model(model_params).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau diff --git a/source/tests/pd/model/test_trans.py b/source/tests/pd/model/test_trans.py index 3fae49d598..f050596996 100644 --- a/source/tests/pd/model/test_trans.py +++ b/source/tests/pd/model/test_trans.py @@ -110,7 +110,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_dpa2) diff --git a/source/tests/pd/model/test_unused_params.py b/source/tests/pd/model/test_unused_params.py new file mode 100644 index 0000000000..bf92171da1 --- /dev/null +++ b/source/tests/pd/model/test_unused_params.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import paddle + +from deepmd.pd.model.model import ( + get_model, +) +from deepmd.pd.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) +from ..common import ( + eval_model, +) +from .test_permutation import ( + model_dpa2, +) + +dtype = paddle.float64 + + +@unittest.skip("paddle do not support unpacking grad_fn.next_functions") +class TestUnusedParamsDPA2(unittest.TestCase): + def test_unused(self): + import itertools + + for conv, drrd, grrg, attn1, g1g1, attn2, h2 in itertools.product( + [True], + [True], + [True], + [True], + [True], + [True], + [True], + ): + if (not drrd) and (not grrg) and h2: + # skip the case h2 is not envolved + continue + if (not grrg) and (not conv): + # skip the case g2 is not envolved + continue + model = copy.deepcopy(model_dpa2) + model["descriptor"]["repformer"]["nlayers"] = 2 + # model["descriptor"]["combine_grrg"] = cmbg2 + model["descriptor"]["repformer"]["update_g1_has_conv"] = conv + model["descriptor"]["repformer"]["update_g1_has_drrd"] = drrd + model["descriptor"]["repformer"]["update_g1_has_grrg"] = grrg + model["descriptor"]["repformer"]["update_g1_has_attn"] = attn1 + model["descriptor"]["repformer"]["update_g2_has_g1g1"] = g1g1 + model["descriptor"]["repformer"]["update_g2_has_attn"] = attn2 + model["descriptor"]["repformer"]["update_h2"] = h2 + model["fitting_net"]["neuron"] = [12, 12, 12] + self._test_unused(model) + + def _test_unused(self, model_params): + self.model = get_model(model_params).to(env.DEVICE) + natoms = 5 + generator = paddle.seed(GLOBAL_SEED) + cell = paddle.rand([3, 3], dtype=dtype).to(device=env.DEVICE) + cell = (cell + cell.T) + 5.0 * paddle.eye(3).to(device=env.DEVICE) + coord = paddle.rand([natoms, 3], dtype=dtype).to(device=env.DEVICE) + coord = paddle.matmul(coord, cell) + atype = paddle.to_tensor([0, 0, 0, 1, 1]).to(env.DEVICE) + idx_perm = [1, 0, 4, 3, 2] + result_0 = eval_model(self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype) + test_keys = ["energy", "force", "virial"] + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + + # use computation graph to find all contributing tensors + def get_contributing_params(y, top_level=True): + nf = y.grad_fn.next_functions if top_level else y.next_functions + for f, _ in nf: + try: + yield f.variable + except AttributeError: + pass # node has no tensor + if f is not None: + yield from get_contributing_params(f, top_level=False) + + contributing_parameters = set(get_contributing_params(ret0["energy"])) + all_parameters = set(self.model.parameters()) + non_contributing = all_parameters - contributing_parameters + self.assertEqual(len(non_contributing), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pd/model/water/multitask.json b/source/tests/pd/model/water/multitask.json index 83524a8b77..2786afca59 100644 --- a/source/tests/pd/model/water/multitask.json +++ b/source/tests/pd/model/water/multitask.json @@ -10,7 +10,8 @@ "type": "se_e2_a", "sel": [ 46, - 92 + 92, + 4 ], "rcut_smth": 0.50, "rcut": 6.00, diff --git a/source/tests/pd/model/water/multitask_sharefit.json b/source/tests/pd/model/water/multitask_sharefit.json index 246b5992f7..934ef04998 100644 --- a/source/tests/pd/model/water/multitask_sharefit.json +++ b/source/tests/pd/model/water/multitask_sharefit.json @@ -91,14 +91,14 @@ "stat_file": "./stat_files/model_1.hdf5", "training_data": { "systems": [ - "pt/water/data/data_0" + "pd/water/data/data_0" ], "batch_size": 1, "_comment": "that's all" }, "validation_data": { "systems": [ - "pt/water/data/data_0" + "pd/water/data/data_0" ], "batch_size": 1, "_comment": "that's all" @@ -108,14 +108,14 @@ "stat_file": "./stat_files/model_2.hdf5", "training_data": { "systems": [ - "pt/water/data/data_0" + "pd/water/data/data_0" ], "batch_size": 1, "_comment": "that's all" }, "validation_data": { "systems": [ - "pt/water/data/data_0" + "pd/water/data/data_0" ], "batch_size": 1, "_comment": "that's all" diff --git a/source/tests/pd/test_finetune.py b/source/tests/pd/test_finetune.py index f82f7a8cd0..769ea6f6d3 100644 --- a/source/tests/pd/test_finetune.py +++ b/source/tests/pd/test_finetune.py @@ -197,7 +197,7 @@ def test_finetune_change_out_bias(self): self.tearDown() - def test_finetune_change_type(self): + def test_finetune_change_type(self) -> None: if not self.mixed_types: # skip when not mixed_types return @@ -284,7 +284,7 @@ def test_finetune_change_type(self): self.tearDown() - def tearDown(self): + def tearDown(self) -> None: for f in os.listdir("."): if f.startswith("model") and f.endswith(".pd"): os.remove(f) @@ -295,7 +295,7 @@ def tearDown(self): class TestEnergyModelSeA(FinetuneTest, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) @@ -311,7 +311,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyZBLModelSeA(FinetuneTest, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) @@ -327,7 +327,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyDOSModelSeA(FinetuneTest, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "dos/input.json") with open(input_json) as f: self.config = json.load(f) @@ -342,7 +342,7 @@ def setUp(self): class TestEnergyModelDPA1(FinetuneTest, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) @@ -356,9 +356,8 @@ def setUp(self): self.testkey = None -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(FinetuneTest, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) diff --git a/source/tests/pd/test_multitask.py b/source/tests/pd/test_multitask.py index d59990dcca..72ad251068 100644 --- a/source/tests/pd/test_multitask.py +++ b/source/tests/pd/test_multitask.py @@ -30,6 +30,8 @@ from .model.test_permutation import ( model_dpa1, + model_dpa2, + model_dpa2tebd, model_se_e2_a, ) @@ -40,6 +42,13 @@ def setUpModule() -> None: with open(multitask_template_json) as f: multitask_template = json.load(f) + global multitask_sharefit_template + multitask_sharefit_template_json = str( + Path(__file__).parent / "water/multitask_sharefit.json" + ) + with open(multitask_sharefit_template_json) as f: + multitask_sharefit_template = json.load(f) + class MultiTaskTrainTest: def test_multitask_train(self) -> None: @@ -227,6 +236,46 @@ def tearDown(self) -> None: MultiTaskTrainTest.tearDown(self) +class TestMultiTaskSeASharefit(unittest.TestCase, MultiTaskTrainTest): + def setUp(self) -> None: + multitask_se_e2_a = deepcopy(multitask_sharefit_template) + multitask_se_e2_a["model"]["shared_dict"]["my_descriptor"] = model_se_e2_a[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "se_e2_a_share_fit" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_se_e2_a + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["stat_file"] = ( + f"{self.stat_files}/model_1" + ) + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["stat_file"] = ( + f"{self.stat_files}/model_2" + ) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + self.share_fitting = True + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + class TestMultiTaskDPA1(unittest.TestCase, MultiTaskTrainTest): def setUp(self) -> None: multitask_DPA1 = deepcopy(multitask_template) @@ -266,5 +315,83 @@ def tearDown(self) -> None: MultiTaskTrainTest.tearDown(self) +class TestMultiTaskDPA2(unittest.TestCase, MultiTaskTrainTest): + def setUp(self) -> None: + multitask_DPA2 = deepcopy(multitask_template) + multitask_DPA2["model"]["shared_dict"]["my_descriptor"] = model_dpa2[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "DPA2" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_DPA2 + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["stat_file"] = ( + f"{self.stat_files}/model_1" + ) + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["stat_file"] = ( + f"{self.stat_files}/model_2" + ) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + +class TestMultiTaskDPA2Tebd(unittest.TestCase, MultiTaskTrainTest): + def setUp(self) -> None: + multitask_DPA2 = deepcopy(multitask_template) + multitask_DPA2["model"]["shared_dict"]["my_descriptor"] = model_dpa2tebd[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "DPA2Tebd" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_DPA2 + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["stat_file"] = ( + f"{self.stat_files}/model_1" + ) + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["stat_file"] = ( + f"{self.stat_files}/model_2" + ) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pd/test_training.py b/source/tests/pd/test_training.py index c3d65c09df..8958dcb165 100644 --- a/source/tests/pd/test_training.py +++ b/source/tests/pd/test_training.py @@ -24,6 +24,7 @@ from .model.test_permutation import ( model_dpa1, + model_dpa2, model_se_e2_a, ) @@ -195,5 +196,21 @@ def tearDown(self) -> None: DPTrainTest.tearDown(self) +class TestEnergyModelDPA2(unittest.TestCase, DPTrainTest): + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa2) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pd/test_update_sel.py b/source/tests/pd/test_update_sel.py index e7b1acf6ff..10342357c6 100644 --- a/source/tests/pd/test_update_sel.py +++ b/source/tests/pd/test_update_sel.py @@ -31,7 +31,7 @@ def setUp(self) -> None: return super().setUp() @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_one_sel(self, sel_mock): + def test_update_one_sel(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [10, 20] min_nbor_dist, sel = self.update_sel.update_one_sel(None, None, 6, "auto") @@ -45,7 +45,7 @@ def test_update_one_sel(self, sel_mock): @unittest.skip("Skip for not implemented yet") @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_sel_hybrid(self, sel_mock): + def test_update_sel_hybrid(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [10, 20] jdata = { @@ -76,7 +76,7 @@ def test_update_sel_hybrid(self, sel_mock): self.assertEqual(jdata, expected_out) @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_sel(self, sel_mock): + def test_update_sel(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [10, 20] jdata = { @@ -90,9 +90,8 @@ def test_update_sel(self, sel_mock): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) - @unittest.skip("Skip for not implemented yet") @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_sel_atten_auto(self, sel_mock): + def test_update_sel_atten_auto(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [25] jdata = { @@ -118,9 +117,8 @@ def test_update_sel_atten_auto(self, sel_mock): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) - @unittest.skip("Skip for not implemented yet") @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_sel_atten_int(self, sel_mock): + def test_update_sel_atten_int(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [25] jdata = { @@ -146,9 +144,8 @@ def test_update_sel_atten_int(self, sel_mock): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) - @unittest.skip("Skip for not implemented yet") @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_sel_atten_list(self, sel_mock): + def test_update_sel_atten_list(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [25] jdata = { @@ -174,7 +171,50 @@ def test_update_sel_atten_list(self, sel_mock): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) - def test_skip_frozen(self): + @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") + def test_update_sel_dpa2_auto(self, sel_mock) -> None: + sel_mock.return_value = self.mock_min_nbor_dist, [25] + + jdata = { + "model": { + "descriptor": { + "type": "dpa2", + "repinit": { + "rcut": 6.0, + "nsel": "auto", + "three_body_rcut": 4.0, + "three_body_sel": "auto", + }, + "repformer": { + "rcut": 4.0, + "nsel": "auto", + }, + } + }, + "training": {"training_data": {}}, + } + expected_out = { + "model": { + "descriptor": { + "type": "dpa2", + "repinit": { + "rcut": 6.0, + "nsel": 28, + "three_body_rcut": 4.0, + "three_body_sel": 28, + }, + "repformer": { + "rcut": 4.0, + "nsel": 28, + }, + } + }, + "training": {"training_data": {}}, + } + jdata = update_sel(jdata) + self.assertEqual(jdata, expected_out) + + def test_skip_frozen(self) -> None: jdata = { "model": { "type": "frozen", @@ -185,7 +225,7 @@ def test_skip_frozen(self): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) - def test_wrap_up_4(self): + def test_wrap_up_4(self) -> None: self.assertEqual(self.update_sel.wrap_up_4(12), 3 * 4) self.assertEqual(self.update_sel.wrap_up_4(13), 4 * 4) self.assertEqual(self.update_sel.wrap_up_4(14), 4 * 4)