Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: refactor dpmodel #3663

Merged
merged 44 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
02e546c
chore: try remove dp model
anyangml Apr 11, 2024
4c026fb
chore: try remove dp model
anyangml Apr 11, 2024
81573d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
3ad9637
fix: import
anyangml Apr 11, 2024
c64e520
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
acd1b99
fix: import
anyangml Apr 11, 2024
db2751c
fix: UTs
anyangml Apr 11, 2024
9612b02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
a7384d2
fix: UTs
anyangml Apr 11, 2024
30bd378
fix: UTs
anyangml Apr 11, 2024
6c9185e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
24859ac
fix: UTs
anyangml Apr 11, 2024
754e18a
fix: UTs
anyangml Apr 11, 2024
56accdf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
9f437ed
fix: address comments
anyangml Apr 12, 2024
ed34307
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
72a34fc
fix: import
anyangml Apr 12, 2024
5355428
fix: UTs
anyangml Apr 12, 2024
257936d
feat: try remove standard
anyangml Apr 13, 2024
a51a792
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
4b14d54
fix: precommit
anyangml Apr 13, 2024
5484375
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
3c5c795
fix: import
anyangml Apr 13, 2024
224ef27
fix: UTs
anyangml Apr 13, 2024
bd079b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
53dbe13
fix: UTs
anyangml Apr 13, 2024
78ff90d
fix: UTs
anyangml Apr 13, 2024
7f99512
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
23fa34d
fix: try modify argcheck
anyangml Apr 13, 2024
16583de
fix: try modify argcheck
anyangml Apr 13, 2024
23fd899
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
4fe3661
fix: plugin
anyangml Apr 13, 2024
576e75d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
d34a7d4
Merge branch 'devel' into chore/refactor-dpmodel
anyangml Apr 14, 2024
3173978
fix:inheritance
anyangml Apr 14, 2024
1e54270
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
353ca6d
fix: revert changes
anyangml Apr 14, 2024
8de02ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
e2a80b3
fix: UTs
anyangml Apr 14, 2024
efcf951
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
0c100fa
fix:UTs
anyangml Apr 14, 2024
752132d
fix: UTs
anyangml Apr 14, 2024
ff30173
fix: tf register
anyangml Apr 15, 2024
7399d3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/dpmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
NativeOP,
)
from .model import (
DPModel,
DPModelCommon,
)
from .output_def import (
FittingOutputDef,
Expand All @@ -19,7 +19,7 @@
)

__all__ = [
"DPModel",
"DPModelCommon",
"PRECISION_DICT",
"DEFAULT_PRECISION",
"NativeOP",
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

from .dp_model import (
DPModel,
DPModelCommon,
)
from .make_model import (
make_model,
Expand All @@ -23,7 +23,7 @@
)

__all__ = [
"DPModel",
"DPModelCommon",
"SpinModel",
"make_model",
]
4 changes: 2 additions & 2 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

def __new__(cls, *args, **kwargs):
if inspect.isabstract(cls):
cls = cls.get_class_by_type(kwargs.get("type", "standard"))
cls = cls.get_class_by_type(kwargs.get("type", "ener"))

Check warning on line 38 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L38

Added line #L38 was not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return super().__new__(cls)

@abstractmethod
Expand Down Expand Up @@ -151,7 +151,7 @@
local_jdata : dict
The local data refer to the current class
"""
cls = cls.get_class_by_type(local_jdata.get("type", "standard"))
cls = cls.get_class_by_type(local_jdata.get("type", "ener"))

Check warning on line 154 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L154

Added line #L154 was not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return cls.update_sel(global_jdata, local_jdata)

return BaseBaseModel
Expand Down
13 changes: 1 addition & 12 deletions deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

from .make_model import (
make_model,
)


# use "class" to resolve "Variable not allowed in type expression"
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel)):
class DPModelCommon:
@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Expand Down
27 changes: 27 additions & 0 deletions deepmd/dpmodel/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dp_atomic_model import (

Check warning on line 2 in deepmd/dpmodel/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/ener_model.py#L2

Added line #L2 was not covered by tests
DPAtomicModel,
)
from deepmd.dpmodel.model.base_model import (

Check warning on line 5 in deepmd/dpmodel/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/ener_model.py#L5

Added line #L5 was not covered by tests
BaseModel,
)

from .dp_model import (

Check warning on line 9 in deepmd/dpmodel/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/ener_model.py#L9

Added line #L9 was not covered by tests
DPModelCommon,
)
from .make_model import (

Check warning on line 12 in deepmd/dpmodel/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/ener_model.py#L12

Added line #L12 was not covered by tests
make_model,
)

DPEnergyModel_ = make_model(DPAtomicModel)

Check warning on line 16 in deepmd/dpmodel/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/ener_model.py#L16

Added line #L16 was not covered by tests


@BaseModel.register("ener")
class EnergyModel(DPModelCommon, DPEnergyModel_):
def __init__(

Check warning on line 21 in deepmd/dpmodel/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/ener_model.py#L19-L21

Added lines #L19 - L21 were not covered by tests
self,
*args,
**kwargs,
):
DPModelCommon.__init__(self)
DPEnergyModel_.__init__(self, *args, **kwargs)

Check warning on line 27 in deepmd/dpmodel/model/ener_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/ener_model.py#L26-L27

Added lines #L26 - L27 were not covered by tests
10 changes: 5 additions & 5 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.model.ener_model import (

Check warning on line 8 in deepmd/dpmodel/model/model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/model.py#L8

Added line #L8 was not covered by tests
EnergyModel,
)
from deepmd.dpmodel.model.spin_model import (
SpinModel,
Expand All @@ -16,8 +16,8 @@
)


def get_standard_model(data: dict) -> DPModel:
"""Get a standard DPModel from a dictionary.
def get_standard_model(data: dict) -> EnergyModel:

Check warning on line 19 in deepmd/dpmodel/model/model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/model.py#L19

Added line #L19 was not covered by tests
"""Get a EnergyModel from a dictionary.

Parameters
----------
Expand All @@ -41,7 +41,7 @@
)
else:
raise ValueError(f"Unknown fitting type {fitting_type}")
return DPModel(
return EnergyModel(

Check warning on line 44 in deepmd/dpmodel/model/model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/model.py#L44

Added line #L44 was not covered by tests
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
Expand Down
11 changes: 8 additions & 3 deletions deepmd/dpmodel/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

import numpy as np

from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.make_model import (
make_model,
)
from deepmd.utils.spin import (
Spin,
Expand Down Expand Up @@ -259,7 +262,9 @@

@classmethod
def deserialize(cls, data) -> "SpinModel":
backbone_model_obj = DPModel.deserialize(data["backbone_model"])
backbone_model_obj = make_model(DPAtomicModel).deserialize(

Check warning on line 265 in deepmd/dpmodel/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/spin_model.py#L265

Added line #L265 was not covered by tests
data["backbone_model"]
)
spin = Spin.deserialize(data["spin"])
return cls(
backbone_model=backbone_model_obj,
Expand Down
16 changes: 16 additions & 0 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,37 @@
from .base_atomic_model import (
BaseAtomicModel,
)
from .dipole_atomic_model import (

Check warning on line 20 in deepmd/pt/model/atomic_model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/__init__.py#L20

Added line #L20 was not covered by tests
DPDipoleAtomicModel,
)
from .dos_atomic_model import (

Check warning on line 23 in deepmd/pt/model/atomic_model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/__init__.py#L23

Added line #L23 was not covered by tests
DPDOSAtomicModel,
)
from .dp_atomic_model import (
DPAtomicModel,
)
from .energy_atomic_model import (

Check warning on line 29 in deepmd/pt/model/atomic_model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/__init__.py#L29

Added line #L29 was not covered by tests
DPEnergyAtomicModel,
)
from .linear_atomic_model import (
DPZBLLinearEnergyAtomicModel,
LinearEnergyAtomicModel,
)
from .pairtab_atomic_model import (
PairTabAtomicModel,
)
from .polar_atomic_model import (

Check warning on line 39 in deepmd/pt/model/atomic_model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/__init__.py#L39

Added line #L39 was not covered by tests
DPPolarAtomicModel,
)

__all__ = [
"BaseAtomicModel",
"DPAtomicModel",
"DPDOSAtomicModel",
"DPEnergyAtomicModel",
"PairTabAtomicModel",
"LinearEnergyAtomicModel",
"DPPolarAtomicModel",
"DPDipoleAtomicModel",
"DPZBLLinearEnergyAtomicModel",
]
28 changes: 28 additions & 0 deletions deepmd/pt/model/atomic_model/dipole_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

Check warning on line 2 in deepmd/pt/model/atomic_model/dipole_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dipole_atomic_model.py#L2

Added line #L2 was not covered by tests
Dict,
)

import torch

Check warning on line 6 in deepmd/pt/model/atomic_model/dipole_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dipole_atomic_model.py#L6

Added line #L6 was not covered by tests

from deepmd.pt.model.task.dipole import (

Check warning on line 8 in deepmd/pt/model/atomic_model/dipole_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dipole_atomic_model.py#L8

Added line #L8 was not covered by tests
DipoleFittingNet,
)

from .dp_atomic_model import (

Check warning on line 12 in deepmd/pt/model/atomic_model/dipole_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dipole_atomic_model.py#L12

Added line #L12 was not covered by tests
DPAtomicModel,
)


class DPDipoleAtomicModel(DPAtomicModel):
anyangml marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DipoleFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)

Check warning on line 20 in deepmd/pt/model/atomic_model/dipole_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dipole_atomic_model.py#L17-L20

Added lines #L17 - L20 were not covered by tests

def apply_out_stat(

Check warning on line 22 in deepmd/pt/model/atomic_model/dipole_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dipole_atomic_model.py#L22

Added line #L22 was not covered by tests
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
# dipole not applying bias
return ret

Check warning on line 28 in deepmd/pt/model/atomic_model/dipole_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dipole_atomic_model.py#L28

Added line #L28 was not covered by tests
14 changes: 14 additions & 0 deletions deepmd/pt/model/atomic_model/dos_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.pt.model.task.dos import (

Check warning on line 2 in deepmd/pt/model/atomic_model/dos_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dos_atomic_model.py#L2

Added line #L2 was not covered by tests
DOSFittingNet,
)

from .dp_atomic_model import (

Check warning on line 6 in deepmd/pt/model/atomic_model/dos_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dos_atomic_model.py#L6

Added line #L6 was not covered by tests
DPAtomicModel,
)


class DPDOSAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DOSFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)

Check warning on line 14 in deepmd/pt/model/atomic_model/dos_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/dos_atomic_model.py#L11-L14

Added lines #L11 - L14 were not covered by tests
20 changes: 20 additions & 0 deletions deepmd/pt/model/atomic_model/energy_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.pt.model.task.ener import (

Check warning on line 2 in deepmd/pt/model/atomic_model/energy_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/energy_atomic_model.py#L2

Added line #L2 was not covered by tests
EnergyFittingNet,
EnergyFittingNetDirect,
InvarFitting,
)

from .dp_atomic_model import (

Check warning on line 8 in deepmd/pt/model/atomic_model/energy_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/energy_atomic_model.py#L8

Added line #L8 was not covered by tests
DPAtomicModel,
)


class DPEnergyAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert (

Check warning on line 15 in deepmd/pt/model/atomic_model/energy_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/energy_atomic_model.py#L13-L15

Added lines #L13 - L15 were not covered by tests
isinstance(fitting, EnergyFittingNet)
or isinstance(fitting, EnergyFittingNetDirect)
or isinstance(fitting, InvarFitting)
)
super().__init__(descriptor, fitting, type_map, **kwargs)

Check warning on line 20 in deepmd/pt/model/atomic_model/energy_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/energy_atomic_model.py#L20

Added line #L20 was not covered by tests
28 changes: 28 additions & 0 deletions deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

Check warning on line 2 in deepmd/pt/model/atomic_model/polar_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/polar_atomic_model.py#L2

Added line #L2 was not covered by tests
Dict,
)

import torch

Check warning on line 6 in deepmd/pt/model/atomic_model/polar_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/polar_atomic_model.py#L6

Added line #L6 was not covered by tests

from deepmd.pt.model.task.polarizability import (

Check warning on line 8 in deepmd/pt/model/atomic_model/polar_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/polar_atomic_model.py#L8

Added line #L8 was not covered by tests
PolarFittingNet,
)

from .dp_atomic_model import (

Check warning on line 12 in deepmd/pt/model/atomic_model/polar_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/polar_atomic_model.py#L12

Added line #L12 was not covered by tests
DPAtomicModel,
)


class DPPolarAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, PolarFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)

Check warning on line 20 in deepmd/pt/model/atomic_model/polar_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/polar_atomic_model.py#L17-L20

Added lines #L17 - L20 were not covered by tests

def apply_out_stat(

Check warning on line 22 in deepmd/pt/model/atomic_model/polar_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/polar_atomic_model.py#L22

Added line #L22 was not covered by tests
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
# TODO: migrate bias
return ret

Check warning on line 28 in deepmd/pt/model/atomic_model/polar_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/polar_atomic_model.py#L28

Added line #L28 was not covered by tests
26 changes: 23 additions & 3 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@
Spin,
)

from .dipole_model import (

Check warning on line 33 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L33

Added line #L33 was not covered by tests
DipoleModel,
)
from .dos_model import (

Check warning on line 36 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L36

Added line #L36 was not covered by tests
DOSModel,
)
from .dp_model import (
DPModel,
DPModelCommon,
)
from .dp_zbl_model import (
DPZBLModel,
Expand All @@ -51,6 +57,9 @@
from .model import (
BaseModel,
)
from .polar_model import (

Check warning on line 60 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L60

Added line #L60 was not covered by tests
PolarModel,
)
from .spin_model import (
SpinEnergyModel,
SpinModel,
Expand Down Expand Up @@ -160,7 +169,18 @@
atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])

model = DPModel(
if fitting_net["type"] == "dipole":
modelcls = DipoleModel
elif fitting_net["type"] == "polar":
modelcls = PolarModel
elif fitting_net["type"] == "dos":
modelcls = DOSModel
elif fitting_net["type"] in ["ener", "direct_force_ener"]:
modelcls = EnergyModel

Check warning on line 179 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L172-L179

Added lines #L172 - L179 were not covered by tests
else:
raise RuntimeError(f"Unknown fitting type: {fitting_net['type']}")

Check warning on line 181 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L181

Added line #L181 was not covered by tests

model = modelcls(

Check warning on line 183 in deepmd/pt/model/model/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/__init__.py#L183

Added line #L183 was not covered by tests
Fixed Show fixed Hide fixed
descriptor=descriptor,
fitting=fitting,
type_map=model_params["type_map"],
Expand All @@ -183,7 +203,7 @@
__all__ = [
"BaseModel",
"get_model",
"DPModel",
"DPModelCommon",
"EnergyModel",
"FrozenModel",
"SpinModel",
Expand Down
20 changes: 17 additions & 3 deletions deepmd/pt/model/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,34 @@

import torch

from deepmd.pt.model.atomic_model import (

Check warning on line 9 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L9

Added line #L9 was not covered by tests
DPDipoleAtomicModel,
)
from deepmd.pt.model.model.model import (

Check warning on line 12 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L12

Added line #L12 was not covered by tests
BaseModel,
)

from .dp_model import (
DPModel,
DPModelCommon,
)
from .make_model import (

Check warning on line 19 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L19

Added line #L19 was not covered by tests
make_model,
)

DPDOSModel_ = make_model(DPDipoleAtomicModel)

Check warning on line 23 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L23

Added line #L23 was not covered by tests


class DipoleModel(DPModel):
@BaseModel.register("dipole")
class DipoleModel(DPModelCommon, DPDOSModel_):

Check warning on line 27 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L26-L27

Added lines #L26 - L27 were not covered by tests
model_type = "dipole"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
DPModelCommon.__init__(self)
DPDOSModel_.__init__(self, *args, **kwargs)

Check warning on line 36 in deepmd/pt/model/model/dipole_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dipole_model.py#L35-L36

Added lines #L35 - L36 were not covered by tests

def forward(
self,
Expand Down
Loading
Loading