Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deterministic version of CUDA forces and stresses kernels #3693

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions doc/determinism.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Running DeepMD in full deterministic mode

With the default settings DeepMD does not guarantee that two successive trainings using the same data will return the same model parameters. The results will also depend on the processing units GPU vs CPU. Variations might also be observed between different families of GPUs. This document explains how to set up DeepMD runs to get reproducible results for a given set of training data and hardware architecture. It only applies to the forces and stress calculations during the training and inference phases.
mtaillefumier marked this conversation as resolved.
Show resolved Hide resolved

The GPU kernels calculating the forces and stress in DeepMD are deterministic. Calls to the TensorFlow API, however, do not guarantee that unless a set of environment variables affecting its execution are set up at runtime or if specific API calls are used during the TensorFlow initialization steps. The most important environment variable is `TF_DETERMINISTIC_OPS` that selects the deterministic variants of TensorFlow GPU functions if set to 1. Two other variables controlling the TensorFlow threading; `TF_INTER_OP_PARALLELISM_THREADS` and `TF_INTRA_OP_PARALLELISM_THREADS`; should be set to 0. More information about running TensorFlow in deterministic mode and what it implies, can be found [here](https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism). The `OMP_NUM_THREADS` variable seems to have less or no impact when the GPU version of DeepMD is used.
mtaillefumier marked this conversation as resolved.
Show resolved Hide resolved

Adding these three lines of code in the run scripts is enough to get reproducible results on the same hardware.

```[sh]
export TF_DETERMINISTIC_OPS=1
export TF_INTER_OP_PARALLELISM_THREADS=0
export TF_INTRA_OP_PARALLELISM_THREADS=0
```
207 changes: 123 additions & 84 deletions source/lib/src/gpu/prod_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,92 +43,147 @@ __global__ void force_deriv_wrt_center_atom(FPTYPE* force,
}
}

template <typename FPTYPE>
__global__ void force_deriv_wrt_neighbors_a(FPTYPE* force,
const FPTYPE* net_deriv,
const FPTYPE* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei) {
// idy -> nnei
const int_64 idx = blockIdx.x;
const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
const unsigned int idz = threadIdx.y;
const int ndescrpt = nnei * 4;
if (idy >= nnei) {
return;
}
// deriv wrt neighbors
int j_idx = nlist[idx * nnei + idy];
if (j_idx < 0) {
template <typename FPTYPE, bool radial_only_ = true, int shared_memory_block_>
__global__ void force_deriv_wrt_neighbors(FPTYPE* force,
const FPTYPE* net_deriv,
const FPTYPE* in_deriv,
const int* nlist,
const int nframes,
const int nloc,
const int nall,
const int nnei) {
// limited to 2 billions atoms and 2 billions frames
const int atom_id = blockIdx.x;
const int frame_id = blockIdx.z * gridDim.y + blockIdx.y;

if (frame_id >= nframes) {
return;
}
FPTYPE force_tmp = (FPTYPE)0.;
for (int idw = 0; idw < 4; ++idw) {
force_tmp += net_deriv[idx * ndescrpt + idy * 4 + idw] *
in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz];

const int ndescrpt = nnei * ((radial_only_) ? (1) : (4));

// define various pointers for a specific frame.
const FPTYPE* frame_net_deriv_ = &net_deriv[frame_id * nloc * ndescrpt];
const FPTYPE* frame_in_deriv_ = &in_deriv[frame_id * nloc * ndescrpt * 3];
const int* frame_neighbor_list_ = &nlist[frame_id * nnei * nloc];
FPTYPE force_tmp[3] = {(FPTYPE)0., (FPTYPE)0., (FPTYPE)0.};

for (int neighbor_id = threadIdx.x; neighbor_id < nnei;
neighbor_id += blockDim.x) {
// collect all terms $\partial E_j / \partial D_{ji} \nabla_R_j D_{ji}$
// where the atom i is a neighbor of the atom j.
//
// Go through all neighbors of atom i, locate the position of
// the atom i in the neighbor list of the atom j and retrieve all necessary
// information.

const int atom_j = frame_neighbor_list_[atom_id * nnei + neighbor_id];

// The neighbors of a given atom are sorted by type and each resulting list
// is separated from the other by a series of -1. More details about the
// sorting can be found in https://doi.org/10.1016/j.cpc.2020.107624
//
// To illustrate this, take the neigbhors of a given atom of type a (in a
// system with two atoms type a and b) deepmd stores the neighbors as
//
// [neighbors list of type a], -1, -1, -1, ...., [neighbor list of type b],
// -1, -1, -1, .....

if (atom_j < 0) {
continue;
}

const int* nei_nei_list_ = &frame_neighbor_list_[atom_j * nnei];
int atom_id_position = 0;

// search the index of the atom i in the local neighbor list of atom j
for (atom_id_position = 0; atom_id_position < nnei; atom_id_position++) {
if (nei_nei_list_[atom_id_position] == atom_id) {
break;
}
}

const int offset_j =
(atom_j * nnei + atom_id_position) * ((radial_only_) ? (1) : (4));
for (int idw = 0; idw < ((radial_only_) ? (1) : (4)); ++idw) {
const FPTYPE cst1 = frame_net_deriv_[offset_j + idw];
force_tmp[0] += cst1 * in_deriv[(offset_j + idw) * 3 + 0];
force_tmp[1] += cst1 * in_deriv[(offset_j + idw) * 3 + 1];
force_tmp[2] += cst1 * in_deriv[(offset_j + idw) * 3 + 2];
}
}
const int_64 kk = idx / nloc; // frame index
atomicAdd(force + kk * nall * 3 + j_idx * 3 + idz, force_tmp);
}

template <typename FPTYPE>
__global__ void force_deriv_wrt_neighbors_r(FPTYPE* force,
const FPTYPE* net_deriv,
const FPTYPE* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei) {
// idy -> nnei
const int_64 idx = blockIdx.x;
const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
const unsigned int idz = threadIdx.y;
const int ndescrpt = nnei * 1;
if (idy >= nnei) {
return;
__shared__ FPTYPE fx[shared_memory_block_];
__shared__ FPTYPE fy[shared_memory_block_];
__shared__ FPTYPE fz[shared_memory_block_];

fx[threadIdx.x] = force_tmp[0];
fy[threadIdx.x] = force_tmp[1];
fz[threadIdx.x] = force_tmp[2];
__syncthreads();

// do the final reduction
for (int tt = shared_memory_block_ / 2; tt > 0; tt >>= 1) {
if (threadIdx.x < tt) {
fx[threadIdx.x] += fx[threadIdx.x + tt];
fy[threadIdx.x] += fy[threadIdx.x + tt];
fz[threadIdx.x] += fz[threadIdx.x + tt];
}
__syncthreads();
}
// deriv wrt neighbors
int j_idx = nlist[idx * nnei + idy];
if (j_idx < 0) {
return;

/* Note the sign difference between the formula in the PRL paper and the code.
it is due to \nabla_R_j D_{ji} = -\nabla_R_i D_{ji} */
if (threadIdx.x == 0) {
const int64_t offset = (frame_id * nall + atom_id) * 3;
force[offset] += fx[0];
force[offset + 1] += fy[0];
force[offset + 2] += fz[0];
}
const int_64 kk = idx / nloc; // frame index
atomicAdd(force + kk * nall * 3 + j_idx * 3 + idz,
net_deriv[idx * ndescrpt + idy] *
in_deriv[idx * ndescrpt * 3 + idy * 3 + idz]);
}

namespace deepmd {
template <typename FPTYPE>
void prod_force_a_gpu(FPTYPE* force,
const FPTYPE* net_deriv,
const FPTYPE* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei,
const int nframes) {
template <typename FPTYPE, bool radial_only_ = true>
void prod_force_a_r_gpu(FPTYPE* force,
const FPTYPE* net_deriv,
const FPTYPE* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei,
const int nframes) {
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
const int ndescrpt = nnei * 4;
const int ndescrpt = nnei * ((radial_only_) ? (1) : (4));
DPErrcheck(gpuMemset(force, 0, sizeof(FPTYPE) * nframes * nall * 3));

force_deriv_wrt_center_atom<FPTYPE, TPB><<<nframes * nloc, TPB>>>(
force, net_deriv, in_deriv, ndescrpt, nloc, nall);
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());

const int LEN = 64;
const int nblock = (nnei + LEN - 1) / LEN;
dim3 block_grid(nframes * nloc, nblock);
dim3 thread_grid(LEN, 3);
force_deriv_wrt_neighbors_a<<<block_grid, thread_grid>>>(
force, net_deriv, in_deriv, nlist, nloc, nall, nnei);
const int sqrt_nframes = sqrt(nframes);
dim3 block_grid(nloc, sqrt_nframes + 1, sqrt_nframes + 1);
// to accomodate AMD GPU
dim3 thread_grid(64, 1, 1);
force_deriv_wrt_neighbors<FPTYPE, radial_only_, 64>
<<<block_grid, thread_grid>>>(force, net_deriv, in_deriv, nlist, nframes,
nloc, nall, nnei);
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}
namespace deepmd {
template <typename FPTYPE>
void prod_force_a_gpu(FPTYPE* force,
const FPTYPE* net_deriv,
const FPTYPE* in_deriv,
const int* nlist,
const int nloc,
const int nall,
const int nnei,
const int nframes) {
prod_force_a_r_gpu<FPTYPE, false>(force, net_deriv, in_deriv, nlist, nloc,
nall, nnei, nframes);
}

template <typename FPTYPE>
void prod_force_r_gpu(FPTYPE* force,
Expand All @@ -139,24 +194,8 @@ void prod_force_r_gpu(FPTYPE* force,
const int nall,
const int nnei,
const int nframes) {
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
const int ndescrpt = nnei * 1;
DPErrcheck(gpuMemset(force, 0, sizeof(FPTYPE) * nframes * nall * 3));

force_deriv_wrt_center_atom<FPTYPE, TPB><<<nframes * nloc, TPB>>>(
force, net_deriv, in_deriv, ndescrpt, nloc, nall);
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());

const int LEN = 64;
const int nblock = (nnei + LEN - 1) / LEN;
dim3 block_grid(nframes * nloc, nblock);
dim3 thread_grid(LEN, 3);
force_deriv_wrt_neighbors_r<<<block_grid, thread_grid>>>(
force, net_deriv, in_deriv, nlist, nloc, nall, nnei);
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
prod_force_a_r_gpu<FPTYPE, true>(force, net_deriv, in_deriv, nlist, nloc,
nall, nnei, nframes);
}

template void prod_force_a_gpu<float>(float* force,
Expand Down
Loading