diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index 52c447b60b..97b3924c95 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -98,9 +98,6 @@ def call( atom_ener_hat, atom_pref, ) - if self.has_gf: - drdq = label_dict["drdq"] - find_drdq = label_dict["find_drdq"] if self.enable_atom_ener_coeff: # when ener_coeff (\nu) is defined, the energy is defined as @@ -114,9 +111,6 @@ def call( atom_ener_coeff = label_dict["atom_ener_coeff"] atom_ener_coeff = xp.reshape(atom_ener_coeff, xp.shape(atom_ener)) energy = xp.sum(atom_ener_coeff * atom_ener, 1) - if self.has_e: - l2_ener_loss = xp.mean(xp.square(energy - energy_hat)) - if self.has_f or self.has_pf or self.relative_f or self.has_gf: force_reshape = xp.reshape(force, [-1]) force_hat_reshape = xp.reshape(force_hat, [-1]) @@ -129,43 +123,6 @@ def call( diff_f_3 = diff_f_3 / norm_f diff_f = xp.reshape(diff_f_3, [-1]) - if self.has_f: - l2_force_loss = xp.mean(xp.square(diff_f)) - - if self.has_pf: - atom_pref_reshape = xp.reshape(atom_pref, [-1]) - l2_pref_force_loss = xp.mean( - xp.multiply(xp.square(diff_f), atom_pref_reshape), - ) - - if self.has_gf: - drdq = label_dict["drdq"] - force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3]) - force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3]) - drdq_reshape = xp.reshape( - drdq, [-1, natoms[0] * 3, self.numb_generalized_coord] - ) - gen_force_hat = xp.einsum( - "bij,bi->bj", drdq_reshape, force_hat_reshape_nframes - ) - gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes) - diff_gen_force = gen_force_hat - gen_force - l2_gen_force_loss = xp.mean(xp.square(diff_gen_force)) - - if self.has_v: - virial_reshape = xp.reshape(virial, [-1]) - virial_hat_reshape = xp.reshape(virial_hat, [-1]) - l2_virial_loss = xp.mean( - xp.square(virial_hat_reshape - virial_reshape), - ) - - if self.has_ae: - atom_ener_reshape = xp.reshape(atom_ener, [-1]) - atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1]) - l2_atom_ener_loss = xp.mean( - xp.square(atom_ener_hat_reshape - atom_ener_reshape), - ) - atom_norm = 1.0 / natoms atom_norm_ener = 1.0 / natoms lr_ratio = learning_rate / self.starter_learning_rate @@ -184,38 +141,66 @@ def call( pref_pf = find_atom_pref * ( self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * lr_ratio ) - if self.has_gf: - pref_gf = find_drdq * ( - self.limit_pref_gf - + (self.start_pref_gf - self.limit_pref_gf) * lr_ratio - ) l2_loss = 0 more_loss = {} if self.has_e: + l2_ener_loss = xp.mean(xp.square(energy - energy_hat)) l2_loss += atom_norm_ener * (pref_e * l2_ener_loss) more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy) if self.has_f: + l2_force_loss = xp.mean(xp.square(diff_f)) l2_loss += pref_f * l2_force_loss more_loss["l2_force_loss"] = self.display_if_exist( l2_force_loss, find_force ) if self.has_v: + virial_reshape = xp.reshape(virial, [-1]) + virial_hat_reshape = xp.reshape(virial_hat, [-1]) + l2_virial_loss = xp.mean( + xp.square(virial_hat_reshape - virial_reshape), + ) l2_loss += atom_norm * (pref_v * l2_virial_loss) more_loss["l2_virial_loss"] = self.display_if_exist( l2_virial_loss, find_virial ) if self.has_ae: + atom_ener_reshape = xp.reshape(atom_ener, [-1]) + atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1]) + l2_atom_ener_loss = xp.mean( + xp.square(atom_ener_hat_reshape - atom_ener_reshape), + ) l2_loss += pref_ae * l2_atom_ener_loss more_loss["l2_atom_ener_loss"] = self.display_if_exist( l2_atom_ener_loss, find_atom_ener ) if self.has_pf: + atom_pref_reshape = xp.reshape(atom_pref, [-1]) + l2_pref_force_loss = xp.mean( + xp.multiply(xp.square(diff_f), atom_pref_reshape), + ) l2_loss += pref_pf * l2_pref_force_loss more_loss["l2_pref_force_loss"] = self.display_if_exist( l2_pref_force_loss, find_atom_pref ) if self.has_gf: + find_drdq = label_dict["find_drdq"] + drdq = label_dict["drdq"] + force_reshape_nframes = xp.reshape(force, [-1, natoms[0] * 3]) + force_hat_reshape_nframes = xp.reshape(force_hat, [-1, natoms[0] * 3]) + drdq_reshape = xp.reshape( + drdq, [-1, natoms[0] * 3, self.numb_generalized_coord] + ) + gen_force_hat = xp.einsum( + "bij,bi->bj", drdq_reshape, force_hat_reshape_nframes + ) + gen_force = xp.einsum("bij,bi->bj", drdq_reshape, force_reshape_nframes) + diff_gen_force = gen_force_hat - gen_force + l2_gen_force_loss = xp.mean(xp.square(diff_gen_force)) + pref_gf = find_drdq * ( + self.limit_pref_gff + + (self.start_pref_gf - self.limit_pref_gf) * lr_ratio + ) l2_loss += pref_gf * l2_gen_force_loss more_loss["l2_gen_force_loss"] = self.display_if_exist( l2_gen_force_loss, find_drdq