Skip to content

Commit

Permalink
Merge branch 'master' into bugfix-v0.6.9
Browse files Browse the repository at this point in the history
  • Loading branch information
wxj6000 authored Dec 9, 2023
2 parents ef0722d + c818905 commit 3aa5f53
Show file tree
Hide file tree
Showing 11 changed files with 411 additions and 282 deletions.
2 changes: 1 addition & 1 deletion gpu4pyscf/df/df_jk.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def build_df():
ni = mf._numint
rks.initialize_grids(mf, mf.mol, dm0)
ni.build(mf.mol, mf.grids.coords)
mf._numint.xcfuns = numint._init_xcfuns(mf.xc)
mf._numint.xcfuns = numint._init_xcfuns(mf.xc, dm0.ndim==3)
dm0 = cupy.asarray(dm0)
return

Expand Down
33 changes: 19 additions & 14 deletions gpu4pyscf/df/hessian/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,9 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
mo_coeff = cupy.asarray(mo_coeff, order='C')
nao, nmo = mo_coeff.shape
mocc = mo_coeff[:,mo_occ>0]
mocc_2 = cupy.einsum('pi,i->pi', mocc, mo_occ[mo_occ>0]**.5)
mocc_2 = mocc * mo_occ[mo_occ>0]**.5
dm0 = cupy.dot(mocc, mocc.T) * 2

hcore_deriv = hessobj.hcore_generator(mol)

# ------------------------------------
# overlap matrix contributions
# ------------------------------------
Expand Down Expand Up @@ -107,10 +105,10 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
int2c = cupy.asarray(int2c, order='C')
int2c = take_last2d(int2c, sph_aux_idx)
int2c_inv = cupy.linalg.pinv(int2c, rcond=1e-12)
int2c = None

int2c_ip1 = cupy.asarray(int2c_ip1, order='C')
int2c_ip1 = take_last2d(int2c_ip1, sph_aux_idx)
int2c_ip1_inv = contract('yqp,pr->yqr', int2c_ip1, int2c_inv)

hj_ao_ao = cupy.zeros([nao,nao,3,3])
hk_ao_ao = cupy.zeros([nao,nao,3,3])
Expand Down Expand Up @@ -143,10 +141,10 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
hj_ao_aux -= contract('q,iqxy->iqxy', rhoj0_P, wj1_01) # (10|0)(0|1)(0|00)
wj1_01 = None

int2c_ip1_inv = contract('yqp,pr->yqr', int2c_ip1, int2c_inv)
if with_k:
if hessobj.auxbasis_response:
wk1_P__ = contract('ypq,qor->ypor', int2c_ip1, rhok0_P__)
int2c_ip1_inv = cupy.asarray(int2c_ip1_inv)

for i0, i1 in lib.prange(0,nao,64):
wk1_Pko_islice = cupy.asarray(wk1_Pko[i0:i1])
Expand Down Expand Up @@ -232,7 +230,8 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
int2c_ipip1 = take_last2d(int2c_ipip1, sph_aux_idx)
rhoj2c_P = contract('xpq,q->xp', int2c_ipip1, rhoj0_P)
# (00|0)(2|0)(0|00)
hj_aux_diag -= cupy.einsum('p,xp->px', rhoj0_P, rhoj2c_P).reshape(-1,3,3)
# p,xp->px
hj_aux_diag -= (rhoj0_P*rhoj2c_P).T.reshape(-1,3,3)
if with_k:
rho2c_0 = contract('pij,qji->pq', rhok0_P__, rhok0_P__)
hk_aux_diag -= .5 * contract('pq,xpq->px', rho2c_0, int2c_ipip1).reshape(-1,3,3)
Expand All @@ -245,7 +244,7 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
int2c_ip1ip2 = auxmol.intor('int2c2e_ip1ip2', aosym='s1')
int2c_ip1ip2 = cupy.asarray(int2c_ip1ip2, order='C')
int2c_ip1ip2 = take_last2d(int2c_ip1ip2, sph_aux_idx)
hj_aux_aux = -.5 * cupy.einsum('p,xpq,q->pqx', rhoj0_P, int2c_ip1ip2, rhoj0_P).reshape(naux, naux,3,3)
hj_aux_aux = -.5 * contract('p,xpq->pqx', rhoj0_P, int2c_ip1ip2*rhoj0_P).reshape(naux, naux,3,3)
if with_k:
hk_aux_aux = -.5 * contract('xpq,pq->pqx', int2c_ip1ip2, rho2c_0).reshape(naux,naux,3,3)
t1 = log.timer_debug1('intermediate variables with int2c_*', *t1)
Expand Down Expand Up @@ -309,7 +308,9 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,

#======================================== sort AO end ===========================================
# Energy weighted density matrix
dme0 = cupy.einsum('pi,qi,i->pq', mocc, mocc, mo_energy[mo_occ>0]) * 2.0
# pi,qi,i->pq
dme0 = cupy.dot(mocc, (mocc * mo_energy[mo_occ>0] * 2).T)
hcore_deriv = hessobj.hcore_generator(mol)
# -----------------------------------------
# collecting all
# -----------------------------------------
Expand All @@ -318,18 +319,18 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
ek = cupy.zeros([len(atmlst),len(atmlst),3,3])
for i0, ia in enumerate(atmlst):
shl0, shl1, p0, p1 = aoslices[ia]
e1[i0,i0] -= cupy.einsum('xypq,pq->xy', s1aa[:,:,p0:p1], dme0[p0:p1]) * 2.0
e1[i0,i0] -= contract('xypq,pq->xy', s1aa[:,:,p0:p1], dme0[p0:p1]) * 2.0
ej[i0,i0] += cupy.sum(hj_ao_diag[p0:p1,:,:], axis=0)
if with_k:
ek[i0,i0] += cupy.sum(hk_ao_diag[p0:p1,:,:], axis=0)
for j0, ja in enumerate(atmlst[:i0+1]):
q0, q1 = aoslices[ja][2:]
ej[i0,j0] += cupy.sum(hj_ao_ao[p0:p1,q0:q1], axis=[0,1])
e1[i0,j0] -= 2.0 * cupy.einsum('xypq,pq->xy', s1ab[:,:,p0:p1,q0:q1], dme0[p0:p1,q0:q1])
e1[i0,j0] -= 2.0 * contract('xypq,pq->xy', s1ab[:,:,p0:p1,q0:q1], dme0[p0:p1,q0:q1])
if with_k:
ek[i0,j0] += cupy.sum(hk_ao_ao[p0:p1,q0:q1], axis=[0,1])
h1ao = hcore_deriv(ia, ja)
e1[i0,j0] += cupy.einsum('xypq,pq->xy', h1ao, dm0)
e1[i0,j0] += contract('xypq,pq->xy', h1ao, dm0)
#
# The first order RI basis response
#
Expand Down Expand Up @@ -425,7 +426,6 @@ def _gen_jk(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None,
intopt.build(mf.direct_scf_tol, diag_block_with_triu=True, aosym=False, group_size_aux=BLKSIZE, group_size=BLKSIZE)
sph_ao_idx = intopt.sph_ao_idx
sph_aux_idx = intopt.sph_aux_idx
rev_ao_idx = np.argsort(intopt.sph_ao_idx)

mocc = mocc[sph_ao_idx, :]
mo_coeff = mo_coeff[sph_ao_idx,:]
Expand All @@ -434,22 +434,27 @@ def _gen_jk(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None,

int2c = take_last2d(int2c, sph_aux_idx)
int2c_inv = cupy.linalg.pinv(int2c, rcond=1e-12)
int2c = None

wj, wk_Pl_, wk_P__ = int3c2e.get_int3c2e_wjk(mol, auxmol, dm0_tag, omega=omega)
rhoj0 = contract('pq,q->p', int2c_inv, wj)
if with_k:
rhok0_P__ = contract('pq,qij->pij', int2c_inv, wk_P__)
wj = wk_P__ = None
if isinstance(wk_Pl_, cupy.ndarray):
rhok0_Pl_ = contract('pq,qio->pio', int2c_inv, wk_Pl_)
else:
rhok0_Pl_ = np.empty_like(wk_Pl_)
for p0, p1 in lib.prange(0,nao,64):
wk_tmp = cupy.asarray(wk_Pl_[:,p0:p1])
rhok0_Pl_[:,p0:p1] = contract('pq,qio->pio', int2c_inv, wk_tmp).get()
wj = wk_Pl_ = wk_P__ = int2c_inv = int2c = None
wk_tmp = None
wk_Pl_ = int2c_inv = None

# int3c_ip1 contributions
cupy.get_default_memory_pool().free_all_blocks()
vj1_buf, vk1_buf, vj1_ao, vk1_ao = int3c2e.get_int3c2e_ip1_vjk(intopt, rhoj0, rhok0_Pl_, dm0_tag, aoslices, omega=omega)
rev_ao_idx = np.argsort(sph_ao_idx)
vj1_buf = take_last2d(vj1_buf, rev_ao_idx)
vk1_buf = take_last2d(vk1_buf, rev_ao_idx)

Expand Down Expand Up @@ -489,7 +494,7 @@ def _gen_jk(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None,
vk1_tmp += 2.0 * contract('xpro,pir->xpio', wk0_10_P__, rhok_tmp)
vk1_int3c_ip2[:,:,p0:p1] += contract('xpio,pa->axio', vk1_tmp, aux2atom)
wj0_10 = wk0_10_P__ = rhok0_P__ = int2c_ip1 = None
vj1_tmp = vk1_tmp = wk0_10_Pl_ = rhoj0 = rhok0_Pl_ = None
rhok_tmp = vj1_tmp = vk1_tmp = wk0_10_Pl_ = rhoj0 = rhok0_Pl_ = None
aux2atom = None

vj1_int3c_ip2 = contract('nxiq,ip->nxpq', vj1_int3c_ip2, mo_coeff)
Expand Down
46 changes: 26 additions & 20 deletions gpu4pyscf/dft/libxc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
import ctypes
import cupy
from pyscf import dft
from gpu4pyscf.dft.libxc_structs import xc_func_type

_libxc = np.ctypeslib.load_library(
'libxc', os.path.abspath(os.path.join(__file__, '..', '..', 'lib', 'deps', 'lib')))

def _check_arrays(current_arrays, fields, factor, required):
def _check_arrays(current_arrays, fields, sizes, factor, required):
"""
A specialized function built to construct and check the sizes of arrays given to the LibXCFunctional class.
"""
Expand All @@ -35,7 +36,8 @@ def _check_arrays(current_arrays, fields, factor, required):

for label in fields:
if required:
current_arrays[label] = cupy.zeros((factor, 1))
size = sizes[label]
current_arrays[label] = cupy.zeros((factor, size), dtype=np.float64)
else:
current_arrays[label] = None # cupy.empty((1))

Expand All @@ -44,16 +46,15 @@ def _check_arrays(current_arrays, fields, factor, required):
class _xcfun(ctypes.Structure):
pass

_xc_func_p = ctypes.POINTER(_xcfun)
_xc_func_p = ctypes.POINTER(xc_func_type)
_libxc.xc_func_alloc.restype = _xc_func_p
_libxc.xc_func_init.argtypes = (_xc_func_p, ctypes.c_int, ctypes.c_int)
_libxc.xc_func_end.argtypes = (_xc_func_p, )
_libxc.xc_func_free.argtypes = (_xc_func_p, )

class XCfun:
def __init__(self, xc, spin):
assert spin == 'unpolarized'
self._spin = 1
self._spin = 1 if spin == 'unpolarized' else 2
self.xc_func = _libxc.xc_func_alloc()
if isinstance(xc, str):
self.func_id = _libxc.xc_functional_get_number(ctypes.c_char_p(xc.encode()))
Expand All @@ -64,6 +65,10 @@ def __init__(self, xc, spin):
raise RuntimeError('failed to initialize xc fun')
self._family = dft.libxc.xc_type(xc)

self.xc_func_sizes = {}
for attr in dir(self.xc_func.contents.dim):
if "_" not in attr:
self.xc_func_sizes[attr] = getattr(self.xc_func.contents.dim, attr)
def __del__(self):
if self.xc_func is None:
return
Expand All @@ -89,6 +94,7 @@ def compute(self, inp, output=None, do_exc=True, do_vxc=True, do_fxc=False, do_k

# Find the right compute function
args = [self.xc_func, ctypes.c_size_t(npoints)]
xc_func_sizes = self.xc_func_sizes
if self._family == 'LDA':
input_labels = ["rho"]
input_num_args = 1
Expand All @@ -102,11 +108,11 @@ def compute(self, inp, output=None, do_exc=True, do_vxc=True, do_fxc=False, do_k
]

# Build input args
output = _check_arrays(output, output_labels[0:1], npoints, do_exc)
output = _check_arrays(output, output_labels[1:2], npoints, do_vxc)
output = _check_arrays(output, output_labels[2:3], npoints, do_fxc)
output = _check_arrays(output, output_labels[3:4], npoints, do_kxc)
output = _check_arrays(output, output_labels[4:5], npoints, do_lxc)
output = _check_arrays(output, output_labels[0:1], xc_func_sizes, npoints, do_exc)
output = _check_arrays(output, output_labels[1:2], xc_func_sizes, npoints, do_vxc)
output = _check_arrays(output, output_labels[2:3], xc_func_sizes, npoints, do_fxc)
output = _check_arrays(output, output_labels[3:4], xc_func_sizes, npoints, do_kxc)
output = _check_arrays(output, output_labels[4:5], xc_func_sizes, npoints, do_lxc)

args.extend([ inp[x] for x in input_labels])
args.extend([output[x] for x in output_labels])
Expand All @@ -129,11 +135,11 @@ def compute(self, inp, output=None, do_exc=True, do_vxc=True, do_fxc=False, do_k
]

# Build input args
output = _check_arrays(output, output_labels[0:1], npoints, do_exc)
output = _check_arrays(output, output_labels[1:3], npoints, do_vxc)
output = _check_arrays(output, output_labels[3:6], npoints, do_fxc)
output = _check_arrays(output, output_labels[6:10], npoints, do_kxc)
output = _check_arrays(output, output_labels[10:15], npoints, do_lxc)
output = _check_arrays(output, output_labels[0:1], xc_func_sizes, npoints, do_exc)
output = _check_arrays(output, output_labels[1:3], xc_func_sizes, npoints, do_vxc)
output = _check_arrays(output, output_labels[3:6], xc_func_sizes, npoints, do_fxc)
output = _check_arrays(output, output_labels[6:10], xc_func_sizes, npoints, do_kxc)
output = _check_arrays(output, output_labels[10:15], xc_func_sizes, npoints, do_lxc)

args.extend([ inp[x] for x in input_labels])
args.extend([output[x] for x in output_labels])
Expand Down Expand Up @@ -174,11 +180,11 @@ def compute(self, inp, output=None, do_exc=True, do_vxc=True, do_fxc=False, do_k
]

# Build input args
output = _check_arrays(output, output_labels[0:1], npoints, do_exc)
output = _check_arrays(output, output_labels[1:5], npoints, do_vxc)
output = _check_arrays(output, output_labels[5:15], npoints, do_fxc)
output = _check_arrays(output, output_labels[15:35], npoints, do_kxc)
output = _check_arrays(output, output_labels[35:70], npoints, do_lxc)
output = _check_arrays(output, output_labels[0:1], xc_func_sizes, npoints, do_exc)
output = _check_arrays(output, output_labels[1:5], xc_func_sizes, npoints, do_vxc)
output = _check_arrays(output, output_labels[5:15], xc_func_sizes, npoints, do_fxc)
output = _check_arrays(output, output_labels[15:35], xc_func_sizes, npoints, do_kxc)
output = _check_arrays(output, output_labels[35:70], xc_func_sizes, npoints, do_lxc)

args.extend([ inp[x] for x in input_labels])
if not self.needs_laplacian():
Expand Down
Loading

0 comments on commit 3aa5f53

Please sign in to comment.