Skip to content

Commit

Permalink
replace cart2sph with cuda kernel (pyscf#137)
Browse files Browse the repository at this point in the history
* cart2sph kernel

* raise error in cart2sph kernel
  • Loading branch information
wxj6000 authored Apr 12, 2024
1 parent 86b977a commit 6047760
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 4 deletions.
52 changes: 52 additions & 0 deletions benchmarks/cupy_helper/benchmark_cart2sph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2024 The GPU4PySCF Authors. All Rights Reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import numpy as np
import cupy
from cupyx import profiler
from gpu4pyscf.lib.cupy_helper import cart2sph, cart2sph_cutensor

print('benchmarking cart2sph when ang=2')
a = cupy.random.random([512,6*128,512])
b = cupy.random.random([512,5*128,512])
perf_kernel = profiler.benchmark(cart2sph, (a,1,2,b), n_repeat=20, n_warmup=3)
perf_cutensor = profiler.benchmark(cart2sph_cutensor, (a,1,2), n_repeat=20, n_warmup=3)
t_kernel = perf_kernel.gpu_times.mean()
t_cutensor = perf_cutensor.gpu_times.mean()
print('kernel:', t_kernel)
print('cutensor:', t_cutensor)
print('memory bandwidth:',(a.nbytes+b.nbytes)/t_kernel/1024**3, 'GB/s')

print('benchmarking cart2sph when ang=3')
a = cupy.random.random([512,10*128,512])
b = cupy.random.random([512,7*128,512])
perf_kernel = profiler.benchmark(cart2sph, (a,1,3,b), n_repeat=20, n_warmup=3)
perf_cutensor = profiler.benchmark(cart2sph_cutensor, (a,1,3), n_repeat=20, n_warmup=3)
t_kernel = perf_kernel.gpu_times.mean()
t_cutensor = perf_cutensor.gpu_times.mean()
print('kernel:', t_kernel)
print('cutensor:', t_cutensor)
print('memory bandwidth:',(a.nbytes+b.nbytes)/t_kernel/1024**3, 'GB/s')

print('benchmarking cart2sph when ang=4')
a = cupy.random.random([512,15*128,512])
b = cupy.random.random([512,9*128,512])
perf_kernel = profiler.benchmark(cart2sph, (a,1,4,b), n_repeat=20, n_warmup=3)
perf_cutensor = profiler.benchmark(cart2sph_cutensor, (a,1,4), n_repeat=20, n_warmup=3)
t_kernel = perf_kernel.gpu_times.mean()
t_cutensor = perf_cutensor.gpu_times.mean()
print('kernel:', t_kernel)
print('cutensor:', t_cutensor)
print('memory bandwidth:',(a.nbytes+b.nbytes)/t_kernel/1024**3, 'GB/s')
4 changes: 2 additions & 2 deletions examples/dft_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
basis=bas,
max_memory=32000)
# set verbose >= 6 for debugging timer
mol.verbose = 7
mol.verbose = 4

if args.unrestricted:
mf_df = uks.UKS(mol, xc=args.xc).density_fit(auxbasis=args.auxbasis)
else:
mf_df = rks.RKS(mol, xc=args.xc).density_fit(auxbasis=args.auxbasis)
mf_df.verbose = 7
mf_df.verbose = 4

if args.solvent:
mf_df = mf_df.PCM()
Expand Down
38 changes: 37 additions & 1 deletion gpu4pyscf/lib/cupy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def hermi_triu(mat, hermi=1, inplace=True):

return mat

def cart2sph(t, axis=0, ang=1, out=None):
def cart2sph_cutensor(t, axis=0, ang=1, out=None):
'''
transform 'axis' of a tensor from cartesian basis into spherical basis
'''
Expand All @@ -415,6 +415,42 @@ def cart2sph(t, axis=0, ang=1, out=None):
t_sph = contract('min,ip->mpn', t_cart, c2s, out=out)
return t_sph.reshape(out_shape)

def cart2sph(t, axis=0, ang=1, out=None, stream=None):
'''
transform 'axis' of a tensor from cartesian basis into spherical basis
'''
if(ang <= 1):
if(out is not None): out[:] = t
return t
size = list(t.shape)
c2s = c2s_l[ang]
if(not t.flags['C_CONTIGUOUS']): t = cupy.asarray(t, order='C')
li_size = c2s.shape
nli = size[axis] // li_size[0]
i0 = max(1, np.prod(size[:axis]))
i3 = max(1, np.prod(size[axis+1:]))
out_shape = size[:axis] + [nli*li_size[1]] + size[axis+1:]

t_cart = t.reshape([i0*nli, li_size[0], i3])
if(out is not None):
out = out.reshape([i0*nli, li_size[1], i3])
else:
out = cupy.empty(out_shape)
count = i0*nli*i3
if stream is None:
stream = cupy.cuda.get_current_stream()
err = libcupy_helper.cart2sph(
ctypes.cast(stream.ptr, ctypes.c_void_p),
ctypes.cast(t_cart.data.ptr, ctypes.c_void_p),
ctypes.cast(out.data.ptr, ctypes.c_void_p),
ctypes.c_int(i3),
ctypes.c_int(count),
ctypes.c_int(ang)
)
if err != 0:
raise RuntimeError('failed in cart2sph kernel')
return out.reshape(out_shape)

# a copy with modification from
# https://github.com/pyscf/pyscf/blob/9219058ac0a1bcdd8058166cad0fb9127b82e9bf/pyscf/lib/linalg_helper.py#L1536
def krylov(aop, b, x0=None, tol=1e-10, max_cycle=30, dot=cupy.dot,
Expand Down
1 change: 1 addition & 0 deletions gpu4pyscf/lib/cupy_helper/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_library(cupy_helper SHARED
dist_matrix.cu
grouped_gemm.cu
grouped_dot.cu
cart2sph.cu
)

add_dependencies(cupy_helper cutlass)
Expand Down
137 changes: 137 additions & 0 deletions gpu4pyscf/lib/cupy_helper/cart2sph.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/* Copyright 2024 The GPU4PySCF Authors. All Rights Reserved.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

#include <cuda_runtime.h>
#include <stdio.h>

#define THREADS 128

// (n,ncart,stride) -> (n,nsph,stride), count = n*stride
__global__
static void _cart2sph_ang2(double *cart, double *sph, int stride, int count){
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= count){
return;
}
int i = idx / stride;
int j = idx % stride;
int sph_offset = 5 * stride * i + j;
int cart_offset = 6 * stride * i + j;
double g0 = cart[cart_offset+0*stride];
double g1 = cart[cart_offset+1*stride];
double g2 = cart[cart_offset+2*stride];
double g3 = cart[cart_offset+3*stride];
double g4 = cart[cart_offset+4*stride];
double g5 = cart[cart_offset+5*stride];

sph[sph_offset+0*stride] = 1.092548430592079070 * g1;
sph[sph_offset+1*stride] = 1.092548430592079070 * g4;
sph[sph_offset+2*stride] = 0.630783130505040012 * g5 - 0.315391565252520002 * (g0 + g3);
sph[sph_offset+3*stride] = 1.092548430592079070 * g2;
sph[sph_offset+4*stride] = 0.546274215296039535 * (g0 - g3);
}

__global__
static void _cart2sph_ang3(double *cart, double *sph, int stride, int count){
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= count){
return;
}
int i = idx / stride;
int j = idx % stride;
int sph_offset = 7 * stride * i + j;
int cart_offset = 10 * stride * i + j;
double g0 = cart[cart_offset+0*stride];
double g1 = cart[cart_offset+1*stride];
double g2 = cart[cart_offset+2*stride];
double g3 = cart[cart_offset+3*stride];
double g4 = cart[cart_offset+4*stride];
double g5 = cart[cart_offset+5*stride];
double g6 = cart[cart_offset+6*stride];
double g7 = cart[cart_offset+7*stride];
double g8 = cart[cart_offset+8*stride];
double g9 = cart[cart_offset+9*stride];

sph[sph_offset+0*stride] = 1.770130769779930531 * g1 - 0.590043589926643510 * g6;
sph[sph_offset+1*stride] = 2.890611442640554055 * g4;
sph[sph_offset+2*stride] = 1.828183197857862944 * g8 - 0.457045799464465739 * (g1 + g6);
sph[sph_offset+3*stride] = 0.746352665180230782 * g9 - 1.119528997770346170 * (g2 + g7);
sph[sph_offset+4*stride] = 1.828183197857862944 * g5 - 0.457045799464465739 * (g0 + g3);
sph[sph_offset+5*stride] = 1.445305721320277020 * (g2 - g7);
sph[sph_offset+6*stride] = 0.590043589926643510 * g0 - 1.770130769779930530 * g3;
}

__global__
static void _cart2sph_ang4(double *cart, double *sph, int stride, int count){
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= count){
return;
}
int i = idx / stride;
int j = idx % stride;
int sph_offset = 9 * stride * i + j;
int cart_offset = 15 * stride * i + j;
double g0 = cart[cart_offset+0*stride];
double g1 = cart[cart_offset+1*stride];
double g2 = cart[cart_offset+2*stride];
double g3 = cart[cart_offset+3*stride];
double g4 = cart[cart_offset+4*stride];
double g5 = cart[cart_offset+5*stride];
double g6 = cart[cart_offset+6*stride];
double g7 = cart[cart_offset+7*stride];
double g8 = cart[cart_offset+8*stride];
double g9 = cart[cart_offset+9*stride];
double g10 = cart[cart_offset+10*stride];
double g11 = cart[cart_offset+11*stride];
double g12 = cart[cart_offset+12*stride];
double g13 = cart[cart_offset+13*stride];
double g14 = cart[cart_offset+14*stride];

sph[sph_offset+0*stride] = 2.503342941796704538 * (g1 - g6);
sph[sph_offset+1*stride] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11;
sph[sph_offset+2*stride] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6);
sph[sph_offset+3*stride] = 2.676186174229156671 * g13- 2.007139630671867500 * (g4 + g11);
sph[sph_offset+4*stride] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14;
sph[sph_offset+5*stride] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7);
sph[sph_offset+6*stride] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0);
sph[sph_offset+7*stride] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ;
sph[sph_offset+8*stride] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3;
}

extern "C" {
__host__
int cart2sph(cudaStream_t stream, double *cart_gto, double *sph_gto, int stride, int count, int ang)
{
dim3 threads(THREADS);
dim3 blocks((count + THREADS - 1)/THREADS);
switch (ang) {
case 0: break;
case 1: break;
case 2: _cart2sph_ang2 <<<blocks, threads, 0, stream>>> (cart_gto, sph_gto, stride, count); break;
case 3: _cart2sph_ang3 <<<blocks, threads, 0, stream>>> (cart_gto, sph_gto, stride, count); break;
case 4: _cart2sph_ang4 <<<blocks, threads, 0, stream>>> (cart_gto, sph_gto, stride, count); break;
default:
fprintf(stderr, "Ang > 4 is not supported!");
return 1;
}

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
return 1;
}
return 0;
}
}
18 changes: 17 additions & 1 deletion gpu4pyscf/lib/tests/test_cupy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from gpu4pyscf.lib.cupy_helper import (
take_last2d, transpose_sum, krylov, unpack_sparse,
add_sparse, takebak, empty_mapped, dist_matrix,
grouped_dot, grouped_gemm, cond)
grouped_dot, grouped_gemm, cond, cart2sph_cutensor, cart2sph)

class KnownValues(unittest.TestCase):
def test_take_last2d(self):
Expand Down Expand Up @@ -160,6 +160,22 @@ def generate_problems(problems):
assert(cupy.linalg.norm(res_Cs - ans_Cs) < 1e-8)
assert(cupy.linalg.norm(res_Cs_2 - ans_Cs) < 1e-8)

def test_cart2sph(self):
a_cart = cupy.random.rand(10,6,11)
a_sph0 = cart2sph_cutensor(a_cart, axis=1, ang=2)
a_sph1 = cart2sph(a_cart, axis=1, ang=2)
assert cupy.linalg.norm(a_sph0 - a_sph1) < 1e-8

a_cart = cupy.random.rand(10,10,11)
a_sph0 = cart2sph_cutensor(a_cart, axis=1, ang=3)
a_sph1 = cart2sph(a_cart, axis=1, ang=3)
assert cupy.linalg.norm(a_sph0 - a_sph1) < 1e-8

a_cart = cupy.random.rand(10,15,11)
a_sph0 = cart2sph_cutensor(a_cart, axis=1, ang=4)
a_sph1 = cart2sph(a_cart, axis=1, ang=4)
assert cupy.linalg.norm(a_sph0 - a_sph1) < 1e-8

if __name__ == "__main__":
print("Full tests for cupy helper module")
unittest.main()

0 comments on commit 6047760

Please sign in to comment.