Skip to content

Commit

Permalink
refactor DeepPotModelDevi
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz committed Jan 11, 2024
1 parent 43f9639 commit 6c73100
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 309 deletions.
29 changes: 1 addition & 28 deletions source/api_cc/include/DeepPot.h
Original file line number Diff line number Diff line change
Expand Up @@ -595,34 +595,7 @@ class DeepPotModelDevi {

private:
unsigned numb_models;
std::vector<tensorflow::Session*> sessions;
int num_intra_nthreads, num_inter_nthreads;
std::vector<tensorflow::GraphDef*> graph_defs;
std::vector<deepmd::DeepPot> dps;
bool inited;
template <class VT>
VT get_scalar(const std::string name) const;
// VALUETYPE get_rcut () const;
// int get_ntypes () const;
double rcut;
double cell_size;
int dtype;
std::string model_type;
std::string model_version;
int ntypes;
int ntypes_spin;
int dfparam;
int daparam;
bool aparam_nall;
template <typename VALUETYPE>
void validate_fparam_aparam(const int& nloc,
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam) const;

// copy neighbor list info from host
bool init_nbor;
std::vector<std::vector<int> > sec;
deepmd::AtomMap atommap;
NeighborListData nlist_data;
InputNlist nlist;
};
} // namespace deepmd
288 changes: 7 additions & 281 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1269,188 +1269,12 @@ void DeepPotModelDevi::init(const std::vector<std::string>& models,
return;
}
numb_models = models.size();
sessions.resize(numb_models);
graph_defs.resize(numb_models);

int gpu_num = -1;
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
DPGetDeviceCount(gpu_num);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

SessionOptions options;
get_env_nthreads(num_intra_nthreads, num_inter_nthreads);
options.config.set_inter_op_parallelism_threads(num_inter_nthreads);
options.config.set_intra_op_parallelism_threads(num_intra_nthreads);
for (unsigned ii = 0; ii < numb_models; ++ii) {
graph_defs[ii] = new GraphDef();
if (file_contents.size() == 0) {
check_status(ReadBinaryProto(Env::Default(), models[ii], graph_defs[ii]));
} else {
(*graph_defs[ii]).ParseFromString(file_contents[ii]);
}
}
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (gpu_num > 0) {
options.config.set_allow_soft_placement(true);
options.config.mutable_gpu_options()->set_per_process_gpu_memory_fraction(
0.9);
options.config.mutable_gpu_options()->set_allow_growth(true);
DPErrcheck(DPSetDevice(gpu_rank % gpu_num));
for (unsigned int ii = 0; ii < numb_models; ++ii) {
dps[ii] = DeepPot(models[ii], gpu_rank, file_contents[ii]);
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

for (unsigned ii = 0; ii < numb_models; ++ii) {
if (gpu_num > 0) {
std::string str = "/gpu:";
str += std::to_string(gpu_rank % gpu_num);
graph::SetDefaultDevice(str, &(*graph_defs[ii]));
}
check_status(NewSession(options, &(sessions[ii])));
check_status(sessions[ii]->Create(*graph_defs[ii]));
}
try {
model_version = get_scalar<STRINGTYPE>("model_attr/model_version");
} catch (deepmd::tf_exception& e) {
// no model version defined in old models
model_version = "0.0";
}
if (!model_compatable(model_version)) {
throw deepmd::deepmd_exception(
"incompatable model: version " + model_version +
" in graph, but version " + global_model_version +
" supported. "
"See https://deepmd.rtfd.io/compatability/ for details.");
}
dtype = session_get_dtype(sessions[0], "descrpt_attr/rcut");
if (dtype == tensorflow::DT_DOUBLE) {
rcut = get_scalar<double>("descrpt_attr/rcut");
} else {
rcut = get_scalar<float>("descrpt_attr/rcut");
}
cell_size = rcut;
ntypes = get_scalar<int>("descrpt_attr/ntypes");
try {
ntypes_spin = get_scalar<int>("spin_attr/ntypes_spin");
} catch (const deepmd::deepmd_exception) {
ntypes_spin = 0;
}
dfparam = get_scalar<int>("fitting_attr/dfparam");
daparam = get_scalar<int>("fitting_attr/daparam");
if (dfparam < 0) {
dfparam = 0;
}
if (daparam < 0) {
daparam = 0;
}
if (daparam > 0) {
try {
aparam_nall = get_scalar<bool>("fitting_attr/aparam_nall");
} catch (const deepmd::deepmd_exception) {
aparam_nall = false;
}
} else {
aparam_nall = false;
}
model_type = get_scalar<STRINGTYPE>("model_attr/model_type");
// rcut = get_rcut();
// cell_size = rcut;
// ntypes = get_ntypes();
inited = true;

init_nbor = false;
}

template <class VT>
VT DeepPotModelDevi::get_scalar(const std::string name) const {
VT myrcut;
for (unsigned ii = 0; ii < numb_models; ++ii) {
VT ret = session_get_scalar<VT>(sessions[ii], name);
if (ii == 0) {
myrcut = ret;
} else {
assert(myrcut == ret);
}
}
return myrcut;
}

template <typename VALUETYPE>
void DeepPotModelDevi::validate_fparam_aparam(
const int& nloc,
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam) const {
if (fparam.size() != dfparam) {
throw deepmd::deepmd_exception(
"the dim of frame parameter provided is not consistent with what the "
"model uses");
}
if (aparam.size() != daparam * nloc) {
throw deepmd::deepmd_exception(
"the dim of atom parameter provided is not consistent with what the "
"model uses");
}
}

template void DeepPotModelDevi::validate_fparam_aparam<double>(
const int& nloc,
const std::vector<double>& fparam,
const std::vector<double>& aparam) const;

template void DeepPotModelDevi::validate_fparam_aparam<float>(
const int& nloc,
const std::vector<float>& fparam,
const std::vector<float>& aparam) const;

// void
// DeepPotModelDevi::
// compute (ENERGYTYPE & dener,
// std::vector<VALUETYPE> & dforce_,
// std::vector<VALUETYPE> & dvirial,
// std::vector<VALUETYPE> & model_devi,
// const std::vector<VALUETYPE> & dcoord_,
// const std::vector<int> & datype_,
// const std::vector<VALUETYPE> & dbox,
// const std::vector<VALUETYPE> & fparam,
// const std::vector<VALUETYPE> & aparam)
// {
// if (numb_models == 0) return;

// atommap = AtomMap<VALUETYPE> (datype_.begin(), datype_.end());
// validate_fparam_aparam(atommap.get_type().size(), fparam, aparam);

// std::vector<std::pair<std::string, Tensor>> input_tensors;
// int nloc = session_input_tensors (input_tensors, dcoord_, ntypes, datype_,
// dbox, cell_size, fparam, aparam, atommap);

// std::vector<ENERGYTYPE > all_energy (numb_models);
// std::vector<std::vector<VALUETYPE > > all_force (numb_models);
// std::vector<std::vector<VALUETYPE > > all_virial (numb_models);

// for (unsigned ii = 0; ii < numb_models; ++ii){
// run_model (all_energy[ii], all_force[ii], all_virial[ii], sessions[ii],
// input_tensors, atommap);
// }

// dener = 0;
// for (unsigned ii = 0; ii < numb_models; ++ii){
// dener += all_energy[ii];
// }
// dener /= VALUETYPE(numb_models);
// compute_avg (dvirial, all_virial);
// compute_avg (dforce_, all_force);

// compute_std_f (model_devi, dforce_, all_force);

// // for (unsigned ii = 0; ii < numb_models; ++ii){
// // cout << all_force[ii][573] << " " << all_force[ii][574] << " " <<
// all_force[ii][575] << endl;
// // }
// // cout << dforce_[573] << " "
// // << dforce_[574] << " "
// // << dforce_[575] << " "
// // << model_devi[191] << endl;
// }

template <typename VALUETYPE>
void DeepPotModelDevi::compute(std::vector<ENERGYTYPE>& all_energy,
std::vector<std::vector<VALUETYPE>>& all_force,
Expand All @@ -1466,57 +1290,12 @@ void DeepPotModelDevi::compute(std::vector<ENERGYTYPE>& all_energy,
if (numb_models == 0) {
return;
}
int nall = dcoord_.size() / 3;
int nframes = 1;
int nloc = nall - nghost;
validate_fparam_aparam((aparam_nall ? nall : nloc), fparam, aparam_);
std::vector<std::pair<std::string, Tensor>> input_tensors;

// select real atoms
std::vector<VALUETYPE> dcoord, dforce, aparam, datom_energy, datom_virial;
std::vector<int> datype, fwd_map, bkw_map;
int nghost_real, nall_real, nloc_real;
select_real_atoms_coord(dcoord, datype, aparam, nghost_real, fwd_map, bkw_map,
nall_real, nloc_real, dcoord_, datype_, aparam_,
nghost, ntypes, nframes, daparam, nall, aparam_nall);

// agp == 0 means that the LAMMPS nbor list has been updated
if (ago == 0) {
atommap = AtomMap(datype.begin(), datype.begin() + nloc_real);
assert(nloc == atommap.get_type().size());

nlist_data.copy_from_nlist(lmp_list);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.shuffle(atommap);
nlist_data.make_inlist(nlist);
}
int ret;
if (dtype == tensorflow::DT_DOUBLE) {
ret = session_input_tensors<double>(input_tensors, dcoord, ntypes, datype,
dbox, nlist, fparam, aparam, atommap,
nghost_real, ago, "", aparam_nall);
} else {
ret = session_input_tensors<float>(input_tensors, dcoord, ntypes, datype,
dbox, nlist, fparam, aparam, atommap,
nghost_real, ago, "", aparam_nall);
}
all_energy.resize(numb_models);
all_force.resize(numb_models);
all_virial.resize(numb_models);
assert(nloc == ret);
for (unsigned ii = 0; ii < numb_models; ++ii) {
std::vector<VALUETYPE> dforce;
if (dtype == tensorflow::DT_DOUBLE) {
run_model<double>(all_energy[ii], dforce, all_virial[ii], sessions[ii],
input_tensors, atommap, 1, nghost_real);
} else {
run_model<float>(all_energy[ii], dforce, all_virial[ii], sessions[ii],
input_tensors, atommap, 1, nghost_real);
}
// bkw map
all_force[ii].resize(nframes * fwd_map.size() * 3);
select_map<VALUETYPE>(all_force[ii], dforce, bkw_map, 3, nframes,
fwd_map.size(), nall_real);
dps[ii].compute(all_energy[ii], all_force[ii], all_virial[ii], dcoord_,
datype_, dbox, nghost, lmp_list, ago, fparam, aparam_);
}
}

Expand Down Expand Up @@ -1564,68 +1343,15 @@ void DeepPotModelDevi::compute(
if (numb_models == 0) {
return;
}
int nframes = 1;
int nall = dcoord_.size() / 3;
int nloc = nall - nghost;
validate_fparam_aparam((aparam_nall ? nall : nloc), fparam, aparam_);
std::vector<std::pair<std::string, Tensor>> input_tensors;

// select real atoms
std::vector<VALUETYPE> dcoord, dforce, aparam, datom_energy, datom_virial;
std::vector<int> datype, fwd_map, bkw_map;
int nghost_real, nall_real, nloc_real;
select_real_atoms_coord(dcoord, datype, aparam, nghost_real, fwd_map, bkw_map,
nall_real, nloc_real, dcoord_, datype_, aparam_,
nghost, ntypes, nframes, daparam, nall, aparam_nall);
// agp == 0 means that the LAMMPS nbor list has been updated

if (ago == 0) {
atommap = AtomMap(datype.begin(), datype.begin() + nloc_real);
assert(nloc == atommap.get_type().size());

nlist_data.copy_from_nlist(lmp_list);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.shuffle(atommap);
nlist_data.make_inlist(nlist);
}
int ret;
if (dtype == tensorflow::DT_DOUBLE) {
ret = session_input_tensors<double>(input_tensors, dcoord, ntypes, datype,
dbox, nlist, fparam, aparam, atommap,
nghost_real, ago, "", aparam_nall);
} else {
ret = session_input_tensors<float>(input_tensors, dcoord, ntypes, datype,
dbox, nlist, fparam, aparam, atommap,
nghost_real, ago, "", aparam_nall);
}

all_energy.resize(numb_models);
all_force.resize(numb_models);
all_virial.resize(numb_models);
all_atom_energy.resize(numb_models);
all_atom_virial.resize(numb_models);
assert(nloc == ret);
for (unsigned ii = 0; ii < numb_models; ++ii) {
std::vector<VALUETYPE> dforce, datom_energy, datom_virial;
if (dtype == tensorflow::DT_DOUBLE) {
run_model<double>(all_energy[ii], dforce, all_virial[ii], datom_energy,
datom_virial, sessions[ii], input_tensors, atommap, 1,
nghost_real);
} else {
run_model<float>(all_energy[ii], dforce, all_virial[ii], datom_energy,
datom_virial, sessions[ii], input_tensors, atommap, 1,
nghost_real);
}
// bkw map
all_force[ii].resize(nframes * fwd_map.size() * 3);
all_atom_energy[ii].resize(nframes * fwd_map.size());
all_atom_virial[ii].resize(nframes * fwd_map.size() * 9);
select_map<VALUETYPE>(all_force[ii], dforce, bkw_map, 3, nframes,
fwd_map.size(), nall_real);
select_map<VALUETYPE>(all_atom_energy[ii], datom_energy, bkw_map, 1,
nframes, fwd_map.size(), nall_real);
select_map<VALUETYPE>(all_atom_virial[ii], datom_virial, bkw_map, 9,
nframes, fwd_map.size(), nall_real);
dps[ii].compute(all_energy[ii], all_force[ii], all_virial[ii],
all_atom_energy[ii], all_atom_virial[ii], dcoord_, datype_,
dbox, nghost, lmp_list, ago, fparam, aparam_);
}
}

Expand Down

0 comments on commit 6c73100

Please sign in to comment.