Skip to content

Commit

Permalink
introduce EmbeddingNet
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 8, 2024
1 parent 4f19ea3 commit e12c10f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 18 deletions.
51 changes: 37 additions & 14 deletions deepmd/descriptor/se.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
defaultdict,
)
from typing import (
List,
Tuple,
)

Expand All @@ -15,8 +16,8 @@
get_embedding_net_variables_from_graph_def,
get_tensor_by_name_from_graph,
)
from deepmd_utils.model_format import (
NativeNet,
from deepmd_utils.model_format.network import (
EmbeddingNet,
)

from .descriptor import (
Expand Down Expand Up @@ -169,21 +170,43 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict):
local_jdata_cpy = local_jdata.copy()
return update_one_sel(global_jdata, local_jdata_cpy, False)

def to_dp_variables(self, variables: dict) -> dict:
"""Convert the variables to deepmd format.
def serialize_network(
self,
in_dim: int,
neuron: List[int],
activation_function: str,
resnet_dt: bool,
variables: dict,
) -> dict:
"""Serialize network.
Parameters
----------
in_dim : int
The input dimension
neuron : List[int]
The neuron list
activation_function : str
The activation function
resnet_dt : bool
Whether to use resnet
variables : dict
The input variables
Returns
-------
dict
The converted variables
The converted network data
"""
# TODO: unclear how to hand suffix, maybe we need to add a suffix argument?
networks = defaultdict(NativeNet)
networks = defaultdict(
lambda: EmbeddingNet(
in_dim=in_dim,
neuron=neuron,
activation_function=activation_function,
resnet_dt=resnet_dt,
)
)
for key, value in variables.items():
m = re.search(EMBEDDING_NET_PATTERN, key)
m = [mm for mm in m.groups() if mm is not None]
Expand All @@ -196,29 +219,29 @@ def to_dp_variables(self, variables: dict) -> dict:
return {key: value.serialize() for key, value in networks.items()}

@classmethod
def from_dp_variables(cls, variables: dict) -> dict:
"""Convert the variables from deepmd format.
def deserialize_network(cls, data: dict) -> Tuple[List[int], str, bool, dict, str]:
"""Deserialize network.
Parameters
----------
variables : dict
The input variables
data : dict
The input network data
Returns
-------
dict
The converted variables
variables : dict
The input variables
"""
embedding_net_variables = {}
for key, value in variables.items():
for key, value in data.items():
keys = key.split("/")
key0 = keys[0][5:]
key1 = keys[1][5:]
if key1 == "all":
key1 = ""
else:
key1 = "_" + key1
network = NativeNet.deserialize(value)
network = EmbeddingNet.deserialize(value)
for layer_idx, layer in enumerate(network.layers):
embedding_net_variables[
f"filter_type_{key0}/matrix_{layer_idx}{key1}"
Expand Down
18 changes: 14 additions & 4 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,12 +1358,13 @@ def deserialize(cls, data: dict):
Model
The deserialized model
"""
if type(cls) is not DescrptSeA:
raise NotImplementedError("Unsupported")
embedding_net_variables = cls.deserialize_network(data["networks"])
descriptor = cls(**data)
descriptor.embedding_net_variables = embedding_net_variables
descriptor.davg = data["@variables"]["davg"]
descriptor.dstd = data["@variables"]["dstd"]
descriptor.embedding_net_variables = cls.from_dp_variables(
data["@variables"]["networks"]
)
descriptor.original_sel = data["@variables"]["original_sel"]
return descriptor

Expand All @@ -1375,6 +1376,8 @@ def serialize(self) -> dict:
dict
The serialized data
"""
if type(self) is not DescrptSeA:
raise NotImplementedError("Unsupported")
return {
"type": "se_e2_a",
"rcut": self.rcut_r,
Expand All @@ -1392,8 +1395,15 @@ def serialize(self) -> dict:
"precision": self.filter_precision.name,
"uniform_seed": self.uniform_seed,
"stripped_type_embedding": self.stripped_type_embedding,
"networks": self.serialize_network(
# TODO: how to consider type embedding?
in_dim=1,
neuron=self.filter_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.filter_resnet_dt,
variables=self.embedding_net_variables,
),
"@variables": {
"networks": self.to_dp_variables(self.embedding_net_variables),
"davg": self.davg,
"dstd": self.dstd,
"original_sel": self.original_sel,
Expand Down

0 comments on commit e12c10f

Please sign in to comment.