Skip to content

Commit

Permalink
added temp examples
Browse files Browse the repository at this point in the history
  • Loading branch information
wxj6000 committed Jan 10, 2024
1 parent 886e0c8 commit ef5d281
Show file tree
Hide file tree
Showing 31 changed files with 587 additions and 291 deletions.
46 changes: 46 additions & 0 deletions examples/18-to_gpu0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2023 The GPU4PySCF Authors. All Rights Reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import pyscf
from pyscf import lib
from pyscf.dft import rks
lib.num_threads(8)

atom = '''
O 0.0000000000 -0.0000000000 0.1174000000
H -0.7570000000 -0.0000000000 -0.4696000000
H 0.7570000000 0.0000000000 -0.4696000000
'''

mol = pyscf.M(atom=atom, basis='def2-tzvpp', max_memory=32000, output='./pyscf.log')

mol.verbose = 4
mf = rks.RKS(mol, xc='B3LYP').density_fit()
mf_GPU = mf.to_gpu()

# Compute Energy
e_dft = mf_GPU.kernel()
print(f"total energy = {e_dft}")

# Compute Gradient
g = mf_GPU.nuc_grad_method()
g.max_memory = 20000
g.auxbasis_response = True
g_dft = g.kernel()

# Compute Hessian
h = mf_GPU.Hessian()
h.auxbasis_response = 2
h_dft = h.kernel()
41 changes: 41 additions & 0 deletions examples/19-pcm_optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2023 The GPU4PySCF Authors. All Rights Reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import numpy as np
import pyscf
from pyscf import lib, df
from gpu4pyscf.dft import rks
from pyscf.geomopt.geometric_solver import optimize
lib.num_threads(8)

atom ='''
O 0.0000000000 -0.0000000000 0.1174000000
H -0.7570000000 -0.0000000000 -0.4696000000
H 0.7570000000 0.0000000000 -0.4696000000
'''

mol = pyscf.M(atom=atom, basis='def2-tzvpp', verbose=0)

mf = rks.RKS(mol, xc='HYB_GGA_XC_B3LYP').density_fit()
mf = mf.PCM()
mf.verbose = 3
mf.grids.atom_grid = (99,590)
mf.small_rho_cutoff = 1e-10
mf.with_solvent.lebedev_order = 29 # 302 Lebedev grids
mf.with_solvent.method = 'C-PCM'
mf.with_solvent.eps = 78.3553

mol_eq = optimize(mf, maxsteps=20)

5 changes: 3 additions & 2 deletions examples/dft_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@
parser.add_argument("--solvent", type=str, default='')
args = parser.parse_args()

lib.num_threads(16)
start_time = time.time()
bas = args.basis
mol = pyscf.M(
atom=args.input,
basis=bas,
max_memory=32000)
# set verbose >= 6 for debugging timer
mol.verbose = 4
mol.verbose = 7

mf_df = rks.RKS(mol, xc=args.xc).density_fit(auxbasis=args.auxbasis)
mf_df.verbose = 4
mf_df.verbose = 7

if args.solvent:
mf_df = mf_df.PCM()
Expand Down
2 changes: 1 addition & 1 deletion gpu4pyscf/__config__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# such as A100-80G
if props['totalGlobalMem'] >= 64 * GB:
min_ao_blksize = 256
min_grid_blksize = 256*256
min_grid_blksize = 128*128
ao_aligned = 32
grid_aligned = 128
mem_fraction = 0.9
Expand Down
22 changes: 3 additions & 19 deletions gpu4pyscf/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,6 @@ def build(self, direct_scf_tol=1e-14, omega=None):
auxmol = self.auxmol
self.nao = mol.nao

# cache indices for better performance
nao = mol.nao
tril_row, tril_col = cupy.tril_indices(nao)
tril_row = cupy.asarray(tril_row)
tril_col = cupy.asarray(tril_col)

self.tril_row = tril_row
self.tril_col = tril_col

idx = np.arange(nao)
self.diag_idx = cupy.asarray(idx*(idx+1)//2+idx)

log = logger.new_logger(mol, mol.verbose)
t0 = log.init_timer()
if auxmol is None:
Expand Down Expand Up @@ -234,20 +222,20 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False):
nj = j1 - j0
if sr_only:
# TODO: in-place implementation or short-range kernel
ints_slices = cupy.zeros([naoaux, nj, ni], order='C')
ints_slices = cupy.empty([naoaux, nj, ni], order='C')
for cp_kl_id, _ in enumerate(intopt.aux_log_qs):
k0 = intopt.sph_aux_loc[cp_kl_id]
k1 = intopt.sph_aux_loc[cp_kl_id+1]
int3c2e.get_int3c2e_slice(intopt, cp_ij_id, cp_kl_id, out=ints_slices[k0:k1])
if omega is not None:
ints_slices_lr = cupy.zeros([naoaux, nj, ni], order='C')
ints_slices_lr = cupy.empty([naoaux, nj, ni], order='C')
for cp_kl_id, _ in enumerate(intopt.aux_log_qs):
k0 = intopt.sph_aux_loc[cp_kl_id]
k1 = intopt.sph_aux_loc[cp_kl_id+1]
int3c2e.get_int3c2e_slice(intopt, cp_ij_id, cp_kl_id, out=ints_slices[k0:k1], omega=omega)
ints_slices -= ints_slices_lr
else:
ints_slices = cupy.zeros([naoaux, nj, ni], order='C')
ints_slices = cupy.empty([naoaux, nj, ni], order='C')
for cp_kl_id, _ in enumerate(intopt.aux_log_qs):
k0 = intopt.sph_aux_loc[cp_kl_id]
k1 = intopt.sph_aux_loc[cp_kl_id+1]
Expand All @@ -261,11 +249,7 @@ def cholesky_eri_gpu(intopt, mol, auxmol, cd_low, omega=None, sr_only=False):

row = intopt.ao_pairs_row[cp_ij_id] - i0
col = intopt.ao_pairs_col[cp_ij_id] - j0
if cpi == cpj:
#ints_slices = ints_slices + ints_slices.transpose([0,2,1])
transpose_sum(ints_slices)
ints_slices = ints_slices[:,col,row]

if cd_low.tag == 'eig':
cderi_block = cupy.dot(cd_low.T, ints_slices)
ints_slices = None
Expand Down
4 changes: 2 additions & 2 deletions gpu4pyscf/df/df_jk.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,18 @@ def get_jk(dfobj, dms_tag, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e-
nao = dms_tag.shape[-1]
dms = dms_tag.reshape([-1,nao,nao])
nset = dms.shape[0]
t0 = log.init_timer()
t1 = t0 = log.init_timer()
if dfobj._cderi is None:
log.debug('CDERI not found, build...')
dfobj.build(direct_scf_tol=direct_scf_tol, omega=omega)
t1 = log.timer_debug1('init jk', *t0)

assert nao == dfobj.nao
vj = None
vk = None
ao_idx = dfobj.intopt.sph_ao_idx
dms = take_last2d(dms, ao_idx)

t1 = log.timer_debug1('init jk', *t0)
rows = dfobj.intopt.cderi_row
cols = dfobj.intopt.cderi_col
if with_j:
Expand Down
14 changes: 10 additions & 4 deletions gpu4pyscf/df/hessian/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import cupy
import numpy as np
from pyscf import lib, df
from gpu4pyscf.grad import rhf as rhf_grad
from gpu4pyscf.hessian import rhf as rhf_hess
from gpu4pyscf.lib.cupy_helper import contract, tag_array, release_gpu_stack, print_mem_info, take_last2d
from gpu4pyscf.df import int3c2e
Expand Down Expand Up @@ -283,6 +284,8 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
hk_aux_aux += .5 * contract('pqxy,pq->pqxy', rho2c_11, int2c_inv) # (00|1)(1|00)
rho2c_0 = rho2c_10 = rho2c_11 = rho2c0_10 = rho2c1_10 = rho2c0_11 = int2c_ip_ip = None
wk_ip2_P__ = int2c_ip1_inv = None
t1 = log.timer_debug1('contract int2c_*', *t1)

ao_idx = np.argsort(intopt.sph_ao_idx)
aux_idx = np.argsort(intopt.sph_aux_idx)
rev_ao_ao = cupy.ix_(ao_idx, ao_idx)
Expand Down Expand Up @@ -372,7 +375,7 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
e1[j0,i0] = e1[i0,j0].T
ej[j0,i0] = ej[i0,j0].T
ek[j0,i0] = ek[i0,j0].T

t1 = log.timer_debug1('hcore contribution', *t1)
log.timer('RHF partial hessian', *time0)
return e1, ej, ek

Expand All @@ -398,6 +401,7 @@ def make_h1(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None, verbose=None):
else:
return chkfile
'''

def _gen_jk(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None,
verbose=None, with_k=True, omega=None):
log = logger.new_logger(hessobj, verbose)
Expand Down Expand Up @@ -521,8 +525,9 @@ def _ao2mo(mat):
vk1_int3c = vk1_int3c_ip1 + vk1_int3c_ip2
vk1_int3c_ip1 = vk1_int3c_ip2 = None

grad_hcore = rhf_grad.get_grad_hcore(hessobj.base.nuc_grad_method())
cupy.get_default_memory_pool().free_all_blocks()
hcore_deriv = hessobj.base.nuc_grad_method().hcore_generator(mol)
#hcore_deriv = hessobj.base.nuc_grad_method().hcore_generator(mol)
vk1 = None
for i0, ia in enumerate(atmlst):
shl0, shl1, p0, p1 = aoslices[ia]
Expand All @@ -535,8 +540,9 @@ def _ao2mo(mat):
vk1_ao[:,p0:p1,:] -= vk1_buf[:,p0:p1,:]
vk1_ao[:,:,p0:p1] -= vk1_buf[:,p0:p1,:].transpose(0,2,1)

h1 = hcore_deriv(ia)
h1 = _ao2mo(cupy.asarray(h1, order='C'))
h1 = grad_hcore[:,i0]
#h1 = hcore_deriv(ia)
#h1 = _ao2mo(cupy.asarray(h1, order='C'))
vj1 = vj1_int3c[ia] + _ao2mo(vj1_ao)
if with_k:
vk1 = vk1_int3c[ia] + _ao2mo(vk1_ao)
Expand Down
46 changes: 27 additions & 19 deletions gpu4pyscf/df/int3c2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,7 @@ def build(self, cutoff=1e-14, group_size=None,
ncptype = len(log_qs)

self.bpcache = ctypes.POINTER(BasisProdCache)()
if diag_block_with_triu:
scale_shellpair_diag = 1.
else:
scale_shellpair_diag = 0.5
scale_shellpair_diag = 1.
libgint.GINTinit_basis_prod(
ctypes.byref(self.bpcache), ctypes.c_double(scale_shellpair_diag),
ao_loc.ctypes.data_as(ctypes.c_void_p),
Expand Down Expand Up @@ -1194,6 +1191,32 @@ def get_dh1e(mol, dm0):
dh1e[k0:k1,:3] += cupy.einsum('xkji,ij->kx', int3c_blk, dm0_sorted[i0:i1,j0:j1])
return 2.0 * cupy.einsum('kx,k->kx', dh1e, -charges)

def get_d2h1e(mol, dm0):
natm = mol.natm
coords = mol.atom_coords()
charges = mol.atom_charges()
fakemol = gto.fakemol_for_charges(coords)

nao = mol.nao
d2h1e_diag = cupy.zeros([natm,9])
d2h1e_offdiag = cupy.zeros([natm, nao, 9])
intopt = VHFOpt(mol, fakemol, 'int2e')
intopt.build(1e-14, diag_block_with_triu=True, aosym=False, group_size=BLKSIZE, group_size_aux=BLKSIZE)
dm0_sorted = take_last2d(dm0, intopt.sph_ao_idx)
for i0,i1,j0,j1,k0,k1,int3c_blk in loop_int3c2e_general(intopt, ip_type='ipip1'):
d2h1e_diag[k0:k1,:9] -= contract('xaji,ij->ax', int3c_blk, dm0_sorted[i0:i1,j0:j1])
d2h1e_offdiag[k0:k1,i0:i1,:9] += contract('xaji,ij->aix', int3c_blk, dm0_sorted[i0:i1,j0:j1])

for i0,i1,j0,j1,k0,k1,int3c_blk in loop_int3c2e_general(intopt, ip_type='ipvip1'):
d2h1e_diag[k0:k1,:9] -= contract('xaji,ij->ax', int3c_blk, dm0_sorted[i0:i1,j0:j1])
d2h1e_offdiag[k0:k1,i0:i1,:9] += contract('xaji,ij->aix', int3c_blk, dm0_sorted[i0:i1,j0:j1])
aoslices = mol.aoslice_by_atom()
ao2atom = get_ao2atom(intopt, aoslices)
d2h1e = contract('aix,ib->abx', d2h1e_offdiag, ao2atom)
d2h1e[np.diag_indices(natm), :] += d2h1e_diag
return 2.0 * cupy.einsum('abx,a->xab', d2h1e, charges)
#return 2.0 * cupy.einsum('ijx,i->kx', dh1e, -charges)

def get_int3c2e_slice(intopt, cp_ij_id, cp_aux_id, aosym=None, out=None, omega=None, stream=None):
'''
Generate one int3c2e block for given ij, k
Expand Down Expand Up @@ -1443,14 +1466,6 @@ def get_pairing(p_offsets, q_offsets, q_cond,
for q0, q1 in zip(q_offsets[:-1], q_offsets[1:]):
if aosym and q0 < p0 or not aosym:
q_sub = q_cond[p0:p1,q0:q1].ravel()
'''
idx = q_sub.argsort(axis=None)[::-1]
q_sorted = q_sub[idx]
mask = q_sorted > cutoff
idx = idx[mask]
ishs, jshs = np.unravel_index(idx, (p1-p0, q1-q0))
print(ishs.shape)
'''
mask = q_sub > cutoff
ishs, jshs = np.indices((p1-p0,q1-q0))
ishs = ishs.ravel()[mask]
Expand All @@ -1464,13 +1479,6 @@ def get_pairing(p_offsets, q_offsets, q_cond,
log_qs.append(log_q)
elif aosym and p0 == q0 and p1 == q1:
q_sub = q_cond[p0:p1,p0:p1].ravel()
'''
idx = q_sub.argsort(axis=None)[::-1]
q_sorted = q_sub[idx]
ishs, jshs = np.unravel_index(idx, (p1-p0, p1-p0))
mask = q_sorted > cutoff
'''

ishs, jshs = np.indices((p1-p0, p1-p0))
ishs = ishs.ravel()
jshs = jshs.ravel()
Expand Down
26 changes: 26 additions & 0 deletions gpu4pyscf/df/tests/test_int3c2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,32 @@ def test_int1e_iprinv(self):
h1ao = mol.intor('int1e_iprinv', comp=3) # <\nabla|1/r|>
assert np.linalg.norm(int3c[:,:,:,i] - h1ao) < 1e-8

def test_int1e_ipiprinv(self):
from pyscf import gto
coords = mol.atom_coords()
charges = mol.atom_charges()

fakemol = gto.fakemol_for_charges(coords)
int3c = int3c2e.get_int3c2e_general(mol, fakemol, ip_type='ipip1').get()

for i,q in enumerate(charges):
mol.set_rinv_origin(coords[i])
h1ao = mol.intor('int1e_ipiprinv', comp=9) # <\nabla|1/r|>
assert np.linalg.norm(int3c[:,:,:,i] - h1ao) < 1e-8

def test_int1e_iprinvip(self):
from pyscf import gto
coords = mol.atom_coords()
charges = mol.atom_charges()

fakemol = gto.fakemol_for_charges(coords)
int3c = int3c2e.get_int3c2e_general(mol, fakemol, ip_type='ipvip1').get()

for i,q in enumerate(charges):
mol.set_rinv_origin(coords[i])
h1ao = mol.intor('int1e_iprinvip', comp=9) # <\nabla|1/r|>
assert np.linalg.norm(int3c[:,:,:,i] - h1ao) < 1e-8

if __name__ == "__main__":
print("Full Tests for int3c")
unittest.main()
8 changes: 5 additions & 3 deletions gpu4pyscf/dft/gen_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,16 @@ def get_partition(mol, atom_grids_tab,
grid_coord and grid_weight arrays. grid_coord array has shape (N,3);
weight 1D array has N elements.
'''
atm_coords = numpy.asarray(mol.atom_coords() , order='C')
atm_coords = cupy.asarray(atm_coords)
'''
if callable(radii_adjust) and atomic_radii is not None:
f_radii_adjust = radii_adjust(mol, atomic_radii)
else:
f_radii_adjust = None
atm_coords = numpy.asarray(mol.atom_coords() , order='C')
atm_dist = gto.inter_distance(mol)
atm_coords = cupy.asarray(atm_coords)
atm_dist = cupy.asarray(atm_dist)
if (becke_scheme is original_becke and
(radii_adjust is radi.treutler_atomic_radii_adjust or
radii_adjust is radi.becke_atomic_radii_adjust or
Expand Down Expand Up @@ -324,7 +326,7 @@ def gen_grid_partition(coords):
pbecke[i] *= .5 * (1-g)
pbecke[j] *= .5 * (1+g)
return pbecke

'''
coords_all = []
weights_all = []
# support atomic_radii_adjust = None
Expand Down
Loading

0 comments on commit ef5d281

Please sign in to comment.