diff --git a/benchmarks/cupy_helper/benchmark_cart2sph.py b/benchmarks/cupy_helper/benchmark_cart2sph.py
new file mode 100644
index 00000000..620d72c3
--- /dev/null
+++ b/benchmarks/cupy_helper/benchmark_cart2sph.py
@@ -0,0 +1,52 @@
+# Copyright 2024 The GPU4PySCF Authors. All Rights Reserved.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import numpy as np
+import cupy
+from cupyx import profiler
+from gpu4pyscf.lib.cupy_helper import cart2sph, cart2sph_cutensor
+
+print('benchmarking cart2sph when ang=2')
+a = cupy.random.random([512,6*128,512])
+b = cupy.random.random([512,5*128,512])
+perf_kernel = profiler.benchmark(cart2sph, (a,1,2,b), n_repeat=20, n_warmup=3)
+perf_cutensor = profiler.benchmark(cart2sph_cutensor, (a,1,2), n_repeat=20, n_warmup=3)
+t_kernel = perf_kernel.gpu_times.mean()
+t_cutensor = perf_cutensor.gpu_times.mean()
+print('kernel:', t_kernel)
+print('cutensor:', t_cutensor)
+print('memory bandwidth:',(a.nbytes+b.nbytes)/t_kernel/1024**3, 'GB/s')
+
+print('benchmarking cart2sph when ang=3')
+a = cupy.random.random([512,10*128,512])
+b = cupy.random.random([512,7*128,512])
+perf_kernel = profiler.benchmark(cart2sph, (a,1,3,b), n_repeat=20, n_warmup=3)
+perf_cutensor = profiler.benchmark(cart2sph_cutensor, (a,1,3), n_repeat=20, n_warmup=3)
+t_kernel = perf_kernel.gpu_times.mean()
+t_cutensor = perf_cutensor.gpu_times.mean()
+print('kernel:', t_kernel)
+print('cutensor:', t_cutensor)
+print('memory bandwidth:',(a.nbytes+b.nbytes)/t_kernel/1024**3, 'GB/s')
+
+print('benchmarking cart2sph when ang=4')
+a = cupy.random.random([512,15*128,512])
+b = cupy.random.random([512,9*128,512])
+perf_kernel = profiler.benchmark(cart2sph, (a,1,4,b), n_repeat=20, n_warmup=3)
+perf_cutensor = profiler.benchmark(cart2sph_cutensor, (a,1,4), n_repeat=20, n_warmup=3)
+t_kernel = perf_kernel.gpu_times.mean()
+t_cutensor = perf_cutensor.gpu_times.mean()
+print('kernel:', t_kernel)
+print('cutensor:', t_cutensor)
+print('memory bandwidth:',(a.nbytes+b.nbytes)/t_kernel/1024**3, 'GB/s')
\ No newline at end of file
diff --git a/examples/dft_driver.py b/examples/dft_driver.py
index ef7a0923..65a5a3d7 100644
--- a/examples/dft_driver.py
+++ b/examples/dft_driver.py
@@ -36,13 +36,13 @@
basis=bas,
max_memory=32000)
# set verbose >= 6 for debugging timer
-mol.verbose = 7
+mol.verbose = 4
if args.unrestricted:
mf_df = uks.UKS(mol, xc=args.xc).density_fit(auxbasis=args.auxbasis)
else:
mf_df = rks.RKS(mol, xc=args.xc).density_fit(auxbasis=args.auxbasis)
-mf_df.verbose = 7
+mf_df.verbose = 4
if args.solvent:
mf_df = mf_df.PCM()
diff --git a/gpu4pyscf/lib/cupy_helper.py b/gpu4pyscf/lib/cupy_helper.py
index 4d4cfa47..fe03587e 100644
--- a/gpu4pyscf/lib/cupy_helper.py
+++ b/gpu4pyscf/lib/cupy_helper.py
@@ -393,7 +393,7 @@ def hermi_triu(mat, hermi=1, inplace=True):
return mat
-def cart2sph(t, axis=0, ang=1, out=None):
+def cart2sph_cutensor(t, axis=0, ang=1, out=None):
'''
transform 'axis' of a tensor from cartesian basis into spherical basis
'''
@@ -415,6 +415,42 @@ def cart2sph(t, axis=0, ang=1, out=None):
t_sph = contract('min,ip->mpn', t_cart, c2s, out=out)
return t_sph.reshape(out_shape)
+def cart2sph(t, axis=0, ang=1, out=None, stream=None):
+ '''
+ transform 'axis' of a tensor from cartesian basis into spherical basis
+ '''
+ if(ang <= 1):
+ if(out is not None): out[:] = t
+ return t
+ size = list(t.shape)
+ c2s = c2s_l[ang]
+ if(not t.flags['C_CONTIGUOUS']): t = cupy.asarray(t, order='C')
+ li_size = c2s.shape
+ nli = size[axis] // li_size[0]
+ i0 = max(1, np.prod(size[:axis]))
+ i3 = max(1, np.prod(size[axis+1:]))
+ out_shape = size[:axis] + [nli*li_size[1]] + size[axis+1:]
+
+ t_cart = t.reshape([i0*nli, li_size[0], i3])
+ if(out is not None):
+ out = out.reshape([i0*nli, li_size[1], i3])
+ else:
+ out = cupy.empty(out_shape)
+ count = i0*nli*i3
+ if stream is None:
+ stream = cupy.cuda.get_current_stream()
+ err = libcupy_helper.cart2sph(
+ ctypes.cast(stream.ptr, ctypes.c_void_p),
+ ctypes.cast(t_cart.data.ptr, ctypes.c_void_p),
+ ctypes.cast(out.data.ptr, ctypes.c_void_p),
+ ctypes.c_int(i3),
+ ctypes.c_int(count),
+ ctypes.c_int(ang)
+ )
+ if err != 0:
+ raise RuntimeError('failed in cart2sph kernel')
+ return out.reshape(out_shape)
+
# a copy with modification from
# https://github.com/pyscf/pyscf/blob/9219058ac0a1bcdd8058166cad0fb9127b82e9bf/pyscf/lib/linalg_helper.py#L1536
def krylov(aop, b, x0=None, tol=1e-10, max_cycle=30, dot=cupy.dot,
diff --git a/gpu4pyscf/lib/cupy_helper/CMakeLists.txt b/gpu4pyscf/lib/cupy_helper/CMakeLists.txt
index 1fb32035..657441ab 100644
--- a/gpu4pyscf/lib/cupy_helper/CMakeLists.txt
+++ b/gpu4pyscf/lib/cupy_helper/CMakeLists.txt
@@ -27,6 +27,7 @@ add_library(cupy_helper SHARED
dist_matrix.cu
grouped_gemm.cu
grouped_dot.cu
+ cart2sph.cu
)
add_dependencies(cupy_helper cutlass)
diff --git a/gpu4pyscf/lib/cupy_helper/cart2sph.cu b/gpu4pyscf/lib/cupy_helper/cart2sph.cu
new file mode 100644
index 00000000..2da0c139
--- /dev/null
+++ b/gpu4pyscf/lib/cupy_helper/cart2sph.cu
@@ -0,0 +1,137 @@
+/* Copyright 2024 The GPU4PySCF Authors. All Rights Reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+#include
+#include
+
+#define THREADS 128
+
+// (n,ncart,stride) -> (n,nsph,stride), count = n*stride
+__global__
+static void _cart2sph_ang2(double *cart, double *sph, int stride, int count){
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= count){
+ return;
+ }
+ int i = idx / stride;
+ int j = idx % stride;
+ int sph_offset = 5 * stride * i + j;
+ int cart_offset = 6 * stride * i + j;
+ double g0 = cart[cart_offset+0*stride];
+ double g1 = cart[cart_offset+1*stride];
+ double g2 = cart[cart_offset+2*stride];
+ double g3 = cart[cart_offset+3*stride];
+ double g4 = cart[cart_offset+4*stride];
+ double g5 = cart[cart_offset+5*stride];
+
+ sph[sph_offset+0*stride] = 1.092548430592079070 * g1;
+ sph[sph_offset+1*stride] = 1.092548430592079070 * g4;
+ sph[sph_offset+2*stride] = 0.630783130505040012 * g5 - 0.315391565252520002 * (g0 + g3);
+ sph[sph_offset+3*stride] = 1.092548430592079070 * g2;
+ sph[sph_offset+4*stride] = 0.546274215296039535 * (g0 - g3);
+}
+
+__global__
+static void _cart2sph_ang3(double *cart, double *sph, int stride, int count){
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= count){
+ return;
+ }
+ int i = idx / stride;
+ int j = idx % stride;
+ int sph_offset = 7 * stride * i + j;
+ int cart_offset = 10 * stride * i + j;
+ double g0 = cart[cart_offset+0*stride];
+ double g1 = cart[cart_offset+1*stride];
+ double g2 = cart[cart_offset+2*stride];
+ double g3 = cart[cart_offset+3*stride];
+ double g4 = cart[cart_offset+4*stride];
+ double g5 = cart[cart_offset+5*stride];
+ double g6 = cart[cart_offset+6*stride];
+ double g7 = cart[cart_offset+7*stride];
+ double g8 = cart[cart_offset+8*stride];
+ double g9 = cart[cart_offset+9*stride];
+
+ sph[sph_offset+0*stride] = 1.770130769779930531 * g1 - 0.590043589926643510 * g6;
+ sph[sph_offset+1*stride] = 2.890611442640554055 * g4;
+ sph[sph_offset+2*stride] = 1.828183197857862944 * g8 - 0.457045799464465739 * (g1 + g6);
+ sph[sph_offset+3*stride] = 0.746352665180230782 * g9 - 1.119528997770346170 * (g2 + g7);
+ sph[sph_offset+4*stride] = 1.828183197857862944 * g5 - 0.457045799464465739 * (g0 + g3);
+ sph[sph_offset+5*stride] = 1.445305721320277020 * (g2 - g7);
+ sph[sph_offset+6*stride] = 0.590043589926643510 * g0 - 1.770130769779930530 * g3;
+}
+
+__global__
+static void _cart2sph_ang4(double *cart, double *sph, int stride, int count){
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= count){
+ return;
+ }
+ int i = idx / stride;
+ int j = idx % stride;
+ int sph_offset = 9 * stride * i + j;
+ int cart_offset = 15 * stride * i + j;
+ double g0 = cart[cart_offset+0*stride];
+ double g1 = cart[cart_offset+1*stride];
+ double g2 = cart[cart_offset+2*stride];
+ double g3 = cart[cart_offset+3*stride];
+ double g4 = cart[cart_offset+4*stride];
+ double g5 = cart[cart_offset+5*stride];
+ double g6 = cart[cart_offset+6*stride];
+ double g7 = cart[cart_offset+7*stride];
+ double g8 = cart[cart_offset+8*stride];
+ double g9 = cart[cart_offset+9*stride];
+ double g10 = cart[cart_offset+10*stride];
+ double g11 = cart[cart_offset+11*stride];
+ double g12 = cart[cart_offset+12*stride];
+ double g13 = cart[cart_offset+13*stride];
+ double g14 = cart[cart_offset+14*stride];
+
+ sph[sph_offset+0*stride] = 2.503342941796704538 * (g1 - g6);
+ sph[sph_offset+1*stride] = 5.310392309339791593 * g4 - 1.770130769779930530 * g11;
+ sph[sph_offset+2*stride] = 5.677048174545360108 * g8 - 0.946174695757560014 * (g1 + g6);
+ sph[sph_offset+3*stride] = 2.676186174229156671 * g13- 2.007139630671867500 * (g4 + g11);
+ sph[sph_offset+4*stride] = 0.317356640745612911 * (g0 + g10) + 0.634713281491225822 * g3 - 2.538853125964903290 * (g5 + g12) + 0.846284375321634430 * g14;
+ sph[sph_offset+5*stride] = 2.676186174229156671 * g9 - 2.007139630671867500 * (g2 + g7);
+ sph[sph_offset+6*stride] = 2.838524087272680054 * (g5 - g12) + 0.473087347878780009 * (g10 - g0);
+ sph[sph_offset+7*stride] = 1.770130769779930531 * g2 - 5.310392309339791590 * g7 ;
+ sph[sph_offset+8*stride] = 0.625835735449176134 * (g0 + g10) - 3.755014412695056800 * g3;
+}
+
+extern "C" {
+__host__
+int cart2sph(cudaStream_t stream, double *cart_gto, double *sph_gto, int stride, int count, int ang)
+{
+ dim3 threads(THREADS);
+ dim3 blocks((count + THREADS - 1)/THREADS);
+ switch (ang) {
+ case 0: break;
+ case 1: break;
+ case 2: _cart2sph_ang2 <<>> (cart_gto, sph_gto, stride, count); break;
+ case 3: _cart2sph_ang3 <<>> (cart_gto, sph_gto, stride, count); break;
+ case 4: _cart2sph_ang4 <<>> (cart_gto, sph_gto, stride, count); break;
+ default:
+ fprintf(stderr, "Ang > 4 is not supported!");
+ return 1;
+ }
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess) {
+ return 1;
+ }
+ return 0;
+}
+}
diff --git a/gpu4pyscf/lib/tests/test_cupy_helper.py b/gpu4pyscf/lib/tests/test_cupy_helper.py
index 51cc2bc9..4ee8d3cd 100644
--- a/gpu4pyscf/lib/tests/test_cupy_helper.py
+++ b/gpu4pyscf/lib/tests/test_cupy_helper.py
@@ -19,7 +19,7 @@
from gpu4pyscf.lib.cupy_helper import (
take_last2d, transpose_sum, krylov, unpack_sparse,
add_sparse, takebak, empty_mapped, dist_matrix,
- grouped_dot, grouped_gemm, cond)
+ grouped_dot, grouped_gemm, cond, cart2sph_cutensor, cart2sph)
class KnownValues(unittest.TestCase):
def test_take_last2d(self):
@@ -160,6 +160,22 @@ def generate_problems(problems):
assert(cupy.linalg.norm(res_Cs - ans_Cs) < 1e-8)
assert(cupy.linalg.norm(res_Cs_2 - ans_Cs) < 1e-8)
+ def test_cart2sph(self):
+ a_cart = cupy.random.rand(10,6,11)
+ a_sph0 = cart2sph_cutensor(a_cart, axis=1, ang=2)
+ a_sph1 = cart2sph(a_cart, axis=1, ang=2)
+ assert cupy.linalg.norm(a_sph0 - a_sph1) < 1e-8
+
+ a_cart = cupy.random.rand(10,10,11)
+ a_sph0 = cart2sph_cutensor(a_cart, axis=1, ang=3)
+ a_sph1 = cart2sph(a_cart, axis=1, ang=3)
+ assert cupy.linalg.norm(a_sph0 - a_sph1) < 1e-8
+
+ a_cart = cupy.random.rand(10,15,11)
+ a_sph0 = cart2sph_cutensor(a_cart, axis=1, ang=4)
+ a_sph1 = cart2sph(a_cart, axis=1, ang=4)
+ assert cupy.linalg.norm(a_sph0 - a_sph1) < 1e-8
+
if __name__ == "__main__":
print("Full tests for cupy helper module")
unittest.main()