Skip to content

Commit

Permalink
Change template variable of int1e density contraction kernel, very bi…
Browse files Browse the repository at this point in the history
…g improvement
  • Loading branch information
henryw7 committed Dec 14, 2024
1 parent 213331e commit 2c79960
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 80 deletions.
26 changes: 16 additions & 10 deletions gpu4pyscf/lib/gint/g1e.cu
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,13 @@ static void GINT_g1e_save_u2(double* __restrict__ g, double* __restrict__ u2_sav
}
}

template <int NROOTS>
template <int L_SUM>
__device__
static void GINT_g1e_without_hrr(double* __restrict__ g, const double grid_x, const double grid_y, const double grid_z,
const int ish, const int prim_ij, const int l, const double charge_exponent, const double omega)
const int ish, const int prim_ij, const double charge_exponent, const double omega)
{
constexpr int NROOTS = L_SUM / 2 + 1;

const double* __restrict__ a12 = c_bpcache.a12;
const double* __restrict__ e12 = c_bpcache.e12;
const double* __restrict__ x12 = c_bpcache.x12;
Expand Down Expand Up @@ -288,7 +290,7 @@ static void GINT_g1e_without_hrr(double* __restrict__ g, const double grid_x, co

const double* __restrict__ u = uw;
const double* __restrict__ w = u + NROOTS;
const int g_size = NROOTS * (l + 1);
constexpr int g_size = NROOTS * (L_SUM + 1);
double* __restrict__ gx = g;
double* __restrict__ gy = g + g_size;
double* __restrict__ gz = g + g_size * 2;
Expand Down Expand Up @@ -317,7 +319,7 @@ static void GINT_g1e_without_hrr(double* __restrict__ g, const double grid_x, co
const double c00y = PAy - qt2_over_p_plus_q * PCy;
const double c00z = PAz - qt2_over_p_plus_q * PCz;

if (l > 0) {
if constexpr (L_SUM > 0) {
double s0x = gx[i_root]; // i - 1
double s0y = gy[i_root];
double s0z = gz[i_root];
Expand All @@ -327,7 +329,8 @@ static void GINT_g1e_without_hrr(double* __restrict__ g, const double grid_x, co
gx[i_root + 1 * NROOTS] = s1x;
gy[i_root + 1 * NROOTS] = s1y;
gz[i_root + 1 * NROOTS] = s1z;
for (int i_rys = 1; i_rys < l; i_rys++) {
#pragma unroll
for (int i_rys = 1; i_rys < L_SUM; i_rys++) {
const double s2x = c00x * s1x + i_rys * b10 * s0x; // i + 1
const double s2y = c00y * s1y + i_rys * b10 * s0y;
const double s2z = c00z * s1z + i_rys * b10 * s0z;
Expand All @@ -346,11 +349,13 @@ static void GINT_g1e_without_hrr(double* __restrict__ g, const double grid_x, co

}

template <int NROOTS>
template <int L_SUM>
__device__
static void GINT_g1e_without_hrr_save_u2(double* __restrict__ g, double* __restrict__ u2_save, const double grid_x, const double grid_y, const double grid_z,
const int ish, const int prim_ij, const int l, const double charge_exponent, const double omega)
const int ish, const int prim_ij, const double charge_exponent, const double omega)
{
constexpr int NROOTS = L_SUM / 2 + 1;

const double* __restrict__ a12 = c_bpcache.a12;
const double* __restrict__ e12 = c_bpcache.e12;
const double* __restrict__ x12 = c_bpcache.x12;
Expand Down Expand Up @@ -382,7 +387,7 @@ static void GINT_g1e_without_hrr_save_u2(double* __restrict__ g, double* __restr

const double* __restrict__ u = uw;
const double* __restrict__ w = u + NROOTS;
const int g_size = NROOTS * (l + 1);
constexpr int g_size = NROOTS * (L_SUM + 1);
double* __restrict__ gx = g;
double* __restrict__ gy = g + g_size;
double* __restrict__ gz = g + g_size * 2;
Expand Down Expand Up @@ -412,7 +417,7 @@ static void GINT_g1e_without_hrr_save_u2(double* __restrict__ g, double* __restr
const double c00y = PAy - qt2_over_p_plus_q * PCy;
const double c00z = PAz - qt2_over_p_plus_q * PCz;

if (l > 0) {
if constexpr (L_SUM > 0) {
double s0x = gx[i_root]; // i - 1
double s0y = gy[i_root];
double s0z = gz[i_root];
Expand All @@ -422,7 +427,8 @@ static void GINT_g1e_without_hrr_save_u2(double* __restrict__ g, double* __restr
gx[i_root + 1 * NROOTS] = s1x;
gy[i_root + 1 * NROOTS] = s1y;
gz[i_root + 1 * NROOTS] = s1z;
for (int i_rys = 1; i_rys < l; i_rys++) {
#pragma unroll
for (int i_rys = 1; i_rys < L_SUM; i_rys++) {
const double s2x = c00x * s1x + i_rys * b10 * s0x; // i + 1
const double s2y = c00y * s1y + i_rys * b10 * s0y;
const double s2z = c00z * s1z + i_rys * b10 * s0z;
Expand Down
36 changes: 21 additions & 15 deletions gpu4pyscf/lib/gint/g3c1e.cu
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,14 @@ static void GINTfill_int3c1e_charge_contracted_kernel_general(double* output, co
}
}

template <int NROOTS>
template <int L_SUM>
__global__
static void GINTfill_int3c1e_density_contracted_kernel_general(double* output, const double* density, const HermiteDensityOffsets hermite_density_offsets,
const BasisProdOffsets offsets, const int i_l, const int j_l, const int nprim_ij,
const BasisProdOffsets offsets, const int nprim_ij,
const double omega, const double* grid_points, const double* charge_exponents)
{
constexpr int NROOTS = L_SUM / 2 + 1;

const int ntasks_ij = offsets.ntasks_ij;
const int ngrids = offsets.ntasks_kl;
const int task_grid = blockIdx.y * blockDim.y + threadIdx.y;
Expand All @@ -207,30 +209,34 @@ static void GINTfill_int3c1e_density_contracted_kernel_general(double* output, c
const int ish = bas_pair2bra[bas_ij];
// const int jsh = bas_pair2ket[bas_ij];

constexpr int l_max = (NROOTS - 1) * 2 + 1;
double D_hermite[(l_max + 1) * (l_max + 2) * (l_max + 3) / 6];
const int l = i_l + j_l;
for (int i_t = 0; i_t < (l + 1) * (l + 2) * (l + 3) / 6; i_t++) {
double D_hermite[(L_SUM + 1) * (L_SUM + 2) * (L_SUM + 3) / 6];
#pragma unroll
for (int i_t = 0; i_t < (L_SUM + 1) * (L_SUM + 2) * (L_SUM + 3) / 6; i_t++) {
D_hermite[i_t] = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + i_t * hermite_density_offsets.n_pair_of_angular_pair];
}

double eri_with_density_per_pair = 0.0;
for (int ij = prim_ij; ij < prim_ij+nprim_ij; ++ij) {
double g[NROOTS * (l_max + 1) * 3];
GINT_g1e_without_hrr<NROOTS>(g, Cx, Cy, Cz, ish, ij, l, charge_exponent, omega);
double g[NROOTS * (L_SUM + 1) * 3];
GINT_g1e_without_hrr<L_SUM>(g, Cx, Cy, Cz, ish, ij, charge_exponent, omega);

double eri_with_density_per_primitive = 0.0;
for (int i_x = 0, i_t = 0; i_x <= l; i_x++) {
for (int i_y = 0; i_x + i_y <= l; i_y++) {
for (int i_z = 0; i_x + i_y + i_z <= l; i_z++, i_t++) {
const double D_t = D_hermite[i_t];
#pragma unroll
for (int i_x = 0, i_t = 0; i_x <= L_SUM; i_x++) {
#pragma unroll
for (int i_y = 0; i_x + i_y <= L_SUM; i_y++) {
#pragma unroll
for (int i_z = 0; i_x + i_y + i_z <= L_SUM; i_z++, i_t++) {
double eri_per_hermite = 0.0;
#pragma unroll
for (int i_root = 0; i_root < NROOTS; i_root++) {
const double gx = g[i_root + NROOTS * i_x];
const double gy = g[i_root + NROOTS * i_y + NROOTS * (l + 1)];
const double gz = g[i_root + NROOTS * i_z + NROOTS * (l + 1) * 2];
eri_with_density_per_primitive += gx * gy * gz * D_t;
const double gy = g[i_root + NROOTS * i_y + NROOTS * (L_SUM + 1)];
const double gz = g[i_root + NROOTS * i_z + NROOTS * (L_SUM + 1) * 2];
eri_per_hermite += gx * gy * gz;
}
const double D_t = D_hermite[i_t];
eri_with_density_per_primitive += eri_per_hermite * D_t;
}
}
}
Expand Down
54 changes: 32 additions & 22 deletions gpu4pyscf/lib/gint/g3c1e_ip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,11 @@ static void GINTwrite_int3c1e_ip1_charge_contracted(const double* g, double* loc
const double* __restrict__ gy = g + g_size;
const double* __restrict__ gz = g + g_size * 2;

for (int j = 0; j < (j_l + 1) * (j_l + 2) / 2; j++) {
for (int i = 0; i < (i_l + 1) * (i_l + 2) / 2; i++) {
const int n_density_elements_i = (i_l + 1) * (i_l + 2) / 2;
const int n_density_elements_j = (j_l + 1) * (j_l + 2) / 2;
const int n_density_elements_ij = n_density_elements_i * n_density_elements_j;
for (int j = 0; j < n_density_elements_j; j++) {
for (int i = 0; i < n_density_elements_i; i++) {
const int loc_j = c_l_locs[j_l] + j;
const int loc_i = c_l_locs[i_l] + i;
const int ix = idx[loc_i];
Expand Down Expand Up @@ -178,9 +181,6 @@ static void GINTwrite_int3c1e_ip1_charge_contracted(const double* g, double* loc
deri_dAy += gx_0 * dgy_dAy * gz_0;
deri_dAz += gx_0 * gy_0 * dgz_dAz;
}
const int n_density_elements_i = (i_l + 1) * (i_l + 2) / 2;
const int n_density_elements_j = (j_l + 1) * (j_l + 2) / 2;
const int n_density_elements_ij = n_density_elements_i * n_density_elements_j;
local_output[i + j * n_density_elements_i + 0 * n_density_elements_ij] += deri_dAx * prefactor;
local_output[i + j * n_density_elements_i + 1 * n_density_elements_ij] += deri_dAy * prefactor;
local_output[i + j * n_density_elements_i + 2 * n_density_elements_ij] += deri_dAz * prefactor;
Expand Down Expand Up @@ -249,12 +249,14 @@ static void GINTfill_int3c1e_ip1_charge_contracted_kernel_general(double* output
}
}

template <int NROOTS>
template <int L_SUM>
__global__
static void GINTfill_int3c1e_ip2_density_contracted_kernel_general(double* output, const double* density, const HermiteDensityOffsets hermite_density_offsets,
const BasisProdOffsets offsets, const int i_l, const int j_l, const int nprim_ij,
const BasisProdOffsets offsets, const int nprim_ij,
const double omega, const double* grid_points, const double* charge_exponents)
{
constexpr int NROOTS = (L_SUM + 1) / 2 + 1;

const int ntasks_ij = offsets.ntasks_ij;
const int ngrids = offsets.ntasks_kl;
const int task_grid = blockIdx.y * blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -286,29 +288,33 @@ static void GINTfill_int3c1e_ip2_density_contracted_kernel_general(double* outpu
const double Ay = bas_y[ish];
const double Az = bas_z[ish];

constexpr int l_max = (NROOTS - 1) * 2 + 1;
double D_hermite[(l_max + 1) * (l_max + 2) * (l_max + 3) / 6];
const int l = i_l + j_l;
for (int i_t = 0; i_t < (l + 1) * (l + 2) * (l + 3) / 6; i_t++) {
double D_hermite[(L_SUM + 1) * (L_SUM + 2) * (L_SUM + 3) / 6];
#pragma unroll
for (int i_t = 0; i_t < (L_SUM + 1) * (L_SUM + 2) * (L_SUM + 3) / 6; i_t++) {
D_hermite[i_t] = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + i_t * hermite_density_offsets.n_pair_of_angular_pair];
}

double deri_dCx_per_pair = 0.0;
double deri_dCy_per_pair = 0.0;
double deri_dCz_per_pair = 0.0;
for (int ij = prim_ij; ij < prim_ij+nprim_ij; ++ij) {
double g[NROOTS * (l_max + 1) * 3];
double g[NROOTS * (L_SUM + 1 + 1) * 3];
double u2[NROOTS];
GINT_g1e_without_hrr_save_u2<NROOTS>(g, u2, Cx, Cy, Cz, ish, ij, l + 1, charge_exponent, omega);
GINT_g1e_without_hrr_save_u2<L_SUM + 1>(g, u2, Cx, Cy, Cz, ish, ij, charge_exponent, omega);

const double* __restrict__ gx = g;
const double* __restrict__ gy = g + NROOTS * (l + 1 + 1);
const double* __restrict__ gz = g + NROOTS * (l + 1 + 1) * 2;
const double* __restrict__ gy = g + NROOTS * (L_SUM + 1 + 1);
const double* __restrict__ gz = g + NROOTS * (L_SUM + 1 + 1) * 2;

for (int i_x = 0, i_t = 0; i_x <= l; i_x++) {
for (int i_y = 0; i_x + i_y <= l; i_y++) {
for (int i_z = 0; i_x + i_y + i_z <= l; i_z++, i_t++) {
const double D_t = D_hermite[i_t];
#pragma unroll
for (int i_x = 0, i_t = 0; i_x <= L_SUM; i_x++) {
#pragma unroll
for (int i_y = 0; i_x + i_y <= L_SUM; i_y++) {
#pragma unroll
for (int i_z = 0; i_x + i_y + i_z <= L_SUM; i_z++, i_t++) {
double deri_dCx_per_hermite = 0.0;
double deri_dCy_per_hermite = 0.0;
double deri_dCz_per_hermite = 0.0;
#pragma unroll
for (int i_root = 0; i_root < NROOTS; i_root++) {
const double gx_0 = gx[i_root + NROOTS * i_x];
Expand All @@ -321,10 +327,14 @@ static void GINTfill_int3c1e_ip2_density_contracted_kernel_general(double* outpu
const double dgx_dCx = minus_two_u2 * (gx_1 + (Ax - Cx) * gx_0);
const double dgy_dCy = minus_two_u2 * (gy_1 + (Ay - Cy) * gy_0);
const double dgz_dCz = minus_two_u2 * (gz_1 + (Az - Cz) * gz_0);
deri_dCx_per_pair += dgx_dCx * gy_0 * gz_0 * D_t;
deri_dCy_per_pair += gx_0 * dgy_dCy * gz_0 * D_t;
deri_dCz_per_pair += gx_0 * gy_0 * dgz_dCz * D_t;
deri_dCx_per_hermite += dgx_dCx * gy_0 * gz_0;
deri_dCy_per_hermite += gx_0 * dgy_dCy * gz_0;
deri_dCz_per_hermite += gx_0 * gy_0 * dgz_dCz;
}
const double D_t = D_hermite[i_t];
deri_dCx_per_pair += deri_dCx_per_hermite * D_t;
deri_dCy_per_pair += deri_dCy_per_hermite * D_t;
deri_dCz_per_pair += deri_dCz_per_hermite * D_t;
}
}
}
Expand Down
30 changes: 13 additions & 17 deletions gpu4pyscf/lib/gint/nr_fill_ao_int3c1e.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,29 +112,25 @@ static int GINTfill_int3c1e_density_contracted_tasks(double* output, const doubl
const double omega, const double* grid_points, const double* charge_exponents,
const int n_pair_sum_per_thread, const cudaStream_t stream)
{
const int nrys_roots = (i_l + j_l) / 2 + 1;
const int ntasks_ij = (offsets.ntasks_ij + n_pair_sum_per_thread - 1) / n_pair_sum_per_thread;
const int ngrids = offsets.ntasks_kl;

const dim3 threads(THREADSX, THREADSY);
const dim3 blocks((ntasks_ij+THREADSX-1)/THREADSX, (ngrids+THREADSY-1)/THREADSY);
int type_ijkl;
switch (nrys_roots) {
case 1:
type_ijkl = (i_l << 2) | j_l;
switch (type_ijkl) {
case (0<<2)|0: GINTfill_int3c1e_density_contracted_kernel00<<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case (1<<2)|0: GINTfill_int3c1e_density_contracted_kernel10<<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
default:
fprintf(stderr, "roots=1 type_ijkl %d\n", type_ijkl);
}
break;
case 2: GINTfill_int3c1e_density_contracted_kernel_general<2> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, i_l, j_l, nprim_ij, omega, grid_points, charge_exponents); break;
case 3: GINTfill_int3c1e_density_contracted_kernel_general<3> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, i_l, j_l, nprim_ij, omega, grid_points, charge_exponents); break;
case 4: GINTfill_int3c1e_density_contracted_kernel_general<4> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, i_l, j_l, nprim_ij, omega, grid_points, charge_exponents); break;
case 5: GINTfill_int3c1e_density_contracted_kernel_general<5> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, i_l, j_l, nprim_ij, omega, grid_points, charge_exponents); break;
switch (i_l + j_l) {
case 0: GINTfill_int3c1e_density_contracted_kernel00<<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case 1: GINTfill_int3c1e_density_contracted_kernel10<<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case 2: GINTfill_int3c1e_density_contracted_kernel_general< 2> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case 3: GINTfill_int3c1e_density_contracted_kernel_general< 3> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case 4: GINTfill_int3c1e_density_contracted_kernel_general< 4> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case 5: GINTfill_int3c1e_density_contracted_kernel_general< 5> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case 6: GINTfill_int3c1e_density_contracted_kernel_general< 6> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case 7: GINTfill_int3c1e_density_contracted_kernel_general< 7> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case 8: GINTfill_int3c1e_density_contracted_kernel_general< 8> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case 9: GINTfill_int3c1e_density_contracted_kernel_general< 9> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
case 10: GINTfill_int3c1e_density_contracted_kernel_general<10> <<<blocks, threads, 0, stream>>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break;
default:
fprintf(stderr, "rys roots %d\n", nrys_roots);
fprintf(stderr, "i_l + j_l = %d out of range\n", i_l + j_l);
return 1;
}

Expand Down
Loading

0 comments on commit 2c79960

Please sign in to comment.