From a33940b5eb9e93a486a3514f8152ab1cff07ae3a Mon Sep 17 00:00:00 2001 From: Henry Wang Date: Thu, 12 Dec 2024 08:15:45 +0800 Subject: [PATCH] int1e first derivative with density and charge contraction working, not optimized --- gpu4pyscf/gto/int3c1e.py | 18 +- gpu4pyscf/gto/int3c1e_ip.py | 189 ++++++++++++++++- gpu4pyscf/gto/moleintor.py | 6 +- gpu4pyscf/gto/tests/test_int1e_grids_ip.py | 135 +++++++++++++ gpu4pyscf/lib/gint/g1e.cu | 95 +++++++++ gpu4pyscf/lib/gint/g3c1e.cu | 7 +- gpu4pyscf/lib/gint/g3c1e_ip.cu | 202 +++++++++++++++++++ gpu4pyscf/lib/gint/j_engine_matrix_reorder.c | 6 +- gpu4pyscf/lib/gint/nr_fill_ao_int3c1e_ip.cu | 185 +++++++++++++++++ 9 files changed, 824 insertions(+), 19 deletions(-) diff --git a/gpu4pyscf/gto/int3c1e.py b/gpu4pyscf/gto/int3c1e.py index e9420758..d65fee5b 100644 --- a/gpu4pyscf/gto/int3c1e.py +++ b/gpu4pyscf/gto/int3c1e.py @@ -29,7 +29,6 @@ BLKSIZE = 128 -libgvhf = load_library('libgvhf') libgint = load_library('libgint') class VHFOpt(_vhf.VHFOpt): @@ -58,7 +57,7 @@ def __init__(self, mol, intor='int2e', prescreen='CVHFnoscreen', def clear(self): _vhf.VHFOpt.__del__(self) for n, bpcache in self._bpcache.items(): - libgvhf.GINTdel_basis_prod(ctypes.byref(bpcache)) + libgint.GINTdel_basis_prod(ctypes.byref(bpcache)) return self def __del__(self): @@ -296,7 +295,7 @@ def get_int3c1e_charge_contracted(mol, grids, charge_exponents, charges, intopt) if charge_exponents is not None: charge_exponents = cp.asarray(charge_exponents, order='C') - int1e = cp.zeros([mol.nao, mol.nao], order='C') + int1e_charge_contracted = cp.zeros([mol.nao, mol.nao], order='C') for cp_ij_id, _ in enumerate(intopt.log_qs): cpi = intopt.cp_idx[cp_ij_id] cpj = intopt.cp_jdx[cp_ij_id] @@ -355,14 +354,14 @@ def get_int3c1e_charge_contracted(mol, grids, charge_exponents, charges, intopt) int1e_angular_slice = cart2sph(int1e_angular_slice, axis=0, ang=lj) int1e_angular_slice = cart2sph(int1e_angular_slice, axis=1, ang=li) - int1e[j0:j1, i0:i1] = int1e_angular_slice + int1e_charge_contracted[j0:j1, i0:i1] = int1e_angular_slice row, col = np.tril_indices(nao) - int1e[row, col] = int1e[col, row] + int1e_charge_contracted[row, col] = int1e_charge_contracted[col, row] ao_idx = np.argsort(intopt._ao_idx) - int1e = int1e[np.ix_(ao_idx, ao_idx)] + int1e_charge_contracted = int1e_charge_contracted[np.ix_(ao_idx, ao_idx)] - return int1e + return int1e_charge_contracted def get_int3c1e_density_contracted(mol, grids, charge_exponents, dm, intopt): omega = mol.omega @@ -397,7 +396,7 @@ def get_int3c1e_density_contracted(mol, grids, charge_exponents, dm, intopt): n_total_hermite_density = intopt.density_offset[-1] dm_pair_ordered = np.zeros(n_total_hermite_density) - libgvhf.GINTinit_J_density_rys_preprocess(dm.ctypes.data_as(ctypes.c_void_p), + libgint.GINTinit_J_density_rys_preprocess(dm.ctypes.data_as(ctypes.c_void_p), dm_pair_ordered.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(1), ctypes.c_int(nao_cart), ctypes.c_int(len(intopt.bas_pairs_locs) - 1), intopt.bas_pair2shls.ctypes.data_as(ctypes.c_void_p), @@ -405,7 +404,8 @@ def get_int3c1e_density_contracted(mol, grids, charge_exponents, dm, intopt): l_ij.ctypes.data_as(ctypes.c_void_p), intopt.density_offset.ctypes.data_as(ctypes.c_void_p), ao_loc_sorted_order.ctypes.data_as(ctypes.c_void_p), - bas_coords.ctypes.data_as(ctypes.c_void_p)) + bas_coords.ctypes.data_as(ctypes.c_void_p), + ctypes.c_bool(True)) dm_pair_ordered = cp.asarray(dm_pair_ordered) diff --git a/gpu4pyscf/gto/int3c1e_ip.py b/gpu4pyscf/gto/int3c1e_ip.py index 01b29f64..a7941b56 100644 --- a/gpu4pyscf/gto/int3c1e_ip.py +++ b/gpu4pyscf/gto/int3c1e_ip.py @@ -17,13 +17,12 @@ import cupy as cp import numpy as np +from pyscf.gto import ATOM_OF from pyscf.lib import c_null_ptr from gpu4pyscf.lib.cupy_helper import load_library, cart2sph, get_avail_mem - libgint = load_library('libgint') - def get_int3c1e_ip(mol, grids, charge_exponents, intopt): omega = mol.omega assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented." @@ -122,3 +121,189 @@ def get_int3c1e_ip(mol, grids, charge_exponents, intopt): int3c_grid_slice[5, :, :, :].get(out = int3c_ip2[2, i_grid_split : i_grid_split + ngrids_of_split, :, :]) return int3c_ip1, int3c_ip2 + +def get_int3c1e_ip1_charge_contracted(mol, grids, charge_exponents, charges, intopt): + omega = mol.omega + assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented." + + nao = mol.nao + + charges = cp.asarray(charges).reshape([-1, 1], order='C') + grids = cp.concatenate([grids, charges], axis=1) + + int1e_charge_contracted = cp.zeros([3, mol.nao, mol.nao], order='C') + for cp_ij_id, _ in enumerate(intopt.log_qs): + cpi = intopt.cp_idx[cp_ij_id] + cpj = intopt.cp_jdx[cp_ij_id] + li = intopt.angular[cpi] + lj = intopt.angular[cpj] + + stream = cp.cuda.get_current_stream() + nao_cart = intopt._sorted_mol.nao + + log_q_ij = intopt.log_qs[cp_ij_id] + + nbins = 1 + bins_locs_ij = np.array([0, len(log_q_ij)], dtype=np.int32) + + i0, i1 = intopt.cart_ao_loc[cpi], intopt.cart_ao_loc[cpi+1] + j0, j1 = intopt.cart_ao_loc[cpj], intopt.cart_ao_loc[cpj+1] + ni = i1 - i0 + nj = j1 - j0 + + ao_offsets = np.array([i0, j0], dtype=np.int32) + strides = np.array([ni, ni*nj], dtype=np.int32) + + charge_exponents_pointer = c_null_ptr() + if charge_exponents is not None: + charge_exponents_pointer = charge_exponents.data.ptr + + ngrids = grids.shape[0] + # n_charge_sum_per_thread = 1 # means every thread processes one pair and one grid + # n_charge_sum_per_thread = ngrids # or larger number gaurantees one thread processes one pair and all grid points + n_charge_sum_per_thread = 10 + + int1e_angular_slice = cp.zeros([3, j1-j0, i1-i0], order='C') + + err = libgint.GINTfill_int3c1e_ip1_charge_contracted( + ctypes.cast(stream.ptr, ctypes.c_void_p), + intopt.bpcache, + ctypes.cast(grids.data.ptr, ctypes.c_void_p), + ctypes.cast(charge_exponents_pointer, ctypes.c_void_p), + ctypes.c_int(ngrids), + ctypes.cast(int1e_angular_slice.data.ptr, ctypes.c_void_p), + ctypes.c_int(nao_cart), + strides.ctypes.data_as(ctypes.c_void_p), + ao_offsets.ctypes.data_as(ctypes.c_void_p), + bins_locs_ij.ctypes.data_as(ctypes.c_void_p), + ctypes.c_int(nbins), + ctypes.c_int(cp_ij_id), + ctypes.c_double(omega), + ctypes.c_int(n_charge_sum_per_thread)) + + if err != 0: + raise RuntimeError('GINTfill_int3c1e_charge_contracted failed') + + i0, i1 = intopt.ao_loc[cpi], intopt.ao_loc[cpi+1] + j0, j1 = intopt.ao_loc[cpj], intopt.ao_loc[cpj+1] + if not mol.cart: + int1e_angular_slice = cart2sph(int1e_angular_slice, axis=1, ang=lj) + int1e_angular_slice = cart2sph(int1e_angular_slice, axis=2, ang=li) + + int1e_charge_contracted[:, j0:j1, i0:i1] = int1e_angular_slice + + ao_idx = np.argsort(intopt._ao_idx) + derivative_idx = np.arange(3) + int1e_charge_contracted = int1e_charge_contracted[np.ix_(derivative_idx, ao_idx, ao_idx)] + + return int1e_charge_contracted + +def get_int3c1e_ip2_density_contracted(mol, grids, charge_exponents, dm, intopt): + omega = mol.omega + assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented." + + nao_cart = intopt._sorted_mol.nao + ngrids = grids.shape[0] + + dm = intopt.sort_orbitals(dm, [0,1]) + if not mol.cart: + cart2sph_transformation_matrix = intopt.cart2sph + # TODO: This part is inefficient (O(N^3)), should be changed to the O(N^2) algorithm + dm = cart2sph_transformation_matrix @ dm @ cart2sph_transformation_matrix.T + dm = dm.flatten(order='F') # Column major order matches (i + j * n_ao) access pattern in the C function + + dm = cp.asnumpy(dm) + + ao_loc_sorted_order = intopt._sorted_mol.ao_loc_nr(cart = True) + l_ij = intopt.l_ij.T.flatten() + bas_coords = intopt._sorted_mol.atom_coords()[intopt._sorted_mol._bas[:, ATOM_OF]].flatten() + + n_total_hermite_density = intopt.density_offset[-1] + dm_pair_ordered = np.zeros(n_total_hermite_density) + libgint.GINTinit_J_density_rys_preprocess(dm.ctypes.data_as(ctypes.c_void_p), + dm_pair_ordered.ctypes.data_as(ctypes.c_void_p), + ctypes.c_int(1), ctypes.c_int(nao_cart), ctypes.c_int(len(intopt.bas_pairs_locs) - 1), + intopt.bas_pair2shls.ctypes.data_as(ctypes.c_void_p), + intopt.bas_pairs_locs.ctypes.data_as(ctypes.c_void_p), + l_ij.ctypes.data_as(ctypes.c_void_p), + intopt.density_offset.ctypes.data_as(ctypes.c_void_p), + ao_loc_sorted_order.ctypes.data_as(ctypes.c_void_p), + bas_coords.ctypes.data_as(ctypes.c_void_p), + ctypes.c_bool(False)) + + dm_pair_ordered = cp.asarray(dm_pair_ordered) + + n_threads_per_block_1d = 16 + n_max_blocks_per_grid_1d = 65535 + n_max_threads_1d = n_threads_per_block_1d * n_max_blocks_per_grid_1d + n_grid_split = int(np.ceil(ngrids / n_max_threads_1d)) + if (n_grid_split > 100): + print(f"Grid dimension = {ngrids} is too large, more than 100 kernels for one electron integral will be launched.") + ngrids_per_split = (ngrids + n_grid_split - 1) // n_grid_split + + int3c_density_contracted = cp.zeros([3, ngrids], order='C') + + for i_grid_split in range(0, ngrids, ngrids_per_split): + ngrids_of_split = np.min([ngrids_per_split, ngrids - i_grid_split]) + for cp_ij_id, _ in enumerate(intopt.log_qs): + stream = cp.cuda.get_current_stream() + + log_q_ij = intopt.log_qs[cp_ij_id] + + nbins = 1 + bins_locs_ij = np.array([0, len(log_q_ij)], dtype=np.int32) + + charge_exponents_pointer = c_null_ptr() + if charge_exponents is not None: + charge_exponents_pointer = charge_exponents[i_grid_split : i_grid_split + ngrids_of_split].data.ptr + + # n_pair_sum_per_thread = 1 # means every thread processes one pair and one grid + # n_pair_sum_per_thread = nao_cart # or larger number gaurantees one thread processes one grid and all pairs of the same type + n_pair_sum_per_thread = nao_cart + + err = libgint.GINTfill_int3c1e_ip2_density_contracted( + ctypes.cast(stream.ptr, ctypes.c_void_p), + intopt.bpcache, + ctypes.cast(grids[i_grid_split : i_grid_split + ngrids_of_split, :].data.ptr, ctypes.c_void_p), + ctypes.cast(charge_exponents_pointer, ctypes.c_void_p), + ctypes.c_int(ngrids_of_split), + ctypes.cast(dm_pair_ordered.data.ptr, ctypes.c_void_p), + intopt.density_offset.ctypes.data_as(ctypes.c_void_p), + ctypes.cast(int3c_density_contracted[:, i_grid_split : i_grid_split + ngrids_of_split].data.ptr, ctypes.c_void_p), + bins_locs_ij.ctypes.data_as(ctypes.c_void_p), + ctypes.c_int(nbins), + ctypes.c_int(cp_ij_id), + ctypes.c_double(omega), + ctypes.c_int(n_pair_sum_per_thread)) + + if err != 0: + raise RuntimeError('GINTfill_int3c1e_density_contracted failed') + + return int3c_density_contracted + +def get_int3c1e_ip_contracted(mol, grids, charge_exponents, dm, charges, intopt): + dm = cp.asarray(dm) + if dm.ndim == 3: + if dm.shape[0] > 2: + print("Warning: There are more than two density matrices to contract with one electron integrals, " + "it's not from an unrestricted calculation, and we're unsure about your purpose. " + "We sum the density matrices up, please check if that's expected.") + dm = cp.einsum("ijk->jk", dm) + + assert dm.ndim == 2 + assert dm.shape[0] == dm.shape[1] and dm.shape[0] == mol.nao + + grids = cp.asarray(grids, order='C') + if charge_exponents is not None: + charge_exponents = cp.asarray(charge_exponents, order='C') + + assert charges.ndim == 1 and charges.shape[0] == grids.shape[0] + charges = cp.asarray(charges) + + int3c_ip2 = get_int3c1e_ip2_density_contracted(mol, grids, charge_exponents, dm, intopt) + int3c_ip2 = int3c_ip2 * charges + + int3c_ip1 = get_int3c1e_ip1_charge_contracted(mol, grids, charge_exponents, charges, intopt) + int3c_ip1 = cp.einsum('xji,ij->xi', int3c_ip1, dm) + + return int3c_ip1, int3c_ip2 \ No newline at end of file diff --git a/gpu4pyscf/gto/moleintor.py b/gpu4pyscf/gto/moleintor.py index 28a62737..3c1ad6fb 100644 --- a/gpu4pyscf/gto/moleintor.py +++ b/gpu4pyscf/gto/moleintor.py @@ -18,7 +18,7 @@ import numpy as np from gpu4pyscf.gto.int3c1e import VHFOpt, get_int3c1e, get_int3c1e_density_contracted, get_int3c1e_charge_contracted -from gpu4pyscf.gto.int3c1e_ip import get_int3c1e_ip +from gpu4pyscf.gto.int3c1e_ip import get_int3c1e_ip, get_int3c1e_ip_contracted def intor(mol, intor, grids, charge_exponents=None, dm=None, charges=None, direct_scf_tol=1e-13, intopt=None): assert grids is not None @@ -52,6 +52,8 @@ def intor(mol, intor, grids, charge_exponents=None, dm=None, charges=None, direc if dm is None and charges is None: return get_int3c1e_ip(mol, grids, charge_exponents, intopt) else: - raise NotImplementedError() + assert dm is not None + assert charges is not None + return get_int3c1e_ip_contracted(mol, grids, charge_exponents, dm, charges, intopt) else: raise NotImplementedError(f"GPU intor {intor} is not implemented.") diff --git a/gpu4pyscf/gto/tests/test_int1e_grids_ip.py b/gpu4pyscf/gto/tests/test_int1e_grids_ip.py index 5f234053..77d4ec72 100644 --- a/gpu4pyscf/gto/tests/test_int1e_grids_ip.py +++ b/gpu4pyscf/gto/tests/test_int1e_grids_ip.py @@ -169,6 +169,141 @@ def test_int1e_grids_ip1_full_tensor_gaussian_charge_omega(self): np.testing.assert_allclose(ref_int1e_dA, test_int1e_dA, atol = integral_threshold) np.testing.assert_allclose(ref_int1e_dC, test_int1e_dC, atol = integral_threshold) + def test_int1e_grids_ip1_contracted_cart(self): + np.random.seed(12346) + dm = np.random.uniform(-2.0, 2.0, (mol_cart.nao, mol_cart.nao)) + charges = np.random.uniform(-2.0, 2.0, grid_points.shape[0]) + + mol = mol_cart + fakemol = gto.fakemol_for_charges(grid_points) + + int3c2e_ip1 = mol._add_suffix('int3c2e_ip1') + cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, int3c2e_ip1) + v_nj = df.incore.aux_e2(mol, fakemol, intor=int3c2e_ip1, aosym='s1', cintopt=cintopt) + ref_int1e_dA = np.einsum('xijk,ij,k->xi', v_nj, dm, charges) + + int3c2e_ip2 = mol._add_suffix('int3c2e_ip2') + cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, int3c2e_ip2) + q_nj = df.incore.aux_e2(mol, fakemol, intor=int3c2e_ip2, aosym='s1', cintopt=cintopt) + ref_int1e_dC = np.einsum('xijk,ij,k->xk', q_nj, dm, charges) + + test_int1e_dA, test_int1e_dC = intor(mol, 'int1e_grids_ip', grid_points, dm = dm, charges = charges) + + assert isinstance(test_int1e_dA, cp.ndarray) + assert isinstance(test_int1e_dC, cp.ndarray) + cp.testing.assert_allclose(ref_int1e_dA, test_int1e_dA, atol = integral_threshold) + cp.testing.assert_allclose(ref_int1e_dC, test_int1e_dC, atol = integral_threshold) + + def test_int1e_grids_ip1_contracted_sph(self): + np.random.seed(12346) + dm = np.random.uniform(-2.0, 2.0, (mol_sph.nao, mol_sph.nao)) + charges = np.random.uniform(-2.0, 2.0, grid_points.shape[0]) + + mol = mol_sph + fakemol = gto.fakemol_for_charges(grid_points) + + int3c2e_ip1 = mol._add_suffix('int3c2e_ip1') + cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, int3c2e_ip1) + v_nj = df.incore.aux_e2(mol, fakemol, intor=int3c2e_ip1, aosym='s1', cintopt=cintopt) + ref_int1e_dA = np.einsum('xijk,ij,k->xi', v_nj, dm, charges) + + int3c2e_ip2 = mol._add_suffix('int3c2e_ip2') + cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, int3c2e_ip2) + q_nj = df.incore.aux_e2(mol, fakemol, intor=int3c2e_ip2, aosym='s1', cintopt=cintopt) + ref_int1e_dC = np.einsum('xijk,ij,k->xk', q_nj, dm, charges) + + test_int1e_dA, test_int1e_dC = intor(mol, 'int1e_grids_ip', grid_points, dm = dm, charges = charges) + + assert isinstance(test_int1e_dA, cp.ndarray) + assert isinstance(test_int1e_dC, cp.ndarray) + cp.testing.assert_allclose(ref_int1e_dA, test_int1e_dA, atol = integral_threshold) + cp.testing.assert_allclose(ref_int1e_dC, test_int1e_dC, atol = integral_threshold) + + def test_int1e_grids_ip1_contracted_gaussian_charge(self): + np.random.seed(12347) + dm = np.random.uniform(-2.0, 2.0, (mol_sph.nao, mol_sph.nao)) + charges = np.random.uniform(-2.0, 2.0, grid_points.shape[0]) + charge_exponents = np.random.uniform(0.5, 1.0, grid_points.shape[0]) + + mol = mol_sph + fakemol = gto.fakemol_for_charges(grid_points, expnt=charge_exponents) + + int3c2e_ip1 = mol._add_suffix('int3c2e_ip1') + cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, int3c2e_ip1) + v_nj = df.incore.aux_e2(mol, fakemol, intor=int3c2e_ip1, aosym='s1', cintopt=cintopt) + ref_int1e_dA = np.einsum('xijk,ij,k->xi', v_nj, dm, charges) + + int3c2e_ip2 = mol._add_suffix('int3c2e_ip2') + cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, int3c2e_ip2) + q_nj = df.incore.aux_e2(mol, fakemol, intor=int3c2e_ip2, aosym='s1', cintopt=cintopt) + ref_int1e_dC = np.einsum('xijk,ij,k->xk', q_nj, dm, charges) + + test_int1e_dA, test_int1e_dC = intor(mol, 'int1e_grids_ip', grid_points, dm = dm, charges = charges, charge_exponents = charge_exponents) + + assert isinstance(test_int1e_dA, cp.ndarray) + assert isinstance(test_int1e_dC, cp.ndarray) + cp.testing.assert_allclose(ref_int1e_dA, test_int1e_dA, atol = integral_threshold) + cp.testing.assert_allclose(ref_int1e_dC, test_int1e_dC, atol = integral_threshold) + + def test_int1e_grids_ip1_contracted_omega(self): + np.random.seed(12348) + dm = np.random.uniform(-2.0, 2.0, (mol_sph.nao, mol_sph.nao)) + charges = np.random.uniform(-2.0, 2.0, grid_points.shape[0]) + + omega = 1.2 + mol_sph_omega = mol_sph.copy() + mol_sph_omega.set_range_coulomb(omega) + + mol = mol_sph_omega + fakemol = gto.fakemol_for_charges(grid_points) + + int3c2e_ip1 = mol._add_suffix('int3c2e_ip1') + cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, int3c2e_ip1) + v_nj = df.incore.aux_e2(mol, fakemol, intor=int3c2e_ip1, aosym='s1', cintopt=cintopt) + ref_int1e_dA = np.einsum('xijk,ij,k->xi', v_nj, dm, charges) + + int3c2e_ip2 = mol._add_suffix('int3c2e_ip2') + cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, int3c2e_ip2) + q_nj = df.incore.aux_e2(mol, fakemol, intor=int3c2e_ip2, aosym='s1', cintopt=cintopt) + ref_int1e_dC = np.einsum('xijk,ij,k->xk', q_nj, dm, charges) + + test_int1e_dA, test_int1e_dC = intor(mol, 'int1e_grids_ip', grid_points, dm = dm, charges = charges) + + assert isinstance(test_int1e_dA, cp.ndarray) + assert isinstance(test_int1e_dC, cp.ndarray) + cp.testing.assert_allclose(ref_int1e_dA, test_int1e_dA, atol = integral_threshold) + cp.testing.assert_allclose(ref_int1e_dC, test_int1e_dC, atol = integral_threshold) + + def test_int1e_grids_ip1_contracted_gaussian_charge_omega(self): + np.random.seed(12349) + dm = np.random.uniform(-2.0, 2.0, (mol_sph.nao, mol_sph.nao)) + charges = np.random.uniform(-2.0, 2.0, grid_points.shape[0]) + charge_exponents = np.random.uniform(0.5, 1.0, grid_points.shape[0]) + + omega = 0.8 + mol_sph_omega = mol_sph.copy() + mol_sph_omega.set_range_coulomb(omega) + + mol = mol_sph_omega + fakemol = gto.fakemol_for_charges(grid_points, expnt=charge_exponents) + + int3c2e_ip1 = mol._add_suffix('int3c2e_ip1') + cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, int3c2e_ip1) + v_nj = df.incore.aux_e2(mol, fakemol, intor=int3c2e_ip1, aosym='s1', cintopt=cintopt) + ref_int1e_dA = np.einsum('xijk,ij,k->xi', v_nj, dm, charges) + + int3c2e_ip2 = mol._add_suffix('int3c2e_ip2') + cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, int3c2e_ip2) + q_nj = df.incore.aux_e2(mol, fakemol, intor=int3c2e_ip2, aosym='s1', cintopt=cintopt) + ref_int1e_dC = np.einsum('xijk,ij,k->xk', q_nj, dm, charges) + + test_int1e_dA, test_int1e_dC = intor(mol, 'int1e_grids_ip', grid_points, dm = dm, charges = charges, charge_exponents = charge_exponents) + + assert isinstance(test_int1e_dA, cp.ndarray) + assert isinstance(test_int1e_dC, cp.ndarray) + cp.testing.assert_allclose(ref_int1e_dA, test_int1e_dA, atol = integral_threshold) + cp.testing.assert_allclose(ref_int1e_dC, test_int1e_dC, atol = integral_threshold) + if __name__ == "__main__": print("Full Tests for One Electron Coulomb Integrals") unittest.main() diff --git a/gpu4pyscf/lib/gint/g1e.cu b/gpu4pyscf/lib/gint/g1e.cu index 1d0cbb1d..16944eae 100644 --- a/gpu4pyscf/lib/gint/g1e.cu +++ b/gpu4pyscf/lib/gint/g1e.cu @@ -345,3 +345,98 @@ static void GINT_g1e_without_hrr(double* __restrict__ g, const double grid_x, co } } + +template +__device__ +static void GINT_g1e_without_hrr_save_u2(double* __restrict__ g, double* __restrict__ u2_save, const double grid_x, const double grid_y, const double grid_z, + const int ish, const int prim_ij, const int l, const double charge_exponent, const double omega) +{ + const double* __restrict__ a12 = c_bpcache.a12; + const double* __restrict__ e12 = c_bpcache.e12; + const double* __restrict__ x12 = c_bpcache.x12; + const double* __restrict__ y12 = c_bpcache.y12; + const double* __restrict__ z12 = c_bpcache.z12; + const double aij = a12[prim_ij]; + const double eij = e12[prim_ij]; + const double Px = x12[prim_ij]; + const double Py = y12[prim_ij]; + const double Pz = z12[prim_ij]; + + const double PCx = Px - grid_x; + const double PCy = Py - grid_y; + const double PCz = Pz - grid_z; + + double a0 = aij; + const double q_over_p_plus_q = charge_exponent > 0.0 ? charge_exponent / (aij + charge_exponent) : 1.0; + const double sqrt_q_over_p_plus_q = charge_exponent > 0.0 ? sqrt(q_over_p_plus_q) : 1.0; + a0 *= q_over_p_plus_q; + const double theta = omega > 0.0 ? omega * omega / (omega * omega + a0) : 1.0; + const double sqrt_theta = omega > 0.0 ? sqrt(theta) : 1.0; + a0 *= theta; + + const double prefactor = 2.0 * M_PI / aij * eij * sqrt_theta * sqrt_q_over_p_plus_q; + const double boys_input = a0 * (PCx * PCx + PCy * PCy + PCz * PCz); + double uw[NROOTS * 2]; + GINTrys_root(boys_input, uw); + GINTscale_u(uw, theta); + + const double* __restrict__ u = uw; + const double* __restrict__ w = u + NROOTS; + const int g_size = NROOTS * (l + 1); + double* __restrict__ gx = g; + double* __restrict__ gy = g + g_size; + double* __restrict__ gz = g + g_size * 2; + + const int nbas = c_bpcache.nbas; + const double* __restrict__ bas_x = c_bpcache.bas_coords; + const double* __restrict__ bas_y = bas_x + nbas; + const double* __restrict__ bas_z = bas_y + nbas; + const double Ax = bas_x[ish]; + const double Ay = bas_y[ish]; + const double Az = bas_z[ish]; + const double PAx = Px - Ax; + const double PAy = Py - Ay; + const double PAz = Pz - Az; + +#pragma unroll + for (int i_root = 0; i_root < NROOTS; i_root++) { + gx[i_root] = 1.0; + gy[i_root] = prefactor; + gz[i_root] = w[i_root]; + + const double u2 = a0 * u[i_root]; + u2_save[i_root] = charge_exponent > 0.0 ? u2 * charge_exponent / (u2 + charge_exponent) : u2; + const double qt2_over_p_plus_q = u2 / (u2 + aij * q_over_p_plus_q) * q_over_p_plus_q; + const double b10 = 0.5 / aij * (1.0 - qt2_over_p_plus_q); + const double c00x = PAx - qt2_over_p_plus_q * PCx; + const double c00y = PAy - qt2_over_p_plus_q * PCy; + const double c00z = PAz - qt2_over_p_plus_q * PCz; + + if (l > 0) { + double s0x = gx[i_root]; // i - 1 + double s0y = gy[i_root]; + double s0z = gz[i_root]; + double s1x = c00x * s0x; // i + double s1y = c00y * s0y; + double s1z = c00z * s0z; + gx[i_root + 1 * NROOTS] = s1x; + gy[i_root + 1 * NROOTS] = s1y; + gz[i_root + 1 * NROOTS] = s1z; + for (int i_rys = 1; i_rys < l; i_rys++) { + const double s2x = c00x * s1x + i_rys * b10 * s0x; // i + 1 + const double s2y = c00y * s1y + i_rys * b10 * s0y; + const double s2z = c00z * s1z + i_rys * b10 * s0z; + gx[i_root + (i_rys+1) * NROOTS] = s2x; + gy[i_root + (i_rys+1) * NROOTS] = s2y; + gz[i_root + (i_rys+1) * NROOTS] = s2z; + s0x = s1x; + s0y = s1y; + s0z = s1z; + s1x = s2x; + s1y = s2y; + s1z = s2z; + } + } + } + +} diff --git a/gpu4pyscf/lib/gint/g3c1e.cu b/gpu4pyscf/lib/gint/g3c1e.cu index c5a9c89b..5ba9547b 100644 --- a/gpu4pyscf/lib/gint/g3c1e.cu +++ b/gpu4pyscf/lib/gint/g3c1e.cu @@ -92,10 +92,9 @@ static void GINTfill_int3c1e_kernel_general(double* output, const BasisProdOffse } } - template __device__ -static void GINTwrite_int3c1e_charge_contracted(const double* g, double* local_output, double prefactor, const int i_l, const int j_l) +static void GINTwrite_int3c1e_charge_contracted(const double* g, double* local_output, const double prefactor, const int i_l, const int j_l) { const int *idx = c_idx; const int *idy = c_idx + TOT_NF; @@ -170,8 +169,8 @@ static void GINTfill_int3c1e_charge_contracted_kernel_general(double* output, co const int* ao_loc = c_bpcache.ao_loc; - const int i0 = ao_loc[ish ] - ao_offsets_i; - const int j0 = ao_loc[jsh ] - ao_offsets_j; + const int i0 = ao_loc[ish] - ao_offsets_i; + const int j0 = ao_loc[jsh] - ao_offsets_j; for (int j = 0; j < (j_l + 1) * (j_l + 2) / 2; j++) { for (int i = 0; i < (i_l + 1) * (i_l + 2) / 2; i++) { const double eri_grid_sum = output_cache[i + j * ((i_l + 1) * (i_l + 2) / 2)]; diff --git a/gpu4pyscf/lib/gint/g3c1e_ip.cu b/gpu4pyscf/lib/gint/g3c1e_ip.cu index cbbbebca..4f54c81a 100644 --- a/gpu4pyscf/lib/gint/g3c1e_ip.cu +++ b/gpu4pyscf/lib/gint/g3c1e_ip.cu @@ -135,3 +135,205 @@ static void GINTfill_int3c1e_ip_kernel_general(double* output, const BasisProdOf GINTwrite_int3c1e_ip(g, output, minus_two_a, u2, AC, ish, jsh, task_grid, i_l, j_l, stride_j, stride_ij, ao_offsets_i, ao_offsets_j, ngrids); } } + +template +__device__ +static void GINTwrite_int3c1e_ip1_charge_contracted(const double* g, double* local_output, const double minus_two_a, const double prefactor, const int i_l, const int j_l) +{ + const int *idx = c_idx; + const int *idy = c_idx + TOT_NF; + const int *idz = c_idx + TOT_NF * 2; + + const int g_size = NROOTS * (i_l + 1 + 1) * (j_l + 1); + const double* __restrict__ gx = g; + const double* __restrict__ gy = g + g_size; + const double* __restrict__ gz = g + g_size * 2; + + for (int j = 0; j < (j_l + 1) * (j_l + 2) / 2; j++) { + for (int i = 0; i < (i_l + 1) * (i_l + 2) / 2; i++) { + const int loc_j = c_l_locs[j_l] + j; + const int loc_i = c_l_locs[i_l] + i; + const int ix = idx[loc_i]; + const int iy = idy[loc_i]; + const int iz = idz[loc_i]; + const int jx = idx[loc_j]; + const int jy = idy[loc_j]; + const int jz = idz[loc_j]; + const int gx_offset = ix + jx * (i_l + 1 + 1); + const int gy_offset = iy + jy * (i_l + 1 + 1); + const int gz_offset = iz + jz * (i_l + 1 + 1); + + double deri_dAx = 0; + double deri_dAy = 0; + double deri_dAz = 0; +#pragma unroll + for (int i_root = 0; i_root < NROOTS; i_root++) { + const double gx_0 = gx[gx_offset * NROOTS + i_root]; + const double gy_0 = gy[gy_offset * NROOTS + i_root]; + const double gz_0 = gz[gz_offset * NROOTS + i_root]; + const double dgx_dAx = (ix > 0 ? ix * gx[(gx_offset - 1) * NROOTS + i_root] : 0) + minus_two_a * gx[(gx_offset + 1) * NROOTS + i_root]; + const double dgy_dAy = (iy > 0 ? iy * gy[(gy_offset - 1) * NROOTS + i_root] : 0) + minus_two_a * gy[(gy_offset + 1) * NROOTS + i_root]; + const double dgz_dAz = (iz > 0 ? iz * gz[(gz_offset - 1) * NROOTS + i_root] : 0) + minus_two_a * gz[(gz_offset + 1) * NROOTS + i_root]; + deri_dAx += dgx_dAx * gy_0 * gz_0; + deri_dAy += gx_0 * dgy_dAy * gz_0; + deri_dAz += gx_0 * gy_0 * dgz_dAz; + } + const int n_density_elements_i = (i_l + 1) * (i_l + 2) / 2; + const int n_density_elements_j = (j_l + 1) * (j_l + 2) / 2; + const int n_density_elements_ij = n_density_elements_i * n_density_elements_j; + local_output[i + j * n_density_elements_i + 0 * n_density_elements_ij] += deri_dAx * prefactor; + local_output[i + j * n_density_elements_i + 1 * n_density_elements_ij] += deri_dAy * prefactor; + local_output[i + j * n_density_elements_i + 2 * n_density_elements_ij] += deri_dAz * prefactor; + } + } +} + +template +__global__ +static void GINTfill_int3c1e_ip1_charge_contracted_kernel_general(double* output, const BasisProdOffsets offsets, const int i_l, const int j_l, const int nprim_ij, + const int stride_j, const int stride_ij, const int ao_offsets_i, const int ao_offsets_j, + const double omega, const double* grid_points, const double* charge_exponents) +{ + const int ntasks_ij = offsets.ntasks_ij; + const int ngrids = offsets.ntasks_kl; + const int task_ij = blockIdx.x * blockDim.x + threadIdx.x; + if (task_ij >= ntasks_ij) { + return; + } + + const int bas_ij = offsets.bas_ij + task_ij; + const int prim_ij = offsets.primitive_ij + task_ij * nprim_ij; + const int* bas_pair2bra = c_bpcache.bas_pair2bra; + const int* bas_pair2ket = c_bpcache.bas_pair2ket; + const int ish = bas_pair2bra[bas_ij]; + const int jsh = bas_pair2ket[bas_ij]; + const double* __restrict__ a_exponents = c_bpcache.a1; + + constexpr int l_sum_max = (NROOTS - 1) * 2 + 1; + constexpr int l_i_max_density_elements = (l_sum_max + 1) / 2; + constexpr int l_j_max_density_elements = l_sum_max - l_i_max_density_elements; + double output_cache[(l_i_max_density_elements + 1) * (l_i_max_density_elements + 2) / 2 + * (l_j_max_density_elements + 1) * (l_j_max_density_elements + 2) / 2 + * 3] { 0.0 }; + + for (int task_grid = blockIdx.y * blockDim.y + threadIdx.y; task_grid < ngrids; task_grid += gridDim.y * blockDim.y) { + const double* grid_point = grid_points + task_grid * 4; + const double charge = grid_point[3]; + const double charge_exponent = (charge_exponents != NULL) ? charge_exponents[task_grid] : 0.0; + + double g[GSIZE_INT3C_1E]; + + for (int ij = prim_ij; ij < prim_ij+nprim_ij; ++ij) { + GINT_g1e(g, grid_point, ish, jsh, ij, i_l + 1, j_l, charge_exponent, omega); + const double minus_two_a = -2.0 * a_exponents[ij]; + GINTwrite_int3c1e_ip1_charge_contracted(g, output_cache, minus_two_a, charge, i_l, j_l); + } + } + + const int* ao_loc = c_bpcache.ao_loc; + + const int i0 = ao_loc[ish] - ao_offsets_i; + const int j0 = ao_loc[jsh] - ao_offsets_j; + const int n_density_elements_i = (i_l + 1) * (i_l + 2) / 2; + const int n_density_elements_j = (j_l + 1) * (j_l + 2) / 2; + const int n_density_elements_ij = n_density_elements_i * n_density_elements_j; + for (int j = 0; j < n_density_elements_j; j++) { + for (int i = 0; i < n_density_elements_i; i++) { + const double deri_dAx = output_cache[i + j * n_density_elements_i + 0 * n_density_elements_ij]; + const double deri_dAy = output_cache[i + j * n_density_elements_i + 1 * n_density_elements_ij]; + const double deri_dAz = output_cache[i + j * n_density_elements_i + 2 * n_density_elements_ij]; + atomicAdd(output + ((i + i0) + (j + j0) * stride_j + 0 * stride_ij), deri_dAx); + atomicAdd(output + ((i + i0) + (j + j0) * stride_j + 1 * stride_ij), deri_dAy); + atomicAdd(output + ((i + i0) + (j + j0) * stride_j + 2 * stride_ij), deri_dAz); + } + } +} + +template +__global__ +static void GINTfill_int3c1e_ip2_density_contracted_kernel_general(double* output, const double* density, const HermiteDensityOffsets hermite_density_offsets, + const BasisProdOffsets offsets, const int i_l, const int j_l, const int nprim_ij, + const double omega, const double* grid_points, const double* charge_exponents) +{ + const int ntasks_ij = offsets.ntasks_ij; + const int ngrids = offsets.ntasks_kl; + const int task_grid = blockIdx.y * blockDim.y + threadIdx.y; + if (task_grid >= ngrids) { + return; + } + + const double* grid_point = grid_points + task_grid * 3; + const double Cx = grid_point[0]; + const double Cy = grid_point[1]; + const double Cz = grid_point[2]; + const double charge_exponent = (charge_exponents != NULL) ? charge_exponents[task_grid] : 0.0; + + double deri_dCx_pair_sum = 0.0; + double deri_dCy_pair_sum = 0.0; + double deri_dCz_pair_sum = 0.0; + for (int task_ij = blockIdx.x * blockDim.x + threadIdx.x; task_ij < ntasks_ij; task_ij += gridDim.x * blockDim.x) { + const int bas_ij = offsets.bas_ij + task_ij; + const int prim_ij = offsets.primitive_ij + task_ij * nprim_ij; + const int* bas_pair2bra = c_bpcache.bas_pair2bra; + // const int* bas_pair2ket = c_bpcache.bas_pair2ket; + const int ish = bas_pair2bra[bas_ij]; + // const int jsh = bas_pair2ket[bas_ij]; + const int nbas = c_bpcache.nbas; + const double* __restrict__ bas_x = c_bpcache.bas_coords; + const double* __restrict__ bas_y = bas_x + nbas; + const double* __restrict__ bas_z = bas_y + nbas; + const double Ax = bas_x[ish]; + const double Ay = bas_y[ish]; + const double Az = bas_z[ish]; + + constexpr int l_max = (NROOTS - 1) * 2 + 1; + double D_hermite[(l_max + 1) * (l_max + 2) * (l_max + 3) / 6]; + const int l = i_l + j_l; + for (int i_t = 0; i_t < (l + 1) * (l + 2) * (l + 3) / 6; i_t++) { + D_hermite[i_t] = density[bas_ij - hermite_density_offsets.pair_offset_of_angular_pair + hermite_density_offsets.density_offset_of_angular_pair + i_t * hermite_density_offsets.n_pair_of_angular_pair]; + } + + double deri_dCx_per_pair = 0.0; + double deri_dCy_per_pair = 0.0; + double deri_dCz_per_pair = 0.0; + for (int ij = prim_ij; ij < prim_ij+nprim_ij; ++ij) { + double g[NROOTS * (l_max + 1) * 3]; + double u2[NROOTS]; + GINT_g1e_without_hrr_save_u2(g, u2, Cx, Cy, Cz, ish, ij, l + 1, charge_exponent, omega); + + const double* __restrict__ gx = g; + const double* __restrict__ gy = g + NROOTS * (l + 1 + 1); + const double* __restrict__ gz = g + NROOTS * (l + 1 + 1) * 2; + + for (int i_x = 0, i_t = 0; i_x <= l; i_x++) { + for (int i_y = 0; i_x + i_y <= l; i_y++) { + for (int i_z = 0; i_x + i_y + i_z <= l; i_z++, i_t++) { + const double D_t = D_hermite[i_t]; +#pragma unroll + for (int i_root = 0; i_root < NROOTS; i_root++) { + const double gx_0 = gx[i_root + NROOTS * i_x]; + const double gy_0 = gy[i_root + NROOTS * i_y]; + const double gz_0 = gz[i_root + NROOTS * i_z]; + const double gx_1 = gx[i_root + NROOTS * (i_x + 1)]; + const double gy_1 = gy[i_root + NROOTS * (i_y + 1)]; + const double gz_1 = gz[i_root + NROOTS * (i_z + 1)]; + const double minus_two_u2 = -2.0 * u2[i_root]; + const double dgx_dCx = minus_two_u2 * (gx_1 + (Ax - Cx) * gx_0); + const double dgy_dCy = minus_two_u2 * (gy_1 + (Ay - Cy) * gy_0); + const double dgz_dCz = minus_two_u2 * (gz_1 + (Az - Cz) * gz_0); + deri_dCx_per_pair += dgx_dCx * gy_0 * gz_0 * D_t; + deri_dCy_per_pair += gx_0 * dgy_dCy * gz_0 * D_t; + deri_dCz_per_pair += gx_0 * gy_0 * dgz_dCz * D_t; + } + } + } + } + } + deri_dCx_pair_sum += deri_dCx_per_pair; + deri_dCy_pair_sum += deri_dCy_per_pair; + deri_dCz_pair_sum += deri_dCz_per_pair; + } + atomicAdd(output + task_grid + ngrids * 0, deri_dCx_pair_sum); + atomicAdd(output + task_grid + ngrids * 1, deri_dCy_pair_sum); + atomicAdd(output + task_grid + ngrids * 2, deri_dCz_pair_sum); +} diff --git a/gpu4pyscf/lib/gint/j_engine_matrix_reorder.c b/gpu4pyscf/lib/gint/j_engine_matrix_reorder.c index 00878aa2..8eeb14de 100644 --- a/gpu4pyscf/lib/gint/j_engine_matrix_reorder.c +++ b/gpu4pyscf/lib/gint/j_engine_matrix_reorder.c @@ -16,6 +16,8 @@ #include "gint.h" +#include + // void GINTinit_J_density_reorder(const double* D_matrix, double* D_pair_ordered, const int n_dm, const int n_ao, const int n_pair_type, // const int* bas_pair2shls, const int* bas_pairs_locs, const int* density_offset, const int* ao_loc) // { @@ -102,7 +104,7 @@ int hermite_xyz_to_t_index(const int x, const int y, const int z, const int l) void GINTinit_J_density_rys_preprocess(const double* D_matrix, double* D_pair_ordered, const int n_dm, const int n_ao, const int n_pair_type, const int* bas_pair2shls, const int* bas_pairs_locs, const int* l_ij, const int* density_offset, const int* ao_loc, - const double* bas_coords) + const double* bas_coords, const bool symmetric) { const int n_bas_pairs = bas_pairs_locs[n_pair_type]; const int n_total_hermite_density = density_offset[n_pair_type]; @@ -139,7 +141,7 @@ void GINTinit_J_density_rys_preprocess(const double* D_matrix, double* D_pair_or for (int i_y_j = lj - i_x_j; i_y_j >= 0; i_y_j--, i_density_j++) { const int i_z_j = lj - i_x_j - i_y_j; - const double D_cartesian = (i0 == j0) ? + const double D_cartesian = ((!symmetric) || i0 == j0) ? D_matrix[(i0 + i_density_i) + (j0 + i_density_j) * n_ao + i_dm * n_ao * n_ao] : D_matrix[(i0 + i_density_i) + (j0 + i_density_j) * n_ao + i_dm * n_ao * n_ao] + D_matrix[(j0 + i_density_j) + (i0 + i_density_i) * n_ao + i_dm * n_ao * n_ao]; diff --git a/gpu4pyscf/lib/gint/nr_fill_ao_int3c1e_ip.cu b/gpu4pyscf/lib/gint/nr_fill_ao_int3c1e_ip.cu index 35bd37f1..a670b4fd 100644 --- a/gpu4pyscf/lib/gint/nr_fill_ao_int3c1e_ip.cu +++ b/gpu4pyscf/lib/gint/nr_fill_ao_int3c1e_ip.cu @@ -67,6 +67,86 @@ static int GINTfill_int3c1e_ip_tasks(double* output, const BasisProdOffsets offs return 0; } +static int GINTfill_int3c1e_ip1_charge_contracted_tasks(double* output, const BasisProdOffsets offsets, const int i_l, const int j_l, const int nprim_ij, + const int stride_j, const int stride_ij, const int ao_offsets_i, const int ao_offsets_j, + const double omega, const double* grid_points, const double* charge_exponents, + const int n_charge_sum_per_thread, const cudaStream_t stream) +{ + const int nrys_roots = (i_l + j_l + 1) / 2 + 1; + const int ntasks_ij = offsets.ntasks_ij; + const int ngrids = (offsets.ntasks_kl + n_charge_sum_per_thread - 1) / n_charge_sum_per_thread; + + const dim3 threads(THREADSX, THREADSY); + const dim3 blocks((ntasks_ij+THREADSX-1)/THREADSX, (ngrids+THREADSY-1)/THREADSY); + int type_ijkl; + switch (nrys_roots) { + // case 1: + // type_ijkl = (i_l << 2) | j_l; + // switch (type_ijkl) { + // case (0<<2)|0: GINTfill_int3c1e_ip1_charge_contracted_kernel00<<>>(output, offsets, nprim_ij, stride_j, stride_ij, ao_offsets_i, ao_offsets_j, omega, grid_points, charge_exponents); break; + // case (1<<2)|0: GINTfill_int3c1e_ip1_charge_contracted_kernel10<<>>(output, offsets, nprim_ij, stride_j, stride_ij, ao_offsets_i, ao_offsets_j, omega, grid_points, charge_exponents); break; + // default: + // fprintf(stderr, "roots=1 type_ijkl %d\n", type_ijkl); + // } + // break; + case 1: GINTfill_int3c1e_ip1_charge_contracted_kernel_general<1, GSIZE1_INT3C_1E> <<>>(output, offsets, i_l, j_l, nprim_ij, stride_j, stride_ij, ao_offsets_i, ao_offsets_j, omega, grid_points, charge_exponents); break; + case 2: GINTfill_int3c1e_ip1_charge_contracted_kernel_general<2, GSIZE2_INT3C_1E> <<>>(output, offsets, i_l, j_l, nprim_ij, stride_j, stride_ij, ao_offsets_i, ao_offsets_j, omega, grid_points, charge_exponents); break; + case 3: GINTfill_int3c1e_ip1_charge_contracted_kernel_general<3, GSIZE3_INT3C_1E> <<>>(output, offsets, i_l, j_l, nprim_ij, stride_j, stride_ij, ao_offsets_i, ao_offsets_j, omega, grid_points, charge_exponents); break; + case 4: GINTfill_int3c1e_ip1_charge_contracted_kernel_general<4, GSIZE4_INT3C_1E> <<>>(output, offsets, i_l, j_l, nprim_ij, stride_j, stride_ij, ao_offsets_i, ao_offsets_j, omega, grid_points, charge_exponents); break; + case 5: GINTfill_int3c1e_ip1_charge_contracted_kernel_general<5, GSIZE5_INT3C_1E> <<>>(output, offsets, i_l, j_l, nprim_ij, stride_j, stride_ij, ao_offsets_i, ao_offsets_j, omega, grid_points, charge_exponents); break; + default: + fprintf(stderr, "rys roots %d\n", nrys_roots); + return 1; + } + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + fprintf(stderr, "CUDA Error in %s: %s\n", __func__, cudaGetErrorString(err)); + return 1; + } + return 0; +} + +static int GINTfill_int3c1e_ip2_density_contracted_tasks(double* output, const double* density, const HermiteDensityOffsets hermite_density_offsets, + const BasisProdOffsets offsets, const int i_l, const int j_l, const int nprim_ij, + const double omega, const double* grid_points, const double* charge_exponents, + const int n_pair_sum_per_thread, const cudaStream_t stream) +{ + const int nrys_roots = (i_l + j_l + 1) / 2 + 1; + const int ntasks_ij = (offsets.ntasks_ij + n_pair_sum_per_thread - 1) / n_pair_sum_per_thread; + const int ngrids = offsets.ntasks_kl; + + const dim3 threads(THREADSX, THREADSY); + const dim3 blocks((ntasks_ij+THREADSX-1)/THREADSX, (ngrids+THREADSY-1)/THREADSY); + int type_ijkl; + switch (nrys_roots) { + // case 1: + // type_ijkl = (i_l << 2) | j_l; + // switch (type_ijkl) { + // case (0<<2)|0: GINTfill_int3c1e_ip2_density_contracted_kernel00<<>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break; + // case (1<<2)|0: GINTfill_int3c1e_ip2_density_contracted_kernel10<<>>(output, density, hermite_density_offsets, offsets, nprim_ij, omega, grid_points, charge_exponents); break; + // default: + // fprintf(stderr, "roots=1 type_ijkl %d\n", type_ijkl); + // } + // break; + case 1: GINTfill_int3c1e_ip2_density_contracted_kernel_general<1> <<>>(output, density, hermite_density_offsets, offsets, i_l, j_l, nprim_ij, omega, grid_points, charge_exponents); break; + case 2: GINTfill_int3c1e_ip2_density_contracted_kernel_general<2> <<>>(output, density, hermite_density_offsets, offsets, i_l, j_l, nprim_ij, omega, grid_points, charge_exponents); break; + case 3: GINTfill_int3c1e_ip2_density_contracted_kernel_general<3> <<>>(output, density, hermite_density_offsets, offsets, i_l, j_l, nprim_ij, omega, grid_points, charge_exponents); break; + case 4: GINTfill_int3c1e_ip2_density_contracted_kernel_general<4> <<>>(output, density, hermite_density_offsets, offsets, i_l, j_l, nprim_ij, omega, grid_points, charge_exponents); break; + case 5: GINTfill_int3c1e_ip2_density_contracted_kernel_general<5> <<>>(output, density, hermite_density_offsets, offsets, i_l, j_l, nprim_ij, omega, grid_points, charge_exponents); break; + default: + fprintf(stderr, "rys roots %d\n", nrys_roots); + return 1; + } + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + fprintf(stderr, "CUDA Error in %s: %s\n", __func__, cudaGetErrorString(err)); + return 1; + } + return 0; +} + extern "C" { int GINTfill_int3c1e_ip(const cudaStream_t stream, const BasisProdCache* bpcache, const double* grid_points, const double* charge_exponents, const int ngrids, @@ -117,4 +197,109 @@ int GINTfill_int3c1e_ip(const cudaStream_t stream, const BasisProdCache* bpcache return 0; } + +int GINTfill_int3c1e_ip1_charge_contracted(const cudaStream_t stream, const BasisProdCache* bpcache, + const double* grid_points, const double* charge_exponents, const int ngrids, + double* integral_charge_contracted, const int nao, + const int* strides, const int* ao_offsets, + const int* bins_locs_ij, int nbins, + const int cp_ij_id, const double omega, const int n_charge_sum_per_thread) +{ + const ContractionProdType *cp_ij = bpcache->cptype + cp_ij_id; + const int i_l = cp_ij->l_bra; + const int j_l = cp_ij->l_ket; + const int nrys_roots = (i_l + j_l + 1) / 2 + 1; + const int nprim_ij = cp_ij->nprim_12; + + if (nrys_roots > MAX_NROOTS_INT3C_1E + 1) { + fprintf(stderr, "nrys_roots = %d too high\n", nrys_roots); + return 2; + } + + checkCudaErrors(cudaMemcpyToSymbol(c_bpcache, bpcache, sizeof(BasisProdCache))); + + const int* bas_pairs_locs = bpcache->bas_pairs_locs; + const int* primitive_pairs_locs = bpcache->primitive_pairs_locs; + for (int ij_bin = 0; ij_bin < nbins; ij_bin++) { + const int bas_ij0 = bins_locs_ij[ij_bin]; + const int bas_ij1 = bins_locs_ij[ij_bin + 1]; + const int ntasks_ij = bas_ij1 - bas_ij0; + if (ntasks_ij <= 0) { + continue; + } + + BasisProdOffsets offsets; + offsets.ntasks_ij = ntasks_ij; + offsets.ntasks_kl = ngrids; + offsets.bas_ij = bas_pairs_locs[cp_ij_id] + bas_ij0; + offsets.bas_kl = -1; + offsets.primitive_ij = primitive_pairs_locs[cp_ij_id] + bas_ij0 * nprim_ij; + offsets.primitive_kl = -1; + + const int err = GINTfill_int3c1e_ip1_charge_contracted_tasks(integral_charge_contracted, offsets, i_l, j_l, nprim_ij, + strides[0], strides[1], ao_offsets[0], ao_offsets[1], + omega, grid_points, charge_exponents, n_charge_sum_per_thread, stream); + + if (err != 0) { + return err; + } + } + + return 0; +} + +int GINTfill_int3c1e_ip2_density_contracted(const cudaStream_t stream, const BasisProdCache* bpcache, + const double* grid_points, const double* charge_exponents, const int ngrids, + const double* dm_pair_ordered, const int* density_offset, + double* integral_density_contracted, + const int* bins_locs_ij, int nbins, + const int cp_ij_id, const double omega, const int n_pair_sum_per_thread) +{ + const ContractionProdType *cp_ij = bpcache->cptype + cp_ij_id; + const int i_l = cp_ij->l_bra; + const int j_l = cp_ij->l_ket; + const int nrys_roots = (i_l + j_l + 1) / 2 + 1; + const int nprim_ij = cp_ij->nprim_12; + + if (nrys_roots > MAX_NROOTS_INT3C_1E + 1) { + fprintf(stderr, "nrys_roots = %d too high\n", nrys_roots); + return 2; + } + + checkCudaErrors(cudaMemcpyToSymbol(c_bpcache, bpcache, sizeof(BasisProdCache))); + + const int* bas_pairs_locs = bpcache->bas_pairs_locs; + const int* primitive_pairs_locs = bpcache->primitive_pairs_locs; + for (int ij_bin = 0; ij_bin < nbins; ij_bin++) { + const int bas_ij0 = bins_locs_ij[ij_bin]; + const int bas_ij1 = bins_locs_ij[ij_bin + 1]; + const int ntasks_ij = bas_ij1 - bas_ij0; + if (ntasks_ij <= 0) { + continue; + } + + BasisProdOffsets offsets; + offsets.ntasks_ij = ntasks_ij; + offsets.ntasks_kl = ngrids; + offsets.bas_ij = bas_pairs_locs[cp_ij_id] + bas_ij0; + offsets.bas_kl = -1; + offsets.primitive_ij = primitive_pairs_locs[cp_ij_id] + bas_ij0 * nprim_ij; + offsets.primitive_kl = -1; + + HermiteDensityOffsets hermite_density_offsets; + hermite_density_offsets.density_offset_of_angular_pair = density_offset[cp_ij_id]; + hermite_density_offsets.pair_offset_of_angular_pair = bas_pairs_locs[cp_ij_id]; + hermite_density_offsets.n_pair_of_angular_pair = bas_pairs_locs[cp_ij_id + 1] - bas_pairs_locs[cp_ij_id]; + + const int err = GINTfill_int3c1e_ip2_density_contracted_tasks(integral_density_contracted, dm_pair_ordered, hermite_density_offsets, + offsets, i_l, j_l, nprim_ij, + omega, grid_points, charge_exponents, n_pair_sum_per_thread, stream); + + if (err != 0) { + return err; + } + } + + return 0; +} }