Skip to content

Commit

Permalink
fix: UT import
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Apr 16, 2024
1 parent 23c7fdf commit ead2a38
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 6 deletions.
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
op_module,
)

from ..test_stat import (
from ..test_finetune import (
energy_data_requirement,
)
from .test_embedding_net import (
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_embedding_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf

from ..test_stat import (
from ..test_finetune import (
energy_data_requirement,
)

Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
LearningRateExp,
)

from ..test_stat import (
from ..test_finetune import (
energy_data_requirement,
)

Expand Down
43 changes: 41 additions & 2 deletions source/tests/pt/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,49 @@
model_se_e2_a,
model_zbl,
)
from .test_stat import (
energy_data_requirement,

from deepmd.utils.data import (
DataRequirementItem,
)

energy_data_requirement = [
DataRequirementItem(
"energy",
ndof=1,
atomic=False,
must=False,
high_prec=True,
),
DataRequirementItem(
"force",
ndof=3,
atomic=True,
must=False,
high_prec=False,
),
DataRequirementItem(
"virial",
ndof=9,
atomic=False,
must=False,
high_prec=False,
),
DataRequirementItem(
"atom_ener",
ndof=1,
atomic=True,
must=False,
high_prec=False,
),
DataRequirementItem(
"atom_pref",
ndof=1,
atomic=True,
must=False,
high_prec=False,
repeat=3,
),
]

class FinetuneTest:
def test_finetune_change_out_bias(self):
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .model.test_embedding_net import (
get_single_batch,
)
from .test_stat import (
from .test_finetune import (
energy_data_requirement,
)

Expand Down

0 comments on commit ead2a38

Please sign in to comment.