Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 5, 2025
1 parent ffd0694 commit 45f222b
Showing 1 changed file with 33 additions and 48 deletions.
81 changes: 33 additions & 48 deletions deepmd/dpmodel/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ def call(
atom_ener_hat,
atom_pref,
)
if self.has_gf:
drdq = label_dict["drdq"]
find_drdq = label_dict["find_drdq"]

if self.enable_atom_ener_coeff:
# when ener_coeff (\nu) is defined, the energy is defined as
Expand All @@ -114,9 +111,6 @@ def call(
atom_ener_coeff = label_dict["atom_ener_coeff"]
atom_ener_coeff = xp.reshape(atom_ener_coeff, xp.shape(atom_ener))
energy = xp.sum(atom_ener_coeff * atom_ener, 1)
if self.has_e:
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))

if self.has_f or self.has_pf or self.relative_f or self.has_gf:
force_reshape = xp.reshape(force, [-1])
force_hat_reshape = xp.reshape(force_hat, [-1])
Expand All @@ -129,43 +123,6 @@ def call(
diff_f_3 = diff_f_3 / norm_f
diff_f = xp.reshape(diff_f_3, [-1])

if self.has_f:
l2_force_loss = xp.mean(xp.square(diff_f))

if self.has_pf:
atom_pref_reshape = xp.reshape(atom_pref, [-1])
l2_pref_force_loss = xp.mean(
xp.multiply(xp.square(diff_f), atom_pref_reshape),
)

if self.has_gf:
drdq = label_dict["drdq"]
force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3])
force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3])
drdq_reshape = xp.reshape(
drdq, [-1, natoms[0] * 3, self.numb_generalized_coord]
)
gen_force_hat = xp.einsum(
"bij,bi->bj", drdq_reshape, force_hat_reshape_nframes
)
gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes)
diff_gen_force = gen_force_hat - gen_force
l2_gen_force_loss = xp.mean(xp.square(diff_gen_force))

if self.has_v:
virial_reshape = xp.reshape(virial, [-1])
virial_hat_reshape = xp.reshape(virial_hat, [-1])
l2_virial_loss = xp.mean(
xp.square(virial_hat_reshape - virial_reshape),
)

if self.has_ae:
atom_ener_reshape = xp.reshape(atom_ener, [-1])
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1])
l2_atom_ener_loss = xp.mean(
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
)

atom_norm = 1.0 / natoms
atom_norm_ener = 1.0 / natoms
lr_ratio = learning_rate / self.starter_learning_rate
Expand All @@ -184,38 +141,66 @@ def call(
pref_pf = find_atom_pref * (
self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * lr_ratio
)
if self.has_gf:
pref_gf = find_drdq * (
self.limit_pref_gf
+ (self.start_pref_gf - self.limit_pref_gf) * lr_ratio
)

l2_loss = 0
more_loss = {}
if self.has_e:
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))
l2_loss += atom_norm_ener * (pref_e * l2_ener_loss)
more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy)
if self.has_f:
l2_force_loss = xp.mean(xp.square(diff_f))
l2_loss += pref_f * l2_force_loss
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss, find_force
)
if self.has_v:
virial_reshape = xp.reshape(virial, [-1])
virial_hat_reshape = xp.reshape(virial_hat, [-1])
l2_virial_loss = xp.mean(
xp.square(virial_hat_reshape - virial_reshape),
)
l2_loss += atom_norm * (pref_v * l2_virial_loss)
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss, find_virial
)
if self.has_ae:
atom_ener_reshape = xp.reshape(atom_ener, [-1])
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1])
l2_atom_ener_loss = xp.mean(
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
)
l2_loss += pref_ae * l2_atom_ener_loss
more_loss["l2_atom_ener_loss"] = self.display_if_exist(
l2_atom_ener_loss, find_atom_ener
)
if self.has_pf:
atom_pref_reshape = xp.reshape(atom_pref, [-1])
l2_pref_force_loss = xp.mean(
xp.multiply(xp.square(diff_f), atom_pref_reshape),
)
l2_loss += pref_pf * l2_pref_force_loss
more_loss["l2_pref_force_loss"] = self.display_if_exist(
l2_pref_force_loss, find_atom_pref
)
if self.has_gf:
find_drdq = label_dict["find_drdq"]
drdq = label_dict["drdq"]
force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3])
force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3])
drdq_reshape = xp.reshape(
drdq, [-1, natoms[0] * 3, self.numb_generalized_coord]
)
gen_force_hat = xp.einsum(
"bij,bi->bj", drdq_reshape, force_hat_reshape_nframes
)
gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes)
diff_gen_force = gen_force_hat - gen_force
l2_gen_force_loss = xp.mean(xp.square(diff_gen_force))
pref_gf = find_drdq * (
self.limit_pref_gff
+ (self.start_pref_gf - self.limit_pref_gf) * lr_ratio
)
l2_loss += pref_gf * l2_gen_force_loss
more_loss["l2_gen_force_loss"] = self.display_if_exist(
l2_gen_force_loss, find_drdq
Expand Down

0 comments on commit 45f222b

Please sign in to comment.