Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): DPA-2 repinit compress #4329

Merged
merged 3 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@
build_multiple_neighbor_list,
get_multiple_nlist_key,
)
from deepmd.pt.utils.tabulate import (
DPTabulate,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.pt.utils.utils import (
ActivationFn,
to_numpy_array,
)
from deepmd.utils.data_system import (
Expand Down Expand Up @@ -306,6 +310,7 @@
# set trainable
for param in self.parameters():
param.requires_grad = trainable
self.compress = False

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -859,3 +864,85 @@
)
local_jdata_cpy["repformer"]["nsel"] = repformer_sel[0]
return local_jdata_cpy, min_nbor_dist

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data.

Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
# do some checks before the mocel compression process
if self.compress:
raise ValueError("Compression is already enabled.")

Check warning on line 893 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L893

Added line #L893 was not covered by tests
assert (
not self.repinit.resnet_dt
), "Model compression error: repinit resnet_dt must be false!"
for tt in self.repinit.exclude_types:
if (tt[0] not in range(self.repinit.ntypes)) or (

Check warning on line 898 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L898

Added line #L898 was not covered by tests
tt[1] not in range(self.repinit.ntypes)
):
raise RuntimeError(

Check warning on line 901 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L901

Added line #L901 was not covered by tests
"Repinit exclude types"
+ str(tt)
+ " must within the number of atomic types "
+ str(self.repinit.ntypes)
+ "!"
)
if (
self.repinit.ntypes * self.repinit.ntypes - len(self.repinit.exclude_types)
== 0
):
raise RuntimeError(

Check warning on line 912 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L912

Added line #L912 was not covered by tests
"Repinit empty embedding-nets are not supported in model compression!"
)

if self.repinit.attn_layer != 0:
raise RuntimeError(

Check warning on line 917 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L917

Added line #L917 was not covered by tests
"Cannot compress model when repinit attention layer is not 0."
)

if self.repinit.tebd_input_mode != "strip":
raise RuntimeError(

Check warning on line 922 in deepmd/pt/model/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/dpa2.py#L922

Added line #L922 was not covered by tests
"Cannot compress model when repinit tebd_input_mode == 'concat'"
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved

# repinit doesn't have a serialize method
data = self.serialize()
self.table = DPTabulate(
self,
data["repinit_args"]["neuron"],
data["repinit_args"]["type_one_side"],
data["exclude_types"],
ActivationFn(data["repinit_args"]["activation_function"]),
)
self.table_config = [
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
]
self.lower, self.upper = self.table.build(
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
)

self.repinit.enable_compression(
self.table.data, self.table_config, self.lower, self.upper
)
self.compress = True
19 changes: 14 additions & 5 deletions deepmd/pt/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ def __init__(
raise RuntimeError("Unknown activation function type!")

self.activation_fn = activation_fn
self.davg = self.descrpt.serialize()["@variables"]["davg"]
self.dstd = self.descrpt.serialize()["@variables"]["dstd"]
self.ntypes = self.descrpt.get_ntypes()
serialized = self.descrpt.serialize()
if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA2):
serialized = serialized["repinit_variable"]
self.davg = serialized["@variables"]["davg"]
self.dstd = serialized["@variables"]["dstd"]
self.embedding_net_nodes = serialized["embeddings"]["networks"]
njzjz marked this conversation as resolved.
Show resolved Hide resolved

self.embedding_net_nodes = self.descrpt.serialize()["embeddings"]["networks"]
self.ntypes = self.descrpt.get_ntypes()

self.layer_size = self._get_layer_size()
self.table_size = self._get_table_size()
Expand Down Expand Up @@ -291,7 +294,13 @@ def _layer_1(self, x, w, b):
return t, self.activation_fn(torch.matmul(x, w) + b) + t

def _get_descrpt_type(self):
if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA1):
if isinstance(
self.descrpt,
(
deepmd.pt.model.descriptor.DescrptDPA1,
deepmd.pt.model.descriptor.DescrptDPA2,
),
):
return "Atten"
elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeA):
return "A"
Expand Down
149 changes: 149 additions & 0 deletions source/tests/pt/model/test_compressed_descriptor_dpa2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from typing import (
Any,
)

import numpy as np
import torch

from deepmd.dpmodel.descriptor.dpa2 import (
RepformerArgs,
RepinitArgs,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.pt.model.descriptor.dpa2 import (
DescrptDPA2,
)
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt
from deepmd.pt.utils.nlist import (
extend_coord_with_ghosts as extend_coord_with_ghosts_pt,
)

from ...consistent.common import (
parameterized,
)


def eval_pt_descriptor(
pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False
) -> Any:
ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt(
torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3),
torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1),
torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3),
pt_obj.get_rcut(),
)
nlist = build_neighbor_list_pt(
ext_coords,
ext_atype,
natoms[0],
pt_obj.get_rcut(),
pt_obj.get_sel(),
distinguish_types=(not mixed_types),
)
result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping)
return result


@parameterized(("float32", "float64"), (True, False))
class TestDescriptorDPA2(unittest.TestCase):
def setUp(self):
(self.dtype, self.type_one_side) = self.param
if self.dtype == "float32":
self.skipTest("FP32 has bugs:")
# ../../../../deepmd/pt/model/descriptor/repformer_layer.py:521: in forward
# torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni)
# E RuntimeError: expected scalar type Float but found Double
if self.dtype == "float32":
njzjz marked this conversation as resolved.
Show resolved Hide resolved
self.atol = 1e-5
elif self.dtype == "float64":
self.atol = 1e-10
self.seed = 21
self.sel = [10]
self.rcut_smth = 5.80
self.rcut = 6.00
self.neuron = [6, 12, 24]
self.axis_neuron = 3
self.ntypes = 2
self.coords = np.array(
[
12.83,
2.56,
2.18,
12.09,
2.87,
2.74,
00.25,
3.32,
1.68,
3.36,
3.00,
1.81,
3.51,
2.51,
2.60,
4.27,
3.22,
1.56,
],
dtype=GLOBAL_NP_FLOAT_PRECISION,
)
self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32)
self.box = np.array(
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
dtype=GLOBAL_NP_FLOAT_PRECISION,
)
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)

repinit = RepinitArgs(
rcut=self.rcut,
rcut_smth=self.rcut_smth,
nsel=10,
tebd_input_mode="strip",
type_one_side=self.type_one_side,
)
repformer = RepformerArgs(
rcut=self.rcut - 1,
rcut_smth=self.rcut_smth - 1,
nsel=9,
)

self.descriptor = DescrptDPA2(
ntypes=self.ntypes,
repinit=repinit,
repformer=repformer,
precision=self.dtype,
)

def test_compressed_forward(self):
result_pt = eval_pt_descriptor(
self.descriptor,
self.natoms,
self.coords,
self.atype,
self.box,
)
self.descriptor.enable_compression(0.5)
result_pt_compressed = eval_pt_descriptor(
self.descriptor,
self.natoms,
self.coords,
self.atype,
self.box,
)

self.assertEqual(result_pt.shape, result_pt_compressed.shape)
torch.testing.assert_close(
result_pt,
result_pt_compressed,
atol=self.atol,
rtol=self.atol,
)


if __name__ == "__main__":
unittest.main()
njzjz marked this conversation as resolved.
Show resolved Hide resolved