From 4c3a018657d0472284a70d7920d5ad95e17f72ad Mon Sep 17 00:00:00 2001 From: Henry Wang Date: Tue, 26 Nov 2024 09:47:50 +0800 Subject: [PATCH] Applied the suggested changes to: long range omega treatment, bpcache change reverted, handle too big grid, one thread sum all pairs --- gpu4pyscf/gto/moleintor.py | 85 ++++--- gpu4pyscf/gto/tests/test_int1e_grids.py | 11 +- gpu4pyscf/lib/gint/bpcache.cu | 3 +- gpu4pyscf/lib/gint/g1e.cu | 6 +- gpu4pyscf/lib/gint/g1e_root_123.cu | 225 ++++++++++--------- gpu4pyscf/lib/gint/g3c1e.cu | 74 +++--- gpu4pyscf/lib/gint/gint.h | 1 - gpu4pyscf/lib/gint/j_engine_matrix_reorder.c | 19 +- gpu4pyscf/lib/gint/nr_fill_ao_int3c1e.cu | 8 +- 9 files changed, 229 insertions(+), 203 deletions(-) diff --git a/gpu4pyscf/gto/moleintor.py b/gpu4pyscf/gto/moleintor.py index 8411da43..9b3476d1 100644 --- a/gpu4pyscf/gto/moleintor.py +++ b/gpu4pyscf/gto/moleintor.py @@ -18,6 +18,7 @@ import numpy as np from pyscf.scf import _vhf +from pyscf.gto import ATOM_OF from gpu4pyscf.lib.cupy_helper import load_library, cart2sph, block_c2s_diag, get_avail_mem from gpu4pyscf.lib import logger from gpu4pyscf.scf.int4c2e import BasisProdCache @@ -186,7 +187,6 @@ def get_n_hermite_density_of_angular_pair(l): def get_int3c1e_slice(intopt, cp_ij_id, grids, out, omega): stream = cp.cuda.get_current_stream() - if omega is None: omega = 0.0 nao_cart = intopt.mol.nao cpi = intopt.cp_idx[cp_ij_id] @@ -222,7 +222,10 @@ def get_int3c1e_slice(intopt, cp_ij_id, grids, out, omega): if err != 0: raise RuntimeError('GINTfill_int3c1e failed') -def get_int3c1e(mol, grids, direct_scf_tol, omega): +def get_int3c1e(mol, grids, direct_scf_tol): + omega = mol.omega + assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented." + intopt = VHFOpt(mol, 'int2e') intopt.build(direct_scf_tol, diag_block_with_triu=True, aosym=True, group_size=BLKSIZE) @@ -257,7 +260,7 @@ def get_int3c1e(mol, grids, direct_scf_tol, omega): j0, j1 = intopt.cart_ao_loc[cpj], intopt.cart_ao_loc[cpj+1] int3c_angular_slice = cp.zeros([ngrids_of_split, j1-j0, i1-i0], order='C') - get_int3c1e_slice(intopt, cp_ij_id, grids[i_grid_split : i_grid_split + ngrids_of_split], out=int3c_angular_slice, omega=omega) + get_int3c1e_slice(intopt, cp_ij_id, grids[i_grid_split : i_grid_split + ngrids_of_split, :], out=int3c_angular_slice, omega=omega) i0, i1 = intopt.ao_loc[cpi], intopt.ao_loc[cpi+1] j0, j1 = intopt.ao_loc[cpj], intopt.ao_loc[cpj+1] if not mol.cart: @@ -270,15 +273,14 @@ def get_int3c1e(mol, grids, direct_scf_tol, omega): grid_idx = np.arange(ngrids_of_split) int3c_grid_slice = int3c_grid_slice[np.ix_(grid_idx, ao_idx, ao_idx)] - cp.cuda.runtime.memcpy(int3c[i_grid_split : i_grid_split + ngrids_of_split, :, :].ctypes.data, - int3c_grid_slice.data.ptr, - int3c_grid_slice.nbytes, - cp.cuda.runtime.memcpyDeviceToHost) - # int3c[i_grid_split : i_grid_split + ngrids_of_split, :, :] = cp.asnumpy(int3c_grid_slice) # This is certainly the wrong way of DtoH memcpy + int3c_grid_slice.get(out = int3c[i_grid_split : i_grid_split + ngrids_of_split, :, :]) return int3c -def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol, omega): +def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol): + omega = mol.omega + assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented." + if cp.get_array_module(dm) is cp: dm = cp.asnumpy(dm) assert cp.get_array_module(dm) is np @@ -289,19 +291,20 @@ def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol, omega): nao_cart = intopt.mol.nao ngrids = grids.shape[0] - # TODO: Split ngrids to make sure GPU block and thread doesn't overflow dm = dm[np.ix_(intopt.ao_idx, intopt.ao_idx)] # intopt.ao_idx is in spherical basis if not mol.cart: cart2sph_transformation_matrix = cp.asnumpy(intopt.cart2sph) # TODO: This part is inefficient (O(N^3)), should be changed to the O(N^2) algorithm dm = cart2sph_transformation_matrix @ dm @ cart2sph_transformation_matrix.T + dm = dm.flatten(order='F') # Column major order matches (i + j * n_ao) access pattern in the C function ao_loc_sorted_order = intopt.sorted_mol.ao_loc_nr(cart = True) l_ij = intopt.l_ij.T.flatten() + bas_coords = intopt.sorted_mol.atom_coords()[intopt.sorted_mol._bas[:, ATOM_OF]].flatten() + n_total_hermite_density = intopt.density_offset[-1] dm_pair_ordered = np.zeros(n_total_hermite_density) - dm = dm.flatten(order='F') # Column major order matches (i + j * n_ao) access pattern in the following function libgvhf.GINTinit_J_density_rys_preprocess(dm.ctypes.data_as(ctypes.c_void_p), dm_pair_ordered.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(1), ctypes.c_int(nao_cart), ctypes.c_int(len(intopt.bas_pairs_locs) - 1), @@ -310,36 +313,50 @@ def get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol, omega): l_ij.ctypes.data_as(ctypes.c_void_p), intopt.density_offset.ctypes.data_as(ctypes.c_void_p), ao_loc_sorted_order.ctypes.data_as(ctypes.c_void_p), - intopt.bpcache) + bas_coords.ctypes.data_as(ctypes.c_void_p)) dm_pair_ordered = cp.asarray(dm_pair_ordered) grids = cp.asarray(grids, order='C') int3c_density_contracted = cp.zeros(ngrids) - for cp_ij_id, _ in enumerate(intopt.log_qs): - stream = cp.cuda.get_current_stream() - if omega is None: omega = 0.0 + n_threads_per_block_1d = 16 + n_max_blocks_per_grid_1d = 65535 + n_max_threads_1d = n_threads_per_block_1d * n_max_blocks_per_grid_1d + n_grid_split = int(np.ceil(ngrids / n_max_threads_1d)) + if (n_grid_split > 100): + print(f"Grid dimension = {ngrids} is too large, more than 100 kernels for one electron integral will be launched.") + ngrids_per_split = (ngrids + n_grid_split - 1) // n_grid_split + + for i_grid_split in range(0, ngrids, ngrids_per_split): + ngrids_of_split = np.min([ngrids_per_split, ngrids - i_grid_split]) + for cp_ij_id, _ in enumerate(intopt.log_qs): + stream = cp.cuda.get_current_stream() + + log_q_ij = intopt.log_qs[cp_ij_id] + + nbins = 1 + bins_locs_ij = np.array([0, len(log_q_ij)], dtype=np.int32) - log_q_ij = intopt.log_qs[cp_ij_id] - nbins = 1 - bins_locs_ij = np.array([0, len(log_q_ij)], dtype=np.int32) + n_pair_sum_per_thread = nao_cart # 1 means every thread processes one pair and one grid + # nao_cart or larger number gaurantees one thread processes one grid and all pairs of the same type - err = libgint.GINTfill_int3c1e_density_contracted( - ctypes.cast(stream.ptr, ctypes.c_void_p), - intopt.bpcache, - ctypes.cast(grids.data.ptr, ctypes.c_void_p), - ctypes.c_int(grids.shape[0]), - ctypes.cast(dm_pair_ordered.data.ptr, ctypes.c_void_p), - intopt.density_offset.ctypes.data_as(ctypes.c_void_p), - ctypes.cast(int3c_density_contracted.data.ptr, ctypes.c_void_p), - bins_locs_ij.ctypes.data_as(ctypes.c_void_p), - ctypes.c_int(nbins), - ctypes.c_int(cp_ij_id), - ctypes.c_double(omega)) + err = libgint.GINTfill_int3c1e_density_contracted( + ctypes.cast(stream.ptr, ctypes.c_void_p), + intopt.bpcache, + ctypes.cast(grids[i_grid_split : i_grid_split + ngrids_of_split, :].data.ptr, ctypes.c_void_p), + ctypes.c_int(ngrids_of_split), + ctypes.cast(dm_pair_ordered.data.ptr, ctypes.c_void_p), + intopt.density_offset.ctypes.data_as(ctypes.c_void_p), + ctypes.cast(int3c_density_contracted[i_grid_split : i_grid_split + ngrids_of_split].data.ptr, ctypes.c_void_p), + bins_locs_ij.ctypes.data_as(ctypes.c_void_p), + ctypes.c_int(nbins), + ctypes.c_int(cp_ij_id), + ctypes.c_double(omega), + ctypes.c_int(n_pair_sum_per_thread)) - if err != 0: - raise RuntimeError('GINTfill_int3c1e failed') + if err != 0: + raise RuntimeError('GINTfill_int3c1e failed') return cp.asnumpy(int3c_density_contracted) @@ -350,9 +367,9 @@ def intor(mol, intor, grids, dm=None, charges=None, direct_scf_tol=1e-13, omega= "If so, pass in density, obtain the result with n_charge and contract with the charges yourself." if dm is None and charges is None: - return get_int3c1e(mol, grids, direct_scf_tol, omega) + return get_int3c1e(mol, grids, direct_scf_tol) elif dm is not None: - return get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol, omega) + return get_int3c1e_density_contracted(mol, grids, dm, direct_scf_tol) elif charges is not None: raise NotImplementedError() else: diff --git a/gpu4pyscf/gto/tests/test_int1e_grids.py b/gpu4pyscf/gto/tests/test_int1e_grids.py index 75fa05cd..98faad07 100644 --- a/gpu4pyscf/gto/tests/test_int1e_grids.py +++ b/gpu4pyscf/gto/tests/test_int1e_grids.py @@ -102,20 +102,21 @@ def test_int1e_grids_full_tensor_omega(self): omega = 0.8 mol_sph_omega = mol_sph.copy() mol_sph_omega.set_range_coulomb(omega) + ref_int1e = mol_sph_omega.intor('int1e_grids', grids=grid_points) - test_int1e = intor(mol_sph, 'int1e_grids', grid_points, omega = omega) + test_int1e = intor(mol_sph_omega, 'int1e_grids', grid_points) assert np.abs(ref_int1e - test_int1e).max() < integral_threshold def test_int1e_grids_density_contracted_omega(self): + np.random.seed(12349) + dm = np.random.uniform(-2.0, 2.0, (mol_sph.nao, mol_sph.nao)) + omega = 1.2 mol_sph_omega = mol_sph.copy() mol_sph_omega.set_range_coulomb(omega) - np.random.seed(12349) - dm = np.random.uniform(-2.0, 2.0, (mol_sph.nao, mol_sph.nao)) - ref_int1e_dot_D = np.einsum('pij,ij->p', mol_sph_omega.intor('int1e_grids', grids=grid_points), dm) - test_int1e_dot_D = intor(mol_sph, 'int1e_grids', grid_points, dm = dm, omega = omega) + test_int1e_dot_D = intor(mol_sph_omega, 'int1e_grids', grid_points, dm = dm) assert np.abs(ref_int1e_dot_D - test_int1e_dot_D).max() < integral_threshold if __name__ == "__main__": diff --git a/gpu4pyscf/lib/gint/bpcache.cu b/gpu4pyscf/lib/gint/bpcache.cu index 5fcde418..5c00b264 100644 --- a/gpu4pyscf/lib/gint/bpcache.cu +++ b/gpu4pyscf/lib/gint/bpcache.cu @@ -53,7 +53,6 @@ void GINTdel_basis_prod(BasisProdCache **pbp) if (bpcache->aexyz != NULL) { free(bpcache->aexyz); - free(bpcache->h_bas_coords); } if (bpcache->a12 != NULL) { @@ -94,7 +93,7 @@ void GINTinit_basis_prod(BasisProdCache **pbp, double diag_fac, int *ao_loc, GINTsort_bas_coordinates(bas_coords, atm, natm, bas, nbas, env); DEVICE_INIT(double, d_bas_coords, bas_coords, nbas * 3); bpcache->bas_coords = d_bas_coords; - bpcache->h_bas_coords = bas_coords; + free(bas_coords); // initialize pair data on GPU memory DEVICE_INIT(double, d_aexyz, aexyz, n_primitive_pairs * 7); diff --git a/gpu4pyscf/lib/gint/g1e.cu b/gpu4pyscf/lib/gint/g1e.cu index d18ec238..327abda0 100644 --- a/gpu4pyscf/lib/gint/g1e.cu +++ b/gpu4pyscf/lib/gint/g1e.cu @@ -116,9 +116,9 @@ static void GINTg1e(double* __restrict__ g, const double* __restrict__ grid_poin const double ABy = Ay - By; const double ABz = Az - Bz; - for (int i_root = 0; i_root < NROOTS; i_root++) { - for (int j_rys = 0; j_rys < j_l; j_rys++) { - for (int i_rys = i_l + j_l - j_rys - 1; i_rys >= 0; i_rys--) { + for (int j_rys = 0; j_rys < j_l; j_rys++) { + for (int i_rys = i_l + j_l - j_rys - 1; i_rys >= 0; i_rys--) { + for (int i_root = 0; i_root < NROOTS; i_root++) { gx[i_root + (i_rys + (j_rys+1) * (i_l+1)) * NROOTS] = gx[i_root + (i_rys+1 + j_rys * (i_l+1)) * NROOTS] + ABx * gx[i_root + (i_rys + j_rys * (i_l+1)) * NROOTS]; gy[i_root + (i_rys + (j_rys+1) * (i_l+1)) * NROOTS] = gy[i_root + (i_rys+1 + j_rys * (i_l+1)) * NROOTS] + ABy * gy[i_root + (i_rys + j_rys * (i_l+1)) * NROOTS]; gz[i_root + (i_rys + (j_rys+1) * (i_l+1)) * NROOTS] = gz[i_root + (i_rys+1 + j_rys * (i_l+1)) * NROOTS] + ABz * gz[i_root + (i_rys + j_rys * (i_l+1)) * NROOTS]; diff --git a/gpu4pyscf/lib/gint/g1e_root_123.cu b/gpu4pyscf/lib/gint/g1e_root_123.cu index ef71a1cc..67067158 100644 --- a/gpu4pyscf/lib/gint/g1e_root_123.cu +++ b/gpu4pyscf/lib/gint/g1e_root_123.cu @@ -87,58 +87,61 @@ static void GINTfill_int3c1e_density_contracted_kernel00(double* output, const d { const int ntasks_ij = offsets.ntasks_ij; const int ngrids = offsets.ntasks_kl; - const int task_ij = blockIdx.x * blockDim.x + threadIdx.x; const int task_grid = blockIdx.y * blockDim.y + threadIdx.y; - - if (task_ij >= ntasks_ij || task_grid >= ngrids) { + if (task_grid >= ngrids) { return; } - const int bas_ij = offsets.bas_ij + task_ij; - const int prim_ij = offsets.primitive_ij + task_ij * nprim_ij; - // const int* bas_pair2bra = c_bpcache.bas_pair2bra; - // const int* bas_pair2ket = c_bpcache.bas_pair2ket; - // const int ish = bas_pair2bra[bas_ij]; - // const int jsh = bas_pair2ket[bas_ij]; - const double* __restrict__ a12 = c_bpcache.a12; - const double* __restrict__ e12 = c_bpcache.e12; - const double* __restrict__ x12 = c_bpcache.x12; - const double* __restrict__ y12 = c_bpcache.y12; - const double* __restrict__ z12 = c_bpcache.z12; + double eri_pair_sum = 0.0; + for (int task_ij = blockIdx.x * blockDim.x + threadIdx.x; task_ij < ntasks_ij; task_ij += gridDim.x * blockDim.x) { + const int bas_ij = offsets.bas_ij + task_ij; + const int prim_ij = offsets.primitive_ij + task_ij * nprim_ij; + // const int* bas_pair2bra = c_bpcache.bas_pair2bra; + // const int* bas_pair2ket = c_bpcache.bas_pair2ket; + // const int ish = bas_pair2bra[bas_ij]; + // const int jsh = bas_pair2ket[bas_ij]; - const double* grid_point = grid_points + task_grid * 3; - const double Cx = grid_point[0]; - const double Cy = grid_point[1]; - const double Cz = grid_point[2]; + const double* __restrict__ a12 = c_bpcache.a12; + const double* __restrict__ e12 = c_bpcache.e12; + const double* __restrict__ x12 = c_bpcache.x12; + const double* __restrict__ y12 = c_bpcache.y12; + const double* __restrict__ z12 = c_bpcache.z12; - double gout0 = 0; - for (int ij = prim_ij; ij < prim_ij + nprim_ij; ij++) { - const double aij = a12[ij]; - const double eij = e12[ij]; - const double Px = x12[ij]; - const double Py = y12[ij]; - const double Pz = z12[ij]; - const double PCx = Px - Cx; - const double PCy = Py - Cy; - const double PCz = Pz - Cz; - double a0 = aij; - const double theta = omega > 0.0 ? omega * omega / (omega * omega + aij) : 1.0; - const double sqrt_theta = omega > 0.0 ? omega / sqrt(omega * omega + aij) : 1.0; - a0 *= theta; + const double* grid_point = grid_points + task_grid * 3; + const double Cx = grid_point[0]; + const double Cy = grid_point[1]; + const double Cz = grid_point[2]; - const double prefactor = 2.0 * M_PI / aij * eij * sqrt_theta; - const double boys_input = a0 * (PCx * PCx + PCy * PCy + PCz * PCz); - double eri = prefactor; - if (boys_input > 3.e-7) { - const double sqrt_boys_input = sqrt(boys_input); - const double boys_0 = SQRTPIE4 / sqrt_boys_input * erf(sqrt_boys_input); - eri *= boys_0; + double eri_per_pair = 0; + for (int ij = prim_ij; ij < prim_ij + nprim_ij; ij++) { + const double aij = a12[ij]; + const double eij = e12[ij]; + const double Px = x12[ij]; + const double Py = y12[ij]; + const double Pz = z12[ij]; + const double PCx = Px - Cx; + const double PCy = Py - Cy; + const double PCz = Pz - Cz; + double a0 = aij; + const double theta = omega > 0.0 ? omega * omega / (omega * omega + aij) : 1.0; + const double sqrt_theta = omega > 0.0 ? omega / sqrt(omega * omega + aij) : 1.0; + a0 *= theta; + + const double prefactor = 2.0 * M_PI / aij * eij * sqrt_theta; + const double boys_input = a0 * (PCx * PCx + PCy * PCy + PCz * PCz); + double eri_per_primitive = prefactor; + if (boys_input > 3.e-7) { + const double sqrt_boys_input = sqrt(boys_input); + const double boys_0 = SQRTPIE4 / sqrt_boys_input * erf(sqrt_boys_input); + eri_per_primitive *= boys_0; + } + eri_per_pair += eri_per_primitive; } - gout0 += eri; - } - const double D = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair]; - atomicAdd(output + task_grid, D * gout0); + const double D = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair]; + eri_pair_sum += D * eri_per_pair; + } + atomicAdd(output + task_grid, eri_pair_sum); } __global__ @@ -234,80 +237,84 @@ static void GINTfill_int3c1e_density_contracted_kernel10(double* output, const d { const int ntasks_ij = offsets.ntasks_ij; const int ngrids = offsets.ntasks_kl; - const int task_ij = blockIdx.x * blockDim.x + threadIdx.x; const int task_grid = blockIdx.y * blockDim.y + threadIdx.y; - - if (task_ij >= ntasks_ij || task_grid >= ngrids) { + if (task_grid >= ngrids) { return; } - const int bas_ij = offsets.bas_ij + task_ij; - const int prim_ij = offsets.primitive_ij + task_ij * nprim_ij; - const int* bas_pair2bra = c_bpcache.bas_pair2bra; - // const int* bas_pair2ket = c_bpcache.bas_pair2ket; - const int ish = bas_pair2bra[bas_ij]; - // const int jsh = bas_pair2ket[bas_ij]; - const double* __restrict__ a12 = c_bpcache.a12; - const double* __restrict__ e12 = c_bpcache.e12; - const double* __restrict__ x12 = c_bpcache.x12; - const double* __restrict__ y12 = c_bpcache.y12; - const double* __restrict__ z12 = c_bpcache.z12; + double eri_pair_sum = 0.0; + for (int task_ij = blockIdx.x * blockDim.x + threadIdx.x; task_ij < ntasks_ij; task_ij += gridDim.x * blockDim.x) { + const int bas_ij = offsets.bas_ij + task_ij; + const int prim_ij = offsets.primitive_ij + task_ij * nprim_ij; + const int* bas_pair2bra = c_bpcache.bas_pair2bra; + // const int* bas_pair2ket = c_bpcache.bas_pair2ket; + const int ish = bas_pair2bra[bas_ij]; + // const int jsh = bas_pair2ket[bas_ij]; - const int nbas = c_bpcache.nbas; - const double* __restrict__ bas_x = c_bpcache.bas_coords; - const double* __restrict__ bas_y = bas_x + nbas; - const double* __restrict__ bas_z = bas_y + nbas; - const double Ax = bas_x[ish]; - const double Ay = bas_y[ish]; - const double Az = bas_z[ish]; + const double* __restrict__ a12 = c_bpcache.a12; + const double* __restrict__ e12 = c_bpcache.e12; + const double* __restrict__ x12 = c_bpcache.x12; + const double* __restrict__ y12 = c_bpcache.y12; + const double* __restrict__ z12 = c_bpcache.z12; - const double* grid_point = grid_points + task_grid * 3; - const double Cx = grid_point[0]; - const double Cy = grid_point[1]; - const double Cz = grid_point[2]; + const int nbas = c_bpcache.nbas; + const double* __restrict__ bas_x = c_bpcache.bas_coords; + const double* __restrict__ bas_y = bas_x + nbas; + const double* __restrict__ bas_z = bas_y + nbas; + const double Ax = bas_x[ish]; + const double Ay = bas_y[ish]; + const double Az = bas_z[ish]; - double gout_x = 0; - double gout_y = 0; - double gout_z = 0; - for (int ij = prim_ij; ij < prim_ij + nprim_ij; ij++) { - const double aij = a12[ij]; - const double eij = e12[ij]; - const double Px = x12[ij]; - const double Py = y12[ij]; - const double Pz = z12[ij]; - const double PCx = Px - Cx; - const double PCy = Py - Cy; - const double PCz = Pz - Cz; - const double PAx = Px - Ax; - const double PAy = Py - Ay; - const double PAz = Pz - Az; - double a0 = aij; - const double one_over_two_p = 0.5 / aij; - const double theta = omega > 0.0 ? omega * omega / (omega * omega + aij) : 1.0; - const double sqrt_theta = omega > 0.0 ? omega / sqrt(omega * omega + aij) : 1.0; - a0 *= theta; + const double* grid_point = grid_points + task_grid * 3; + const double Cx = grid_point[0]; + const double Cy = grid_point[1]; + const double Cz = grid_point[2]; - const double prefactor = 2.0 * M_PI / aij * eij * sqrt_theta; - const double boys_input = a0 * (PCx * PCx + PCy * PCy + PCz * PCz); - double eri_x = prefactor; - double eri_y = prefactor; - double eri_z = prefactor; - if (boys_input > 3.e-7) { - const double sqrt_boys_input = sqrt(boys_input); - const double R000_0 = SQRTPIE4 / sqrt_boys_input * erf(sqrt_boys_input); - const double R000_1 = -a0 * (R000_0 - exp(-boys_input)) / boys_input; - eri_x *= R000_0 * PAx + R000_1 * PCx * one_over_two_p; - eri_y *= R000_0 * PAy + R000_1 * PCy * one_over_two_p; - eri_z *= R000_0 * PAz + R000_1 * PCz * one_over_two_p; + double eri_per_pair_x = 0; + double eri_per_pair_y = 0; + double eri_per_pair_z = 0; + for (int ij = prim_ij; ij < prim_ij + nprim_ij; ij++) { + const double aij = a12[ij]; + const double eij = e12[ij]; + const double Px = x12[ij]; + const double Py = y12[ij]; + const double Pz = z12[ij]; + const double PCx = Px - Cx; + const double PCy = Py - Cy; + const double PCz = Pz - Cz; + const double PAx = Px - Ax; + const double PAy = Py - Ay; + const double PAz = Pz - Az; + double a0 = aij; + const double one_over_two_p = 0.5 / aij; + const double theta = omega > 0.0 ? omega * omega / (omega * omega + aij) : 1.0; + const double sqrt_theta = omega > 0.0 ? omega / sqrt(omega * omega + aij) : 1.0; + a0 *= theta; + + const double prefactor = 2.0 * M_PI / aij * eij * sqrt_theta; + const double boys_input = a0 * (PCx * PCx + PCy * PCy + PCz * PCz); + double eri_per_primitive_x = prefactor; + double eri_per_primitive_y = prefactor; + double eri_per_primitive_z = prefactor; + if (boys_input > 3.e-7) { + const double sqrt_boys_input = sqrt(boys_input); + const double R000_0 = SQRTPIE4 / sqrt_boys_input * erf(sqrt_boys_input); + const double R000_1 = -a0 * (R000_0 - exp(-boys_input)) / boys_input; + eri_per_primitive_x *= R000_0 * PAx + R000_1 * PCx * one_over_two_p; + eri_per_primitive_y *= R000_0 * PAy + R000_1 * PCy * one_over_two_p; + eri_per_primitive_z *= R000_0 * PAz + R000_1 * PCz * one_over_two_p; + } + eri_per_pair_x += eri_per_primitive_x; + eri_per_pair_y += eri_per_primitive_y; + eri_per_pair_z += eri_per_primitive_z; } - gout_x += eri_x; - gout_y += eri_y; - gout_z += eri_z; - } - // Density element 0 is the 000 element and is not used in McMurchie-Davidson algorithm. Density element 1~3 is the unchanged z,y,x components. - const double D_x = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + hermite_density_offsets.n_pair_of_angular_pair * 3]; - const double D_y = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + hermite_density_offsets.n_pair_of_angular_pair * 2]; - const double D_z = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + hermite_density_offsets.n_pair_of_angular_pair * 1]; - atomicAdd(output + task_grid, D_x * gout_x + D_y * gout_y + D_z * gout_z); + // Density element 0 is the 000 element and is not used in McMurchie-Davidson algorithm. Density element 1~3 is the unchanged z,y,x components. + const double D_x = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + hermite_density_offsets.n_pair_of_angular_pair * 3]; + const double D_y = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + hermite_density_offsets.n_pair_of_angular_pair * 2]; + const double D_z = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + hermite_density_offsets.n_pair_of_angular_pair * 1]; + + eri_pair_sum += D_x * eri_per_pair_x + D_y * eri_per_pair_y + D_z * eri_per_pair_z; + } + atomicAdd(output + task_grid, eri_pair_sum); } diff --git a/gpu4pyscf/lib/gint/g3c1e.cu b/gpu4pyscf/lib/gint/g3c1e.cu index 7a5ee25d..2cfde36c 100644 --- a/gpu4pyscf/lib/gint/g3c1e.cu +++ b/gpu4pyscf/lib/gint/g3c1e.cu @@ -96,45 +96,53 @@ void GINT_int3c1e_density_contracted_kernel_general(double* output, const double { const int ntasks_ij = offsets.ntasks_ij; const int ngrids = offsets.ntasks_kl; - const int task_ij = blockIdx.x * blockDim.x + threadIdx.x; const int task_grid = blockIdx.y * blockDim.y + threadIdx.y; - - if (task_ij >= ntasks_ij || task_grid >= ngrids) { + if (task_grid >= ngrids) { return; } - const int bas_ij = offsets.bas_ij + task_ij; - const int prim_ij = offsets.primitive_ij + task_ij * nprim_ij; - const int* bas_pair2bra = c_bpcache.bas_pair2bra; - // const int* bas_pair2ket = c_bpcache.bas_pair2ket; - const int ish = bas_pair2bra[bas_ij]; - // const int jsh = bas_pair2ket[bas_ij]; - const double* grid_point = grid_points + task_grid * 3; - - constexpr int l_max = (NROOTS - 1) * 2 + 1; - double D_hermite[(l_max + 1) * (l_max + 2) * (l_max + 3) / 6]; - const int l = i_l + j_l; - for (int i_t = 0; i_t < (l + 1) * (l + 2) * (l + 3) / 6; i_t++) { - D_hermite[i_t] = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + i_t * hermite_density_offsets.n_pair_of_angular_pair]; - } + double eri_with_density_pair_sum = 0.0; + for (int task_ij = blockIdx.x * blockDim.x + threadIdx.x; task_ij < ntasks_ij; task_ij += gridDim.x * blockDim.x) { + const int bas_ij = offsets.bas_ij + task_ij; + const int prim_ij = offsets.primitive_ij + task_ij * nprim_ij; + const int* bas_pair2bra = c_bpcache.bas_pair2bra; + // const int* bas_pair2ket = c_bpcache.bas_pair2ket; + const int ish = bas_pair2bra[bas_ij]; + // const int jsh = bas_pair2ket[bas_ij]; + + const double* grid_point = grid_points + task_grid * 3; + + constexpr int l_max = (NROOTS - 1) * 2 + 1; + double D_hermite[(l_max + 1) * (l_max + 2) * (l_max + 3) / 6]; + const int l = i_l + j_l; + for (int i_t = 0; i_t < (l + 1) * (l + 2) * (l + 3) / 6; i_t++) { + D_hermite[i_t] = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + i_t * hermite_density_offsets.n_pair_of_angular_pair]; + } - double eri_with_density = 0.0; - for (int ij = prim_ij; ij < prim_ij+nprim_ij; ++ij) { - double g[NROOTS * (l_max + 1) * 3]; - GINT_g1e_without_hrr(g, grid_point, ish, ij, l, omega); - - double eri_with_density_primitive = 0.0; - for (int i_x = 0, i_t = 0; i_x <= l; i_x++) - for (int i_y = 0; i_x + i_y <= l; i_y++) - for (int i_z = 0; i_x + i_y + i_z <= l; i_z++, i_t++) - for (int i_root = 0; i_root < NROOTS; i_root++) { - const double gx = g[i_root + NROOTS * i_x]; - const double gy = g[i_root + NROOTS * i_y + NROOTS * (l + 1)]; - const double gz = g[i_root + NROOTS * i_z + NROOTS * (l + 1) * 2]; - eri_with_density_primitive += gx * gy * gz * D_hermite[i_t]; + double eri_with_density_per_pair = 0.0; + for (int ij = prim_ij; ij < prim_ij+nprim_ij; ++ij) { + double g[NROOTS * (l_max + 1) * 3]; + GINT_g1e_without_hrr(g, grid_point, ish, ij, l, omega); + + double eri_with_density_per_primitive = 0.0; + for (int i_x = 0, i_t = 0; i_x <= l; i_x++) { + for (int i_y = 0; i_x + i_y <= l; i_y++) { + for (int i_z = 0; i_x + i_y + i_z <= l; i_z++, i_t++) { + const double D_t = D_hermite[i_t]; + #pragma unroll + for (int i_root = 0; i_root < NROOTS; i_root++) { + const double gx = g[i_root + NROOTS * i_x]; + const double gy = g[i_root + NROOTS * i_y + NROOTS * (l + 1)]; + const double gz = g[i_root + NROOTS * i_z + NROOTS * (l + 1) * 2]; + eri_with_density_per_primitive += gx * gy * gz * D_t; + } } + } + } - eri_with_density += eri_with_density_primitive; + eri_with_density_per_pair += eri_with_density_per_primitive; + } + eri_with_density_pair_sum += eri_with_density_per_pair; } - atomicAdd(output + task_grid, eri_with_density); + atomicAdd(output + task_grid, eri_with_density_pair_sum); } diff --git a/gpu4pyscf/lib/gint/gint.h b/gpu4pyscf/lib/gint/gint.h index 5f50b97d..e8cd33c2 100644 --- a/gpu4pyscf/lib/gint/gint.h +++ b/gpu4pyscf/lib/gint/gint.h @@ -264,7 +264,6 @@ typedef struct { int *primitive_pairs_locs; // len(a12) = sum(cptype[:].nparis*cptype[:].nprim_12) int *bas_pair2shls; double *aexyz; - double *h_bas_coords; // Data below held on GPU global memory double *bas_coords; // basis coordinates diff --git a/gpu4pyscf/lib/gint/j_engine_matrix_reorder.c b/gpu4pyscf/lib/gint/j_engine_matrix_reorder.c index f40abc31..00878aa2 100644 --- a/gpu4pyscf/lib/gint/j_engine_matrix_reorder.c +++ b/gpu4pyscf/lib/gint/j_engine_matrix_reorder.c @@ -102,13 +102,8 @@ int hermite_xyz_to_t_index(const int x, const int y, const int z, const int l) void GINTinit_J_density_rys_preprocess(const double* D_matrix, double* D_pair_ordered, const int n_dm, const int n_ao, const int n_pair_type, const int* bas_pair2shls, const int* bas_pairs_locs, const int* l_ij, const int* density_offset, const int* ao_loc, - const BasisProdCache* bpcache) + const double* bas_coords) { - const double* bas_coords = bpcache->h_bas_coords; - const int nbas = bpcache->nbas; - const double* bas_x = bas_coords; - const double* bas_y = bas_x + nbas; - const double* bas_z = bas_y + nbas; const int n_bas_pairs = bas_pairs_locs[n_pair_type]; const int n_total_hermite_density = density_offset[n_pair_type]; for (int i_dm = 0; i_dm < n_dm; i_dm++) { @@ -120,12 +115,12 @@ void GINTinit_J_density_rys_preprocess(const double* D_matrix, double* D_pair_or for (int i_pair = bas_pairs_locs[i_pair_type]; i_pair < bas_pairs_locs[i_pair_type + 1]; i_pair++) { const int ish = bas_pair2shls[i_pair]; const int jsh = bas_pair2shls[n_bas_pairs + i_pair]; - const double Ax = bas_x[ish]; - const double Ay = bas_y[ish]; - const double Az = bas_z[ish]; - const double Bx = bas_x[jsh]; - const double By = bas_y[jsh]; - const double Bz = bas_z[jsh]; + const double Ax = bas_coords[ish * 3 + 0]; + const double Ay = bas_coords[ish * 3 + 1]; + const double Az = bas_coords[ish * 3 + 2]; + const double Bx = bas_coords[jsh * 3 + 0]; + const double By = bas_coords[jsh * 3 + 1]; + const double Bz = bas_coords[jsh * 3 + 2]; const int i0 = ao_loc[ish]; const int j0 = ao_loc[jsh]; diff --git a/gpu4pyscf/lib/gint/nr_fill_ao_int3c1e.cu b/gpu4pyscf/lib/gint/nr_fill_ao_int3c1e.cu index a475c413..dbd00367 100644 --- a/gpu4pyscf/lib/gint/nr_fill_ao_int3c1e.cu +++ b/gpu4pyscf/lib/gint/nr_fill_ao_int3c1e.cu @@ -96,10 +96,10 @@ static int GINTfill_int3c1e_tasks(double* output, const BasisProdOffsets offsets static int GINTfill_int3c1e_density_contracted_tasks(double* output, const double* density, const HermiteDensityOffsets hermite_density_offsets, const BasisProdOffsets offsets, const int i_l, const int j_l, const int nprim_ij, - const double omega, const double* grid_points, const cudaStream_t stream) + const double omega, const double* grid_points, const int n_pair_sum_per_thread, const cudaStream_t stream) { const int nrys_roots = (i_l + j_l) / 2 + 1; - const int ntasks_ij = offsets.ntasks_ij; + const int ntasks_ij = (offsets.ntasks_ij + n_pair_sum_per_thread - 1) / n_pair_sum_per_thread; const int ngrids = offsets.ntasks_kl; const dim3 threads(THREADSX, THREADSY); @@ -196,7 +196,7 @@ int GINTfill_int3c1e_density_contracted(const cudaStream_t stream, const BasisPr const double* dm_pair_ordered, const int* density_offset, double* integral_density_contracted, const int* bins_locs_ij, int nbins, - const int cp_ij_id, const double omega) + const int cp_ij_id, const double omega, const int n_pair_sum_per_thread) { const ContractionProdType *cp_ij = bpcache->cptype + cp_ij_id; const int i_l = cp_ij->l_bra; @@ -236,7 +236,7 @@ int GINTfill_int3c1e_density_contracted(const cudaStream_t stream, const BasisPr const int err = GINTfill_int3c1e_density_contracted_tasks(integral_density_contracted, dm_pair_ordered, hermite_density_offsets, offsets, i_l, j_l, nprim_ij, - omega, grid_points, stream); + omega, grid_points, n_pair_sum_per_thread, stream); if (err != 0) { return err;