Skip to content

Commit

Permalink
fix: var name
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Apr 12, 2024
1 parent 5300d98 commit 90ede06
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 11 deletions.
4 changes: 2 additions & 2 deletions deepmd/pt/model/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def forward(
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["polar"] = model_ret["polar"]
model_predict["global_polar"] = model_ret["polar_redu"]
model_predict["polar"] = model_ret["polarizability"]
model_predict["global_polar"] = model_ret["polarizability_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
self.var_name,
"polarizability",
[3, 3],
reduciable=True,
r_differentiable=False,
Expand Down Expand Up @@ -314,7 +314,7 @@ def forward(
bias = bias.unsqueeze(-1) * eye
out = out + bias

return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
return {"polarizability": out.to(env.GLOBAL_PT_FLOAT_PRECISION)}

# make jit happy with torch 2.0.0
exclude_types: List[int]
7 changes: 0 additions & 7 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,6 @@ def compute_output_stats(
which will be subtracted from the energy label of the data.
The difference will then be used to calculate the delta complement energy bias for each type.
"""
# mapping keys, to resolve var_name/label_name mismatch (eg. polar/polarizabitlity)
key_mapping = {
"polar": "polarizability",
}
keys = [keys] if isinstance(keys, str) else keys
assert isinstance(keys, list)
keys = [key_mapping[k] if k in key_mapping else k for k in keys]

# try to restore the bias from stat file
bias_atom_e, std_atom_e = _restore_from_file(stat_file_path, keys)
Expand Down

0 comments on commit 90ede06

Please sign in to comment.