From ead2a38ba06197111952096088d008c9bfa694b8 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 16 Apr 2024 17:21:32 +0800 Subject: [PATCH] fix: UT import --- source/tests/pt/model/test_descriptor.py | 2 +- source/tests/pt/model/test_embedding_net.py | 2 +- source/tests/pt/model/test_model.py | 2 +- source/tests/pt/test_finetune.py | 43 ++++++++++++++++++++- source/tests/pt/test_loss.py | 2 +- 5 files changed, 45 insertions(+), 6 deletions(-) diff --git a/source/tests/pt/model/test_descriptor.py b/source/tests/pt/model/test_descriptor.py index 7d21d1c13d..ff1fd0c959 100644 --- a/source/tests/pt/model/test_descriptor.py +++ b/source/tests/pt/model/test_descriptor.py @@ -38,7 +38,7 @@ op_module, ) -from ..test_stat import ( +from ..test_finetune import ( energy_data_requirement, ) from .test_embedding_net import ( diff --git a/source/tests/pt/model/test_embedding_net.py b/source/tests/pt/model/test_embedding_net.py index 63a3534c74..77d14db2a4 100644 --- a/source/tests/pt/model/test_embedding_net.py +++ b/source/tests/pt/model/test_embedding_net.py @@ -39,7 +39,7 @@ ) from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf -from ..test_stat import ( +from ..test_finetune import ( energy_data_requirement, ) diff --git a/source/tests/pt/model/test_model.py b/source/tests/pt/model/test_model.py index 493d6e2cc3..71ad64d99d 100644 --- a/source/tests/pt/model/test_model.py +++ b/source/tests/pt/model/test_model.py @@ -51,7 +51,7 @@ LearningRateExp, ) -from ..test_stat import ( +from ..test_finetune import ( energy_data_requirement, ) diff --git a/source/tests/pt/test_finetune.py b/source/tests/pt/test_finetune.py index 8f299ce542..3ea6fdeb1d 100644 --- a/source/tests/pt/test_finetune.py +++ b/source/tests/pt/test_finetune.py @@ -29,10 +29,49 @@ model_se_e2_a, model_zbl, ) -from .test_stat import ( - energy_data_requirement, + +from deepmd.utils.data import ( + DataRequirementItem, ) +energy_data_requirement = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ), + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ), +] class FinetuneTest: def test_finetune_change_out_bias(self): diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index 17b05dadc6..66460dfef1 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -32,7 +32,7 @@ from .model.test_embedding_net import ( get_single_batch, ) -from .test_stat import ( +from .test_finetune import ( energy_data_requirement, )