Skip to content

Commit

Permalink
fix all serialization and inference code
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Jan 7, 2025
1 parent 5b663f1 commit cbbd11e
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 299 deletions.
38 changes: 35 additions & 3 deletions deepmd/pd/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,48 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
InputSpec,
)

jit_model = paddle.jit.to_static(
model,
""" example output shape and dtype of forward
atom_energy: fetch_name_0 (1, 6, 1) float64
atom_virial: fetch_name_1 (1, 6, 1, 9) float64
energy: fetch_name_2 (1, 1) float64
force: fetch_name_3 (1, 6, 3) float64
mask: fetch_name_4 (1, 6) int32
virial: fetch_name_5 (1, 9) float64
"""
model.forward = paddle.jit.to_static(
model.forward,
full_graph=True,
input_spec=[
InputSpec([1, -1, 3], dtype="float64", name="coord"),
InputSpec([1, -1], dtype="int64", name="atype"),
InputSpec([1, 9], dtype="float64", name="box"),
None,
None,
True,
],
)
""" example output shape and dtype of forward_lower
fetch_name_0: atom_energy [1, 192, 1] paddle.float64
fetch_name_1: energy [1, 1] paddle.float64
fetch_name_2: extended_force [1, 5184, 3] paddle.float64
fetch_name_3: extended_virial [1, 5184, 1, 9] paddle.float64
fetch_name_4: virial [1, 9] paddle.float64
"""
model.forward_lower = paddle.jit.to_static(
model.forward_lower,
full_graph=True,
input_spec=[
InputSpec([1, -1, 3], dtype="float64", name="coord"),
InputSpec([1, -1], dtype="int64", name="atype"),
InputSpec([1, -1, -1], dtype="int32", name="nlist"),
None,
None,
None,
True,
None,
],
)
paddle.jit.save(
jit_model,
model,
model_file.split(".json")[0],
)
5 changes: 5 additions & 0 deletions source/api_cc/include/DeepPotPD.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,13 @@ class DeepPotPD : public DeepPotBackend {
int daparam;
int aparam_nall;
// copy neighbor list info from host
// config & predictor for model.forward
std::shared_ptr<paddle_infer::Config> config;
std::shared_ptr<paddle_infer::Predictor> predictor;
// config & predictor for model.forward_lower
std::shared_ptr<paddle_infer::Config> config_fl;
std::shared_ptr<paddle_infer::Predictor> predictor_fl;

double rcut;
NeighborListData nlist_data;
int max_num_neighbors;
Expand Down
Loading

0 comments on commit cbbd11e

Please sign in to comment.