Skip to content

Commit

Permalink
feat(pt): calculate stat during compression if --skip-neighbor-stat (
Browse files Browse the repository at this point in the history
…#4330)

If `--skip-neighbor-stat` is set during training, when calling `dp
compress`, first calculate the neighbor stat.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced `enable_compression` function to accept a `training_script`
parameter for improved error handling and functionality.
- Updated the `compress` command to allow specification of a training
script during execution.
- Introduced a new testing framework for models using the
`--skip-neighbor-stat` flag, validating their functionality.

- **Bug Fixes**
- Improved error handling for cases where the model's minimum neighbor
distance is not saved.

- **Tests**
- Added a new test class and methods to validate the functionality of
models initialized with skip neighbor statistics.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 9, 2024
1 parent cb3e39e commit c12bc01
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 1 deletion.
53 changes: 53 additions & 0 deletions deepmd/pt/entrypoints/compress.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import logging
from typing import (
Optional,
)

import torch

from deepmd.common import (
j_loader,
)
from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.compat import (
update_deepmd_input,
)
from deepmd.utils.data_system import (
get_data,
)

log = logging.getLogger(__name__)


def enable_compression(
Expand All @@ -14,12 +35,44 @@ def enable_compression(
stride: float = 0.01,
extrapolate: int = 5,
check_frequency: int = -1,
training_script: Optional[str] = None,
):
saved_model = torch.jit.load(input_file, map_location="cpu")
model_def_script = json.loads(saved_model.model_def_script)
model = get_model(model_def_script)
model.load_state_dict(saved_model.state_dict())

if model.get_min_nbor_dist() is None:
log.info(
"Minimal neighbor distance is not saved in the model, compute it from the training data."
)
if training_script is None:
raise ValueError(
"The model does not have a minimum neighbor distance, "
"so the training script and data must be provided "
"(via -t,--training-script)."
)

jdata = j_loader(training_script)
jdata = update_deepmd_input(jdata)

type_map = jdata["model"].get("type_map", None)
train_data = get_data(
jdata["training"]["training_data"],
0, # not used
type_map,
None,
)
update_sel = UpdateSel()
t_min_nbor_dist = update_sel.get_min_nbor_dist(
train_data,
)
model.min_nbor_dist = torch.tensor(
t_min_nbor_dist,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)

model.enable_compression(
extrapolate,
stride,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
stride=FLAGS.step,
extrapolate=FLAGS.extrapolate,
check_frequency=FLAGS.frequency,
training_script=FLAGS.training_script,
)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")
Expand Down
159 changes: 158 additions & 1 deletion source/tests/pt/test_model_compression_se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,48 @@ def _init_models_exclude_types():
return INPUT, frozen_model, compressed_model


def _init_models_skip_neighbor_stat():
suffix = "-skip-neighbor-stat"
data_file = str(tests_path / os.path.join("model_compression", "data"))
frozen_model = str(tests_path / f"dp-original{suffix}.pth")
compressed_model = str(tests_path / f"dp-compressed{suffix}.pth")
INPUT = str(tests_path / "input.json")
jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json")))
jdata["training"]["training_data"]["systems"] = data_file
with open(INPUT, "w") as fp:
json.dump(jdata, fp, indent=4)

ret = run_dp("dp --pt train " + INPUT + " --skip-neighbor-stat")
np.testing.assert_equal(ret, 0, "DP train failed!")
ret = run_dp("dp --pt freeze -o " + frozen_model)
np.testing.assert_equal(ret, 0, "DP freeze failed!")
ret = run_dp(
"dp --pt compress "
+ " -i "
+ frozen_model
+ " -o "
+ compressed_model
+ " -t "
+ INPUT
)
np.testing.assert_equal(ret, 0, "DP model compression failed!")
return INPUT, frozen_model, compressed_model


def setUpModule():
global \
INPUT, \
FROZEN_MODEL, \
COMPRESSED_MODEL, \
INPUT_ET, \
FROZEN_MODEL_ET, \
COMPRESSED_MODEL_ET
COMPRESSED_MODEL_ET, \
FROZEN_MODEL_SKIP_NEIGHBOR_STAT, \
COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT
INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models()
_, FROZEN_MODEL_SKIP_NEIGHBOR_STAT, COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT = (
_init_models_skip_neighbor_stat()
)
INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types()


Expand Down Expand Up @@ -572,5 +605,129 @@ def test_2frame_atm(self):
np.testing.assert_almost_equal(vv0, vv1, default_places)


class TestSkipNeighborStat(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dp_original = DeepEval(FROZEN_MODEL_SKIP_NEIGHBOR_STAT)
cls.dp_compressed = DeepEval(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT)
cls.coords = np.array(
[
12.83,
2.56,
2.18,
12.09,
2.87,
2.74,
00.25,
3.32,
1.68,
3.36,
3.00,
1.81,
3.51,
2.51,
2.60,
4.27,
3.22,
1.56,
]
)
cls.atype = [0, 1, 1, 0, 1, 1]
cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])

def test_attrs(self):
self.assertEqual(self.dp_original.get_ntypes(), 2)
self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places)
self.assertEqual(self.dp_original.get_type_map(), ["O", "H"])
self.assertEqual(self.dp_original.get_dim_fparam(), 0)
self.assertEqual(self.dp_original.get_dim_aparam(), 0)

self.assertEqual(self.dp_compressed.get_ntypes(), 2)
self.assertAlmostEqual(
self.dp_compressed.get_rcut(), 6.0, places=default_places
)
self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"])
self.assertEqual(self.dp_compressed.get_dim_fparam(), 0)
self.assertEqual(self.dp_compressed.get_dim_aparam(), 0)

def test_1frame(self):
ee0, ff0, vv0 = self.dp_original.eval(
self.coords, self.box, self.atype, atomic=False
)
ee1, ff1, vv1 = self.dp_compressed.eval(
self.coords, self.box, self.atype, atomic=False
)
# check shape of the returns
nframes = 1
natoms = len(self.atype)
self.assertEqual(ee0.shape, (nframes, 1))
self.assertEqual(ff0.shape, (nframes, natoms, 3))
self.assertEqual(vv0.shape, (nframes, 9))
self.assertEqual(ee1.shape, (nframes, 1))
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)

def test_1frame_atm(self):
ee0, ff0, vv0, ae0, av0 = self.dp_original.eval(
self.coords, self.box, self.atype, atomic=True
)
ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval(
self.coords, self.box, self.atype, atomic=True
)
# check shape of the returns
nframes = 1
natoms = len(self.atype)
self.assertEqual(ee0.shape, (nframes, 1))
self.assertEqual(ff0.shape, (nframes, natoms, 3))
self.assertEqual(vv0.shape, (nframes, 9))
self.assertEqual(ae0.shape, (nframes, natoms, 1))
self.assertEqual(av0.shape, (nframes, natoms, 9))
self.assertEqual(ee1.shape, (nframes, 1))
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
self.assertEqual(ae1.shape, (nframes, natoms, 1))
self.assertEqual(av1.shape, (nframes, natoms, 9))
# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ae0, ae1, default_places)
np.testing.assert_almost_equal(av0, av1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)

def test_2frame_atm(self):
coords2 = np.concatenate((self.coords, self.coords))
box2 = np.concatenate((self.box, self.box))
ee0, ff0, vv0, ae0, av0 = self.dp_original.eval(
coords2, box2, self.atype, atomic=True
)
ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval(
coords2, box2, self.atype, atomic=True
)
# check shape of the returns
nframes = 2
natoms = len(self.atype)
self.assertEqual(ee0.shape, (nframes, 1))
self.assertEqual(ff0.shape, (nframes, natoms, 3))
self.assertEqual(vv0.shape, (nframes, 9))
self.assertEqual(ae0.shape, (nframes, natoms, 1))
self.assertEqual(av0.shape, (nframes, natoms, 9))
self.assertEqual(ee1.shape, (nframes, 1))
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
self.assertEqual(ae1.shape, (nframes, natoms, 1))
self.assertEqual(av1.shape, (nframes, natoms, 9))

# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ae0, ae1, default_places)
np.testing.assert_almost_equal(av0, av1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)


if __name__ == "__main__":
unittest.main()

0 comments on commit c12bc01

Please sign in to comment.