Skip to content

Commit

Permalink
Fix typo in smooth_type_embdding (deepmodeling#3698)
Browse files Browse the repository at this point in the history
(cherry picked from commit 86b0bf8)
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
iProzd authored and njzjz committed Jul 2, 2024
1 parent a6acc34 commit cdbb70d
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 18 deletions.
10 changes: 6 additions & 4 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class DescrptSeAtten(DescrptSeA):
stripped_type_embedding
Whether to strip the type embedding into a separated embedding network.
Default value will be True in `se_atten_v2` descriptor.
smooth_type_embdding
smooth_type_embedding
When using stripped type embedding, whether to dot smooth factor on the network output of type embedding
to keep the network smooth, instead of setting `set_davg_zero` to be True.
Default value will be True in `se_atten_v2` descriptor.
Expand Down Expand Up @@ -152,10 +152,12 @@ def __init__(
attn_mask: bool = False,
multi_task: bool = False,
stripped_type_embedding: bool = False,
smooth_type_embdding: bool = False,
smooth_type_embedding: bool = False,
**kwargs,
) -> None:
if not set_davg_zero and not (stripped_type_embedding and smooth_type_embdding):
if not set_davg_zero and not (
stripped_type_embedding and smooth_type_embedding
):
warnings.warn(
"Set 'set_davg_zero' False in descriptor 'se_atten' "
"may cause unexpected incontinuity during model inference!"
Expand Down Expand Up @@ -188,7 +190,7 @@ def __init__(
if ntypes == 0:
raise ValueError("`model/type_map` is not set or empty!")
self.stripped_type_embedding = stripped_type_embedding
self.smooth = smooth_type_embdding
self.smooth = smooth_type_embedding
self.ntypes = ntypes
self.att_n = attn
self.attn_layer = attn_layer
Expand Down
2 changes: 1 addition & 1 deletion deepmd/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,6 @@ def __init__(
attn_mask=attn_mask,
multi_task=multi_task,
stripped_type_embedding=True,
smooth_type_embdding=True,
smooth_type_embedding=True,
**kwargs,
)
7 changes: 4 additions & 3 deletions deepmd_utils/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def descrpt_se_atten_common_args():
@descrpt_args_plugin.register("se_atten")
def descrpt_se_atten_args():
doc_stripped_type_embedding = "Whether to strip the type embedding into a separated embedding network. Setting it to `False` will fall back to the previous version of `se_atten` which is non-compressible."
doc_smooth_type_embdding = "When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True."
doc_smooth_type_embedding = "When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True."
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used"

return [
Expand All @@ -430,11 +430,12 @@ def descrpt_se_atten_args():
doc=doc_stripped_type_embedding,
),
Argument(
"smooth_type_embdding",
"smooth_type_embedding",
bool,
optional=True,
default=False,
doc=doc_smooth_type_embdding,
alias=["smooth_type_embdding"],
doc=doc_smooth_type_embedding,
),
Argument(
"set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero
Expand Down
2 changes: 1 addition & 1 deletion doc/model/train-se-atten.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ We highly recommend using the version 2.0 of the attention-based descriptor `"se

```json
"stripped_type_embedding": true,
"smooth_type_embdding": true,
"smooth_type_embedding": true,
"set_davg_zero": false
```

Expand Down
14 changes: 7 additions & 7 deletions source/tests/test_model_compression_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,27 @@ def _file_delete(file):
{
"se_atten precision": "float64",
"type embedding precision": "float64",
"smooth_type_embdding": True,
"smooth_type_embedding": True,
},
{
"se_atten precision": "float64",
"type embedding precision": "float64",
"smooth_type_embdding": False,
"smooth_type_embedding": False,
},
{
"se_atten precision": "float64",
"type embedding precision": "float32",
"smooth_type_embdding": True,
"smooth_type_embedding": True,
},
{
"se_atten precision": "float32",
"type embedding precision": "float64",
"smooth_type_embdding": True,
"smooth_type_embedding": True,
},
{
"se_atten precision": "float32",
"type embedding precision": "float32",
"smooth_type_embdding": True,
"smooth_type_embedding": True,
},
]

Expand All @@ -82,8 +82,8 @@ def _init_models():
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
jdata["model"]["descriptor"]["sel"] = 120
jdata["model"]["descriptor"]["attn_layer"] = 0
jdata["model"]["descriptor"]["smooth_type_embdding"] = tests[i][
"smooth_type_embdding"
jdata["model"]["descriptor"]["smooth_type_embedding"] = tests[i][
"smooth_type_embedding"
]
jdata["model"]["type_embedding"] = {}
jdata["model"]["type_embedding"]["precision"] = tests[i][
Expand Down
4 changes: 2 additions & 2 deletions source/tests/test_model_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self):
jdata["model"]["descriptor"].pop("type", None)
jdata["model"]["descriptor"]["ntypes"] = 2
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
jdata["model"]["descriptor"]["smooth_type_embdding"] = True
jdata["model"]["descriptor"]["smooth_type_embedding"] = True
jdata["model"]["descriptor"]["attn_layer"] = 1
jdata["model"]["descriptor"]["rcut"] = 6.0
jdata["model"]["descriptor"]["rcut_smth"] = 4.0
Expand Down Expand Up @@ -894,7 +894,7 @@ def test_smoothness_of_stripped_type_embedding_smooth_model_excluded_types(self)
jdata["model"]["descriptor"].pop("type", None)
jdata["model"]["descriptor"]["ntypes"] = 2
jdata["model"]["descriptor"]["stripped_type_embedding"] = True
jdata["model"]["descriptor"]["smooth_type_embdding"] = True
jdata["model"]["descriptor"]["smooth_type_embedding"] = True
jdata["model"]["descriptor"]["attn_layer"] = 1
jdata["model"]["descriptor"]["rcut"] = 6.0
jdata["model"]["descriptor"]["rcut_smth"] = 4.0
Expand Down

0 comments on commit cdbb70d

Please sign in to comment.