Skip to content

Commit

Permalink
Improve eval gto (pyscf#72)
Browse files Browse the repository at this point in the history
* fixed a bug in screen_index

* added unit test for to_gpu

* new grids group scheme

* use grid_aligned in gpu4pyscf.__config__

* fixed a bug in eval_ao
  • Loading branch information
wxj6000 authored Jan 2, 2024
1 parent 2e5bc47 commit 1dff11d
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 131 deletions.
5 changes: 3 additions & 2 deletions examples/dft_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
max_memory=32000)
# set verbose >= 6 for debugging timer

mol.verbose = 6
mol.verbose = 4

mf_df = rks.RKS(mol, xc=args.xc).density_fit(auxbasis=args.auxbasis)
mf_df.verbose = 6
mf_df.verbose = 4

if args.solvent:
mf_df = mf_df.PCM()
Expand All @@ -52,6 +52,7 @@
mf_df.direct_scf_tol = 1e-14
mf_df.direct_scf = 1e-14
mf_df.conv_tol = 1e-10
mf_df.chkfile = None
e_tot = mf_df.kernel()
scf_time = time.time() - start_time
print(f'compute time for energy: {scf_time:.3f} s')
Expand Down
4 changes: 2 additions & 2 deletions gpu4pyscf/__config__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
GB = 1024*1024*1024
# such as A100-80G
if props['totalGlobalMem'] >= 64 * GB:
min_ao_blksize = 128
min_grid_blksize = 256*256#128*128
min_ao_blksize = 256
min_grid_blksize = 256*256
ao_aligned = 32
grid_aligned = 128
mem_fraction = 0.9
Expand Down
1 change: 0 additions & 1 deletion gpu4pyscf/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False):
with data_stream:
for i in range(naux):
cderi_block[i].get(out=cderi[i,ij0:ij1])

t1 = log.timer_debug1(f'solve {cp_ij_id} / {nq}', *t1)

cupy.cuda.Device().synchronize()
Expand Down
47 changes: 31 additions & 16 deletions gpu4pyscf/df/df_jk.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
import cupy
import numpy
from cupy import cublas
from pyscf import lib, scf, __config__
from pyscf.scf import dhf
from pyscf.df import df_jk, addons
Expand Down Expand Up @@ -264,69 +265,83 @@ def get_jk(dfobj, dms_tag, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e-
dms = take_last2d(dms, ao_idx)

t1 = log.timer_debug1('init jk', *t0)
rows = dfobj.intopt.cderi_row
cols = dfobj.intopt.cderi_col
if with_j:
rows = dfobj.intopt.cderi_row
cols = dfobj.intopt.cderi_col
dm_sparse = dms[:,rows,cols]
dm_sparse[:, dfobj.intopt.cderi_diag] *= .5
vj = cupy.zeros_like(dms)
vj_tmp = cupy.zeros_like(dms)

if with_k:
vk = cupy.zeros_like(dms)

def get_j(cderi_sparse):
rhoj = 2.0*dm_sparse.dot(cderi_sparse)
vj_sparse = cupy.dot(rhoj, cderi_sparse.T)
vj_tmp[:,rows,cols] = vj_sparse
vj_tmp[:,cols,rows] = vj_sparse
vj_sparse = None
return vj_tmp

# SCF K matrix with occ
if nset == 1 and hasattr(dms_tag, 'occ_coeff'):
occ_coeff = cupy.asarray(dms_tag.occ_coeff[ao_idx, :], order='C')
nocc = occ_coeff.shape[1]
blksize = dfobj.get_blksize(extra=nao*nocc)
if with_j:
vj_packed = cupy.zeros_like(dm_sparse)

for cderi, cderi_sparse in dfobj.loop(blksize=blksize, unpack=with_k):
# leading dimension is 1
if with_j:
vj += get_j(cderi_sparse)
rhoj = 2.0*dm_sparse.dot(cderi_sparse)
vj_packed += cupy.dot(rhoj, cderi_sparse.T)
if with_k:
rhok = contract('Lij,jk->Lki', cderi, occ_coeff)
#vk[0] += contract('Lki,Lkj->ij', rhok, rhok)
contract('Lki,Lkj->ij', rhok, rhok, alpha=1.0, beta=1.0, out=vk[0])
cublas.syrk('T', rhok.reshape([-1,nao]), out=vk[0], alpha=1.0, beta=1.0, lower=True)
if with_j:
vj[:,rows,cols] = vj_packed
vj[:,cols,rows] = vj_packed
if with_k:
vk[0][numpy.diag_indices(nao)] *= 0.5
transpose_sum(vk)
vk *= 2.0
# CP-HF K matrix
elif hasattr(dms_tag, 'mo1'):
if with_j:
vj_sparse = cupy.zeros_like(dm_sparse)
mo1 = dms_tag.mo1[:,ao_idx,:]
nocc = mo1.shape[2]
# 2.0 due to rhok and rhok1, put it here for symmetry
occ_coeff = dms_tag.occ_coeff[ao_idx,:] * 2.0
blksize = dfobj.get_blksize(extra=2*nao*nocc)
for cderi, cderi_sparse in dfobj.loop(blksize=blksize, unpack=with_k):
if with_j:
vj += get_j(cderi_sparse)
#vj += get_j(cderi_sparse)
rhoj = 2.0*dm_sparse.dot(cderi_sparse)
vj_sparse += cupy.dot(rhoj, cderi_sparse.T)
if with_k:
rhok = contract('Lij,jk->Lki', cderi, occ_coeff)
for i in range(mo1.shape[0]):
rhok1 = contract('Lij,jk->Lki', cderi, mo1[i])
#vk[i] += contract('Lki,Lkj->ij', rhok, rhok1)
contract('Lki,Lkj->ij', rhok, rhok1, alpha=1.0, beta=1.0, out=vk[i])
occ_coeff = rhok1 = rhok = mo1 = None
if with_j:
vj[:,rows,cols] = vj_sparse
vj[:,cols,rows] = vj_sparse
if with_k:
vk = vk + vk.transpose(0,2,1)
#vk = vk + vk.transpose(0,2,1)
transpose_sum(vk)
# general K matrix with density matrix
else:
if with_j:
vj_sparse = cupy.zeros_like(dm_sparse)
blksize = dfobj.get_blksize()
for cderi, cderi_sparse in dfobj.loop(blksize=blksize, unpack=with_k):
if with_j:
vj += get_j(cderi_sparse)
rhoj = 2.0*dm_sparse.dot(cderi_sparse)
vj_sparse += cupy.dot(rhoj, cderi_sparse.T)
if with_k:
for k in range(nset):
rhok = contract('Lij,jk->Lki', cderi, dms[k])
vk[k] += contract('Lki,Lkj->ij', cderi, rhok)
if with_j:
vj[:,rows,cols] = vj_sparse
vj[:,cols,rows] = vj_sparse
rhok = None

rev_ao_idx = dfobj.intopt.rev_ao_idx
Expand Down
26 changes: 19 additions & 7 deletions gpu4pyscf/dft/gen_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,17 @@ def gen_grids_partition(atm_coords, coords, a):
natm = atm_coords.shape[0]
ngrids = coords.shape[0]
assert ngrids < 65535 * 16
x_i = cupy.expand_dims(atm_coords, axis=1)
x_g = cupy.expand_dims(coords, axis=0)
squared_diff = (x_i - x_g)**2
dist_ig = cupy.sum(squared_diff, axis=2)**0.5
#x_i = cupy.expand_dims(atm_coords, axis=1)
#x_g = cupy.expand_dims(coords, axis=0)
#squared_diff = (x_i - x_g)**2
#dist_ig = cupy.sum(squared_diff, axis=2)**0.5

x_j = cupy.expand_dims(atm_coords, axis=0)
squared_diff = (x_i - x_j)**2
dist_ij = cupy.sum(squared_diff, axis=2)**0.5
#x_j = cupy.expand_dims(atm_coords, axis=0)
#squared_diff = (x_i - x_j)**2
#dist_ij = cupy.sum(squared_diff, axis=2)**0.5

pbecke = cupy.ones([natm, ngrids], order='C')
'''
err = libgdft.GDFTgen_grid_partition(
ctypes.cast(stream.ptr, ctypes.c_void_p),
ctypes.cast(pbecke.data.ptr, ctypes.c_void_p),
Expand All @@ -205,6 +206,17 @@ def gen_grids_partition(atm_coords, coords, a):
ctypes.c_int(ngrids),
ctypes.c_int(natm)
)
'''
atm_coords = cupy.asarray(atm_coords, order='F')
err = libgdft.GDFTgen_grid_partition(
ctypes.cast(stream.ptr, ctypes.c_void_p),
ctypes.cast(pbecke.data.ptr, ctypes.c_void_p),
ctypes.cast(coords.data.ptr, ctypes.c_void_p),
ctypes.cast(atm_coords.data.ptr, ctypes.c_void_p),
ctypes.cast(a.data.ptr, ctypes.c_void_p),
ctypes.c_int(ngrids),
ctypes.c_int(natm)
)
if err != 0:
raise RuntimeError('CUDA Error')
return pbecke
Expand Down
12 changes: 7 additions & 5 deletions gpu4pyscf/dft/numint.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def eval_rho2(mol, ao, mo_coeff, mo_occ, non0tab=None, xctype='LDA',
ao_loc = mol.ao_loc_nr()

#cpos = cupy.einsum('ij,j->ij', mo_coeff[:,mo_occ>0], cupy.sqrt(mo_occ[mo_occ>0]))
cpos = mo_coeff[:,mo_occ>0] * cupy.sqrt(mo_occ[mo_occ>0])
#cpos = mo_coeff[:,mo_occ>0] * cupy.sqrt(mo_occ[mo_occ>0])
cpos = (mo_coeff * mo_occ**0.5)[:,mo_occ>0]
if xctype == 'LDA' or xctype == 'HF':
c0 = _dot_ao_dm(mol, ao, cpos, non0tab, shls_slice, ao_loc)
#:rho = numpy.einsum('pi,pi->p', c0, c0)
Expand All @@ -194,11 +195,12 @@ def eval_rho2(mol, ao, mo_coeff, mo_occ, non0tab=None, xctype='LDA',
rho = cupy.empty((4,ngrids))
c0 = _dot_ao_dm(mol, ao[0], cpos, non0tab, shls_slice, ao_loc)
#:rho[0] = numpy.einsum('pi,pi->p', c0, c0)
rho[0] = _contract_rho(c0, c0)
_contract_rho(c0, c0, rho=rho[0])
for i in range(1, 4):
c1 = _dot_ao_dm(mol, ao[i], cpos, non0tab, shls_slice, ao_loc)
#:rho[i] = numpy.einsum('pi,pi->p', c0, c1) * 2 # *2 for +c.c.
rho[i] = _contract_rho(c0, c1) * 2
_contract_rho(c0, c1, rho=rho[i])
rho[i] *= 2
else: # meta-GGA
if with_lapl:
# rho[4] = \nabla^2 rho, rho[5] = 1/2 |nabla f|^2
Expand All @@ -209,7 +211,7 @@ def eval_rho2(mol, ao, mo_coeff, mo_occ, non0tab=None, xctype='LDA',
tau_idx = 4
c0 = _dot_ao_dm(mol, ao[0], cpos, non0tab, shls_slice, ao_loc)
#:rho[0] = numpy.einsum('pi,pi->p', c0, c0)
rho[0] = _contract_rho(c0, c0)
_contract_rho(c0, c0, rho=rho[0])

rho[tau_idx] = 0
for i in range(1, 4):
Expand Down Expand Up @@ -1212,7 +1214,7 @@ def _block_loop(ni, mol, grids, nao=None, deriv=0, max_memory=2000,
sorted_ao: by default ao_value is sorted for GPU
'''
if grids.coords is None:
grids.build(with_non0tab=True)
grids.build(with_non0tab=False, sort_grids=True)
if nao is None:
nao = mol.nao
ngrids = grids.coords.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion gpu4pyscf/dft/tests/test_ao_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setUpModule():
'''
mol_sph = pyscf.M(
atom=atom,
basis='ccpvdz',
basis='ccpvqz',
spin=None,
cart = 0,
output = '/dev/null')
Expand Down
59 changes: 45 additions & 14 deletions gpu4pyscf/lib/gdft/gen_grids.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,71 @@
#define NATOM_PER_BLOCK 128

__global__
void GDFTgen_grid_kernel(double *pbecke, const double *dist_ig, const double *dist_ij,
const double *a, int ngrids, int natm)
void GDFTgen_grid_kernel(double *pbecke, const double *coords, const double *atm_coords, const double *a,
int ngrids, int natm)
{
int grid_id = blockIdx.x * blockDim.x + threadIdx.x;
const bool active = grid_id < ngrids;

__shared__ double dij_smem[NATOM_PER_BLOCK];
double xg = 0.0;
double yg = 0.0;
double zg = 0.0;
if(active){
xg = coords[3*grid_id+0];
yg = coords[3*grid_id+1];
zg = coords[3*grid_id+2];
}
__shared__ double xj[NATOM_PER_BLOCK];
__shared__ double yj[NATOM_PER_BLOCK];
__shared__ double zj[NATOM_PER_BLOCK];
__shared__ double a_smem[NATOM_PER_BLOCK];
__shared__ double dij_smem[NATOM_PER_BLOCK];

const int tx = threadIdx.x;

for (int atom_i = 0; atom_i < natm; atom_i++){
double xi = atm_coords[atom_i];
double yi = atm_coords[atom_i + natm];
double zi = atm_coords[atom_i + 2*natm];

double becke = 2.0;
double dig = 0.0;
double dx, dy, dz, dig;
if (active){
// distance between grids and atom i
dig = dist_ig[atom_i * ngrids + grid_id];
dx = xg - xi;
dy = yg - yi;
dz = zg - zi;
dig = norm3d(dx, dy, dz);
}
for (int j = 0; j < natm; j+=blockDim.x){
int atom_idx = j + tx;
if (atom_idx < natm){
double xj_t = atm_coords[atom_idx];
double yj_t = atm_coords[atom_idx + natm];
double zj_t = atm_coords[atom_idx + 2*natm];

// distance between atom i and atom j
dij_smem[tx] = dist_ij[atom_i * natm + atom_idx];
dx = xi - xj_t;
dy = yi - yj_t;
dz = zi - zj_t;
double dij = norm3d(dx, dy, dz);

// distance between atom i and atom j
dij_smem[tx] = dij;
xj[tx] = xj_t;
yj[tx] = yj_t;
zj[tx] = zj_t;
a_smem[tx] = a[atom_i * natm + atom_idx];
}
__syncthreads();

for (int l = 0, M = min(NATOM_PER_BLOCK, natm-j); l < M; ++l){
int atom_j = j + l;
// distance between grids and atom j
double djg = 0;
if (active){
djg = dist_ig[atom_j * ngrids + grid_id];
}
dx = xg - xj[l];
dy = yg - yj[l];
dz = zg - zj[l];
double djg = norm3d(dx, dy, dz);

double dij = dij_smem[l];
double aij = a_smem[l];
double g = (atom_i == atom_j) ? 0.0 : (dig - djg) / dij;
Expand Down Expand Up @@ -122,12 +154,11 @@ void GDFTgroup_grids_kernel(int* group_ids, const double* atom_coords, const dou
extern "C"{
__host__
int GDFTgen_grid_partition(cudaStream_t stream, double *pbecke,
const double *dist_ig, const double *dist_ij,
const double *a, int ngrids, int natm)
const double *coords, const double *atm_coords, const double *a, int ngrids, int natm)
{
dim3 threads(NATOM_PER_BLOCK);
dim3 blocks((ngrids+NATOM_PER_BLOCK-1)/NATOM_PER_BLOCK);
GDFTgen_grid_kernel<<<blocks, threads, 0, stream>>>(pbecke, dist_ig, dist_ij, a, ngrids, natm);
GDFTgen_grid_kernel<<<blocks, threads, 0, stream>>>(pbecke, coords, atm_coords, a, ngrids, natm);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess){
fprintf(stderr, "CUDA Error of gen grids: %s\n", cudaGetErrorString(err));
Expand Down
Loading

0 comments on commit 1dff11d

Please sign in to comment.