Skip to content

Commit

Permalink
Add guassian charge support to g3c1e
Browse files Browse the repository at this point in the history
  • Loading branch information
henryw7 committed Dec 4, 2024
1 parent 2f21fbe commit 4017ef8
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 97 deletions.
46 changes: 35 additions & 11 deletions gpu4pyscf/gto/moleintor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pyscf.scf import _vhf
from pyscf.gto import ATOM_OF
from pyscf.lib import c_null_ptr
from gpu4pyscf.lib.cupy_helper import load_library, cart2sph, block_c2s_diag, get_avail_mem
from gpu4pyscf.lib import logger
from gpu4pyscf.scf.int4c2e import BasisProdCache
Expand Down Expand Up @@ -168,7 +169,7 @@ def cart2sph(self):
# end of class VHFOpt


def get_int3c1e(mol, grids, intopt):
def get_int3c1e(mol, grids, charge_exponents, intopt):
omega = mol.omega
assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented."

Expand All @@ -190,6 +191,8 @@ def get_int3c1e(mol, grids, intopt):
# int3c = np.zeros([ngrids, nao, nao], order='C') # Using unpinned (pageable) memory, each memcpy is much slower, but there's no initialization time

grids = cp.asarray(grids, order='C')
if charge_exponents is not None:
charge_exponents = cp.asarray(charge_exponents, order='C')

for i_grid_split in range(0, ngrids, ngrids_per_split):
ngrids_of_split = np.min([ngrids_per_split, ngrids - i_grid_split])
Expand Down Expand Up @@ -218,10 +221,15 @@ def get_int3c1e(mol, grids, intopt):

int3c_angular_slice = cp.zeros([ngrids_of_split, j1-j0, i1-i0], order='C')

charge_exponents_pointer = c_null_ptr()
if charge_exponents is not None:
charge_exponents_pointer = charge_exponents[i_grid_split : i_grid_split + ngrids_of_split].data.ptr

err = libgint.GINTfill_int3c1e(
ctypes.cast(stream.ptr, ctypes.c_void_p),
intopt.bpcache,
ctypes.cast(grids[i_grid_split : i_grid_split + ngrids_of_split, :].data.ptr, ctypes.c_void_p),
ctypes.cast(charge_exponents_pointer, ctypes.c_void_p),
ctypes.c_int(ngrids_of_split),
ctypes.cast(int3c_angular_slice.data.ptr, ctypes.c_void_p),
ctypes.c_int(nao_cart),
Expand Down Expand Up @@ -253,7 +261,7 @@ def get_int3c1e(mol, grids, intopt):

return int3c

def get_int3c1e_charge_contracted(mol, grids, charges, intopt):
def get_int3c1e_charge_contracted(mol, grids, charge_exponents, charges, intopt):
omega = mol.omega
assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented."

Expand All @@ -262,8 +270,9 @@ def get_int3c1e_charge_contracted(mol, grids, charges, intopt):
assert charges.ndim == 1 and charges.shape[0] == grids.shape[0]

grids = cp.asarray(grids, order='C')
charges = cp.asarray(charges).reshape([-1, 1], order='C')
grids = cp.concatenate([grids, charges], axis=1)
charges = cp.asarray(charges)
if charge_exponents is not None:
charge_exponents = cp.asarray(charge_exponents, order='C')

int1e = cp.zeros([mol.nao, mol.nao], order='C')
for cp_ij_id, _ in enumerate(intopt.log_qs):
Expand All @@ -288,6 +297,10 @@ def get_int3c1e_charge_contracted(mol, grids, charges, intopt):
ao_offsets = np.array([i0, j0], dtype=np.int32)
strides = np.array([ni, ni*nj], dtype=np.int32)

charge_exponents_pointer = c_null_ptr()
if charge_exponents is not None:
charge_exponents_pointer = charge_exponents.data.ptr

ngrids = grids.shape[0]
# n_charge_sum_per_thread = 1 # means every thread processes one pair and one grid
# n_charge_sum_per_thread = ngrids # or larger number gaurantees one thread processes one pair and all grid points
Expand All @@ -299,6 +312,8 @@ def get_int3c1e_charge_contracted(mol, grids, charges, intopt):
ctypes.cast(stream.ptr, ctypes.c_void_p),
intopt.bpcache,
ctypes.cast(grids.data.ptr, ctypes.c_void_p),
ctypes.cast(charges.data.ptr, ctypes.c_void_p),
ctypes.cast(charge_exponents_pointer, ctypes.c_void_p),
ctypes.c_int(ngrids),
ctypes.cast(int1e_angular_slice.data.ptr, ctypes.c_void_p),
ctypes.c_int(nao_cart),
Expand Down Expand Up @@ -328,7 +343,7 @@ def get_int3c1e_charge_contracted(mol, grids, charges, intopt):

return int1e

def get_int3c1e_density_contracted(mol, grids, dm, intopt):
def get_int3c1e_density_contracted(mol, grids, charge_exponents, dm, intopt):
omega = mol.omega
assert omega >= 0.0, "Short-range one electron integrals with GPU acceleration is not implemented."

Expand Down Expand Up @@ -368,8 +383,6 @@ def get_int3c1e_density_contracted(mol, grids, dm, intopt):
bas_coords.ctypes.data_as(ctypes.c_void_p))

dm_pair_ordered = cp.asarray(dm_pair_ordered)
grids = cp.asarray(grids, order='C')
int3c_density_contracted = cp.zeros(ngrids)

n_threads_per_block_1d = 16
n_max_blocks_per_grid_1d = 65535
Expand All @@ -379,6 +392,12 @@ def get_int3c1e_density_contracted(mol, grids, dm, intopt):
print(f"Grid dimension = {ngrids} is too large, more than 100 kernels for one electron integral will be launched.")
ngrids_per_split = (ngrids + n_grid_split - 1) // n_grid_split

grids = cp.asarray(grids, order='C')
if charge_exponents is not None:
charge_exponents = cp.asarray(charge_exponents, order='C')

int3c_density_contracted = cp.zeros(ngrids)

for i_grid_split in range(0, ngrids, ngrids_per_split):
ngrids_of_split = np.min([ngrids_per_split, ngrids - i_grid_split])
for cp_ij_id, _ in enumerate(intopt.log_qs):
Expand All @@ -389,6 +408,10 @@ def get_int3c1e_density_contracted(mol, grids, dm, intopt):
nbins = 1
bins_locs_ij = np.array([0, len(log_q_ij)], dtype=np.int32)

charge_exponents_pointer = c_null_ptr()
if charge_exponents is not None:
charge_exponents_pointer = charge_exponents[i_grid_split : i_grid_split + ngrids_of_split].data.ptr

# n_pair_sum_per_thread = 1 # means every thread processes one pair and one grid
# n_pair_sum_per_thread = nao_cart # or larger number gaurantees one thread processes one grid and all pairs of the same type
n_pair_sum_per_thread = nao_cart
Expand All @@ -397,6 +420,7 @@ def get_int3c1e_density_contracted(mol, grids, dm, intopt):
ctypes.cast(stream.ptr, ctypes.c_void_p),
intopt.bpcache,
ctypes.cast(grids[i_grid_split : i_grid_split + ngrids_of_split, :].data.ptr, ctypes.c_void_p),
ctypes.cast(charge_exponents_pointer, ctypes.c_void_p),
ctypes.c_int(ngrids_of_split),
ctypes.cast(dm_pair_ordered.data.ptr, ctypes.c_void_p),
intopt.density_offset.ctypes.data_as(ctypes.c_void_p),
Expand All @@ -412,7 +436,7 @@ def get_int3c1e_density_contracted(mol, grids, dm, intopt):

return int3c_density_contracted

def intor(mol, intor, grids, dm=None, charges=None, direct_scf_tol=1e-13, intopt=None):
def intor(mol, intor, grids, charge_exponents=None, dm=None, charges=None, direct_scf_tol=1e-13, intopt=None):
assert intor == 'int1e_grids'
assert grids is not None
assert dm is None or charges is None, \
Expand All @@ -428,10 +452,10 @@ def intor(mol, intor, grids, dm=None, charges=None, direct_scf_tol=1e-13, intopt
assert hasattr(intopt, "density_offset"), "Please call build() function for VHFOpt object first."

if dm is None and charges is None:
return get_int3c1e(mol, grids, intopt)
return get_int3c1e(mol, grids, charge_exponents, intopt)
elif dm is not None:
return get_int3c1e_density_contracted(mol, grids, dm, intopt)
return get_int3c1e_density_contracted(mol, grids, charge_exponents, dm, intopt)
elif charges is not None:
return get_int3c1e_charge_contracted(mol, grids, charges, intopt)
return get_int3c1e_charge_contracted(mol, grids, charge_exponents, charges, intopt)
else:
raise ValueError(f"Logic error in {__file__} {__name__}")
106 changes: 105 additions & 1 deletion gpu4pyscf/gto/tests/test_int1e_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import cupy as cp
import pyscf
from pyscf import lib
from pyscf import lib, gto, df
from gpu4pyscf.gto.moleintor import intor

def setUpModule():
Expand Down Expand Up @@ -157,6 +157,110 @@ def test_int1e_grids_charge_contracted_omega(self):
assert isinstance(test_int1e_dot_q, cp.ndarray)
cp.testing.assert_allclose(ref_int1e_dot_q, test_int1e_dot_q, atol = charge_contraction_threshold)

# Gaussian charges

def test_int1e_grids_full_tensor_guassian_charge(self):
np.random.seed(12351)
charge_exponents = np.random.uniform(0.5, 1.0, grid_points.shape[0])

int3c2e = mol_sph._add_suffix('int3c2e')
cintopt = gto.moleintor.make_cintopt(mol_sph._atm, mol_sph._bas, mol_sph._env, int3c2e)
fakemol = gto.fakemol_for_charges(grid_points, expnt=charge_exponents)
ref_int1e = df.incore.aux_e2(mol_sph, fakemol, intor=int3c2e, aosym='s1', cintopt=cintopt)
ref_int1e = ref_int1e.transpose((2,0,1))

test_int1e = intor(mol_sph, 'int1e_grids', grid_points, charge_exponents = charge_exponents)
np.testing.assert_allclose(ref_int1e, test_int1e, atol = integral_threshold)

def test_int1e_grids_density_contracted_guassian_charge(self):
np.random.seed(12351)
charge_exponents = np.random.uniform(0.5, 1.0, grid_points.shape[0])
dm = np.random.uniform(-2.0, 2.0, (mol_sph.nao, mol_sph.nao))

int3c2e = mol_sph._add_suffix('int3c2e')
cintopt = gto.moleintor.make_cintopt(mol_sph._atm, mol_sph._bas, mol_sph._env, int3c2e)
fakemol = gto.fakemol_for_charges(grid_points, expnt=charge_exponents)
ref_int1e = df.incore.aux_e2(mol_sph, fakemol, intor=int3c2e, aosym='s1', cintopt=cintopt)
ref_int1e = ref_int1e.transpose((2,0,1))

ref_int1e_dot_D = np.einsum('pij,ij->p', ref_int1e, dm)
test_int1e_dot_D = intor(mol_sph, 'int1e_grids', grid_points, dm = dm, charge_exponents = charge_exponents)
assert isinstance(test_int1e_dot_D, cp.ndarray)
cp.testing.assert_allclose(ref_int1e_dot_D, test_int1e_dot_D, atol = density_contraction_threshold)

def test_int1e_grids_charge_contracted_guassian_charge(self):
np.random.seed(12351)
charge_exponents = np.random.uniform(0.5, 1.0, grid_points.shape[0])
charges = np.random.uniform(-2.0, 2.0, grid_points.shape[0])

int3c2e = mol_sph._add_suffix('int3c2e')
cintopt = gto.moleintor.make_cintopt(mol_sph._atm, mol_sph._bas, mol_sph._env, int3c2e)
fakemol = gto.fakemol_for_charges(grid_points, expnt=charge_exponents)
ref_int1e = df.incore.aux_e2(mol_sph, fakemol, intor=int3c2e, aosym='s1', cintopt=cintopt)
ref_int1e = ref_int1e.transpose((2,0,1))

ref_int1e_dot_q = np.einsum('pij,p->ij', ref_int1e, charges)
test_int1e_dot_q = intor(mol_sph, 'int1e_grids', grid_points, charges = charges, charge_exponents = charge_exponents)
assert isinstance(test_int1e_dot_q, cp.ndarray)
cp.testing.assert_allclose(ref_int1e_dot_q, test_int1e_dot_q, atol = charge_contraction_threshold)

def test_int1e_grids_full_tensor_guassian_charge_omega(self):
np.random.seed(12351)
charge_exponents = np.random.uniform(0.5, 1.0, grid_points.shape[0])

omega = 0.8
mol_sph_omega = mol_sph.copy()
mol_sph_omega.set_range_coulomb(omega)

int3c2e = mol_sph_omega._add_suffix('int3c2e')
cintopt = gto.moleintor.make_cintopt(mol_sph_omega._atm, mol_sph_omega._bas, mol_sph_omega._env, int3c2e)
fakemol = gto.fakemol_for_charges(grid_points, expnt=charge_exponents)
ref_int1e = df.incore.aux_e2(mol_sph_omega, fakemol, intor=int3c2e, aosym='s1', cintopt=cintopt)
ref_int1e = ref_int1e.transpose((2,0,1))

test_int1e = intor(mol_sph_omega, 'int1e_grids', grid_points, charge_exponents = charge_exponents)
np.testing.assert_allclose(ref_int1e, test_int1e, atol = integral_threshold)

def test_int1e_grids_density_contracted_guassian_charge_omega(self):
np.random.seed(12351)
charge_exponents = np.random.uniform(0.5, 1.0, grid_points.shape[0])
dm = np.random.uniform(-2.0, 2.0, (mol_sph.nao, mol_sph.nao))

omega = 1.2
mol_sph_omega = mol_sph.copy()
mol_sph_omega.set_range_coulomb(omega)

int3c2e = mol_sph_omega._add_suffix('int3c2e')
cintopt = gto.moleintor.make_cintopt(mol_sph_omega._atm, mol_sph_omega._bas, mol_sph_omega._env, int3c2e)
fakemol = gto.fakemol_for_charges(grid_points, expnt=charge_exponents)
ref_int1e = df.incore.aux_e2(mol_sph_omega, fakemol, intor=int3c2e, aosym='s1', cintopt=cintopt)
ref_int1e = ref_int1e.transpose((2,0,1))

ref_int1e_dot_D = np.einsum('pij,ij->p', ref_int1e, dm)
test_int1e_dot_D = intor(mol_sph_omega, 'int1e_grids', grid_points, dm = dm, charge_exponents = charge_exponents)
assert isinstance(test_int1e_dot_D, cp.ndarray)
cp.testing.assert_allclose(ref_int1e_dot_D, test_int1e_dot_D, atol = density_contraction_threshold)

def test_int1e_grids_charge_contracted_guassian_charge_omega(self):
np.random.seed(12351)
charge_exponents = np.random.uniform(0.5, 1.0, grid_points.shape[0])
charges = np.random.uniform(-2.0, 2.0, grid_points.shape[0])

omega = 1.2
mol_sph_omega = mol_sph.copy()
mol_sph_omega.set_range_coulomb(omega)

int3c2e = mol_sph_omega._add_suffix('int3c2e')
cintopt = gto.moleintor.make_cintopt(mol_sph_omega._atm, mol_sph_omega._bas, mol_sph_omega._env, int3c2e)
fakemol = gto.fakemol_for_charges(grid_points, expnt=charge_exponents)
ref_int1e = df.incore.aux_e2(mol_sph_omega, fakemol, intor=int3c2e, aosym='s1', cintopt=cintopt)
ref_int1e = ref_int1e.transpose((2,0,1))

ref_int1e_dot_q = np.einsum('pij,p->ij', ref_int1e, charges)
test_int1e_dot_q = intor(mol_sph_omega, 'int1e_grids', grid_points, charges = charges, charge_exponents = charge_exponents)
assert isinstance(test_int1e_dot_q, cp.ndarray)
cp.testing.assert_allclose(ref_int1e_dot_q, test_int1e_dot_q, atol = charge_contraction_threshold)

if __name__ == "__main__":
print("Full Tests for One Electron Coulomb Integrals")
unittest.main()
47 changes: 28 additions & 19 deletions gpu4pyscf/lib/gint/g1e.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
// This function assumes i_l >= j_l
template <int NROOTS>
__device__
static void GINTg1e(double* __restrict__ g, const double* __restrict__ grid_point, const int ish, const int jsh, const int prim_ij,
const int i_l, const int j_l, const double omega)
static void GINTg1e(double* __restrict__ g, const double* __restrict__ grid_point,
const int ish, const int jsh, const int prim_ij,
const int i_l, const int j_l, const double charge_exponent, const double omega)
{
const double* __restrict__ a12 = c_bpcache.a12;
const double* __restrict__ e12 = c_bpcache.e12;
Expand All @@ -40,12 +41,16 @@ static void GINTg1e(double* __restrict__ g, const double* __restrict__ grid_poin
const double PCx = Px - Cx;
const double PCy = Py - Cy;
const double PCz = Pz - Cz;

double a0 = aij;
const double theta = omega > 0.0 ? omega * omega / (omega * omega + aij) : 1.0;
const double sqrt_theta = omega > 0.0 ? omega / sqrt(omega * omega + aij) : 1.0;
const double q_over_p_plus_q = charge_exponent > 0.0 ? charge_exponent / (aij + charge_exponent) : 1.0;
const double sqrt_q_over_p_plus_q = charge_exponent > 0.0 ? sqrt(q_over_p_plus_q) : 1.0;
a0 *= q_over_p_plus_q;
const double theta = omega > 0.0 ? omega * omega / (omega * omega + a0) : 1.0;
const double sqrt_theta = omega > 0.0 ? sqrt(theta) : 1.0;
a0 *= theta;

const double prefactor = 2.0 * M_PI / aij * eij * sqrt_theta;
const double prefactor = 2.0 * M_PI / aij * eij * sqrt_theta * sqrt_q_over_p_plus_q;
const double boys_input = a0 * (PCx * PCx + PCy * PCy + PCz * PCz);
double uw[NROOTS * 2];
GINTrys_root<NROOTS>(boys_input, uw);
Expand Down Expand Up @@ -76,11 +81,11 @@ static void GINTg1e(double* __restrict__ g, const double* __restrict__ grid_poin
gz[i_root] = w[i_root];

const double u2 = a0 * u[i_root];
const double t2 = u2 / (u2 + aij);
const double b10 = 0.5 / aij * (1.0 - t2);
const double c00x = PAx - t2 * PCx;
const double c00y = PAy - t2 * PCy;
const double c00z = PAz - t2 * PCz;
const double qt2_over_p_plus_q = u2 / (u2 + aij * q_over_p_plus_q) * q_over_p_plus_q;
const double b10 = 0.5 / aij * (1.0 - qt2_over_p_plus_q);
const double c00x = PAx - qt2_over_p_plus_q * PCx;
const double c00y = PAy - qt2_over_p_plus_q * PCy;
const double c00z = PAz - qt2_over_p_plus_q * PCz;

if (i_l + j_l > 0) {
double s0x = gx[i_root]; // i - 1
Expand Down Expand Up @@ -132,7 +137,7 @@ static void GINTg1e(double* __restrict__ g, const double* __restrict__ grid_poin
template <int NROOTS>
__device__
static void GINT_g1e_without_hrr(double* __restrict__ g, const double grid_x, const double grid_y, const double grid_z,
const int ish, const int prim_ij, const int l, const double omega)
const int ish, const int prim_ij, const int l, const double charge_exponent, const double omega)
{
const double* __restrict__ a12 = c_bpcache.a12;
const double* __restrict__ e12 = c_bpcache.e12;
Expand All @@ -148,12 +153,16 @@ static void GINT_g1e_without_hrr(double* __restrict__ g, const double grid_x, co
const double PCx = Px - grid_x;
const double PCy = Py - grid_y;
const double PCz = Pz - grid_z;

double a0 = aij;
const double theta = omega > 0.0 ? omega * omega / (omega * omega + aij) : 1.0;
const double sqrt_theta = omega > 0.0 ? omega / sqrt(omega * omega + aij) : 1.0;
const double q_over_p_plus_q = charge_exponent > 0.0 ? charge_exponent / (aij + charge_exponent) : 1.0;
const double sqrt_q_over_p_plus_q = charge_exponent > 0.0 ? sqrt(q_over_p_plus_q) : 1.0;
a0 *= q_over_p_plus_q;
const double theta = omega > 0.0 ? omega * omega / (omega * omega + a0) : 1.0;
const double sqrt_theta = omega > 0.0 ? sqrt(theta) : 1.0;
a0 *= theta;

const double prefactor = 2.0 * M_PI / aij * eij * sqrt_theta;
const double prefactor = 2.0 * M_PI / aij * eij * sqrt_theta * sqrt_q_over_p_plus_q;
const double boys_input = a0 * (PCx * PCx + PCy * PCy + PCz * PCz);
double uw[NROOTS * 2];
GINTrys_root<NROOTS>(boys_input, uw);
Expand Down Expand Up @@ -184,11 +193,11 @@ static void GINT_g1e_without_hrr(double* __restrict__ g, const double grid_x, co
gz[i_root] = w[i_root];

const double u2 = a0 * u[i_root];
const double t2 = u2 / (u2 + aij);
const double b10 = 0.5 / aij * (1.0 - t2);
const double c00x = PAx - t2 * PCx;
const double c00y = PAy - t2 * PCy;
const double c00z = PAz - t2 * PCz;
const double qt2_over_p_plus_q = u2 / (u2 + aij * q_over_p_plus_q) * q_over_p_plus_q;
const double b10 = 0.5 / aij * (1.0 - qt2_over_p_plus_q);
const double c00x = PAx - qt2_over_p_plus_q * PCx;
const double c00y = PAy - qt2_over_p_plus_q * PCy;
const double c00z = PAz - qt2_over_p_plus_q * PCz;

if (l > 0) {
double s0x = gx[i_root]; // i - 1
Expand Down
Loading

0 comments on commit 4017ef8

Please sign in to comment.