2626from gpu4pyscf .lib import utils
2727from gpu4pyscf .lib .cupy_helper import (
2828 load_library , tag_array , contract , sandwich_dot , block_diag , transpose_sum ,
29- dist_matrix )
29+ dist_matrix , batched_vec3_norm2 )
3030from gpu4pyscf .gto .mole import cart2sph_by_l
3131from gpu4pyscf .dft import numint
3232from gpu4pyscf .pbc import tools
@@ -723,22 +723,23 @@ def eval_vpplocG(cell, mesh):
723723 '''PRB, 58, 3641 Eq (5) first term
724724 '''
725725 assert cell .dimension != 2
726- Gv , (basex , basey , basez ) = cell .get_Gv_weights (mesh )[:2 ]
727- basex = cp .asarray (basex )
728- basey = cp .asarray (basey )
729- basez = cp .asarray (basez )
726+ Gv , (basex , basey , basez ) = tools .pbc ._get_Gv_with_base (cell , mesh )
730727 b = cell .reciprocal_vectors ()
731728 coords = cell .atom_coords ()
732729 rb = cp .asarray (coords .dot (b .T ))
733730 SIx = cp .exp (- 1j * rb [:,0 ,None ] * basex )
734731 SIy = cp .exp (- 1j * rb [:,1 ,None ] * basey )
735732 SIz = cp .exp (- 1j * rb [:,2 ,None ] * basez )
736- G2 = contract ('px,px->p' , Gv , Gv )
733+ # G2 = contract('px,px->p', Gv, Gv)
734+ G2 = batched_vec3_norm2 (Gv )
737735 charges = cell .atom_charges ()
738736
739737 coulG = tools .get_coulG (cell , Gv = Gv )
740738 vlocG = cp .zeros (len (G2 ), dtype = np .complex128 )
741739 vlocG0 = 0
740+
741+ _kernel_registery = {}
742+
742743 for ia in range (cell .natm ):
743744 symb = cell .atom_symbol (ia )
744745 if symb not in cell ._pseudo :
@@ -749,24 +750,97 @@ def eval_vpplocG(cell, mesh):
749750 if nexp == 0 :
750751 continue
751752
752- SI = (SIx [ia ,:,None ,None ] * SIy [ia ,:,None ] * SIz [ia ]).ravel ()
753- G2_red = G2 * rloc ** 2
754- SI *= cp .exp (- 0.5 * G2_red )
755753 vlocG0 += 2 * np .pi * charges [ia ]* rloc ** 2
756- vlocG -= charges [ia ] * coulG * SI
757754
758- # Add the C1, C2, C3, C4 contributions
759- cfacs = 0
755+ fn_name = f"gth_loc_reciporcal_nexp_{ nexp } _kernel"
756+ if fn_name not in _kernel_registery :
757+ C_declaration = ''
758+ C_contribution = ''
759+ if nexp >= 1 :
760+ C_declaration += ', const double cexp0'
761+ C_contribution += 'cfacs += cexp0;'
762+ if nexp >= 2 :
763+ C_declaration += ', const double cexp1'
764+ C_contribution += 'cfacs += cexp1 * (3 - G2_red);'
765+ if nexp >= 3 :
766+ C_declaration += ', const double cexp2'
767+ C_contribution += 'cfacs += cexp2 * (15 - 10 * G2_red + G2_red * G2_red);'
768+ if nexp >= 4 :
769+ C_declaration += ', const double cexp3'
770+ C_contribution += 'cfacs += cexp3 * (105 - 105 * G2_red + 21 * G2_red * G2_red - G2_red * G2_red * G2_red);'
771+ kernel_code = r'''
772+ #include <cupy/complex.cuh>
773+ extern "C" __global__
774+ void ''' + fn_name + '''(
775+ const double* __restrict__ grids_G2, const double* __restrict__ grids_coulG,
776+ const complex<double>* __restrict__ grids_SIx, const complex<double>* __restrict__ grids_SIy, const complex<double>* __restrict__ grids_SIz,
777+ complex<double>* __restrict__ grids_vlocG,
778+ const int n_mesh_x, const int n_mesh_y, const int n_mesh_z, const int i_atom,
779+ const double charge, const double rloc''' + C_declaration + r''')
780+ {
781+ const int i_grid = blockDim.x * blockIdx.x + threadIdx.x;
782+ const int ngrids = n_mesh_x * n_mesh_y * n_mesh_z;
783+ if (i_grid >= ngrids) return;
784+
785+ const double G2 = grids_G2[i_grid];
786+ const double coulG = grids_coulG[i_grid];
787+ const double G2_red = G2 * rloc * rloc;
788+ const int i_grid_x = i_grid / (n_mesh_y * n_mesh_z);
789+ const int i_grid_y = (i_grid - i_grid_x * (n_mesh_y * n_mesh_z)) / n_mesh_z;
790+ const int i_grid_z = i_grid - i_grid_x * (n_mesh_y * n_mesh_z) - i_grid_y * n_mesh_z;
791+ const complex<double> SIx = grids_SIx[i_atom * n_mesh_x + i_grid_x];
792+ const complex<double> SIy = grids_SIy[i_atom * n_mesh_y + i_grid_y];
793+ const complex<double> SIz = grids_SIz[i_atom * n_mesh_z + i_grid_z];
794+ const complex<double> SI = SIx * SIy * SIz * exp(-0.5 * G2_red);
795+ complex<double> vlocG = -charge * coulG * SI;
796+
797+ double cfacs = 0;
798+ ''' + C_contribution + r'''
799+ vlocG += 15.749609945722419 * rloc * rloc * rloc * cfacs * SI;
800+
801+ grids_vlocG[i_grid] += vlocG;
802+ }
803+ '''
804+ _kernel_registery [fn_name ] = cp .RawKernel (kernel_code , fn_name )
805+ kernel = _kernel_registery [fn_name ]
806+
807+ ngrids = G2 .shape [0 ]
808+ assert G2 .shape == (ngrids ,) and G2 .dtype == cp .float64
809+ assert coulG .shape == (ngrids ,) and coulG .dtype == cp .float64
810+ assert SIx .shape == (cell .natm , mesh [0 ]) and SIx .dtype == cp .complex128 and SIx .flags .c_contiguous
811+ assert SIy .shape == (cell .natm , mesh [1 ]) and SIy .dtype == cp .complex128 and SIy .flags .c_contiguous
812+ assert SIz .shape == (cell .natm , mesh [2 ]) and SIz .dtype == cp .complex128 and SIz .flags .c_contiguous
813+ assert vlocG .shape == (ngrids ,) and vlocG .dtype == cp .complex128
814+ assert ngrids < np .iinfo (np .int32 ).max
815+
816+ kernel_parameters = [G2 , coulG , SIx , SIy , SIz , vlocG , cp .int32 (mesh [0 ]), cp .int32 (mesh [1 ]), cp .int32 (mesh [2 ]),
817+ cp .int32 (ia ), cp .float64 (charges [ia ]), cp .float64 (rloc )]
760818 if nexp >= 1 :
761- cfacs += cexp [0 ]
819+ kernel_parameters . append ( cp . float64 ( cexp [0 ]))
762820 if nexp >= 2 :
763- cfacs += cexp [1 ] * ( 3 - G2_red )
821+ kernel_parameters . append ( cp . float64 ( cexp [1 ]) )
764822 if nexp >= 3 :
765- cfacs += cexp [2 ] * ( 15 - 10 * G2_red + G2_red ** 2 )
823+ kernel_parameters . append ( cp . float64 ( cexp [2 ]) )
766824 if nexp >= 4 :
767- cfacs += cexp [3 ] * (105 - 105 * G2_red + 21 * G2_red ** 2 - G2_red ** 3 )
768-
769- vlocG += (2 * np .pi )** (3 / 2. )* rloc ** 3 * cfacs * SI
825+ kernel_parameters .append (cp .float64 (cexp [3 ]))
826+ kernel (((ngrids + 1024 - 1 ) // 1024 , ), (1024 , ), kernel_parameters )
827+
828+ # SI = (SIx[ia,:,None,None] * SIy[ia,:,None] * SIz[ia]).ravel()
829+ # G2_red = G2 * rloc**2
830+ # SI *= cp.exp(-0.5*G2_red)
831+ # vlocG -= charges[ia] * coulG * SI
832+
833+ # # Add the C1, C2, C3, C4 contributions
834+ # cfacs = 0
835+ # if nexp >= 1:
836+ # cfacs += cexp[0]
837+ # if nexp >= 2:
838+ # cfacs += cexp[1] * (3 - G2_red)
839+ # if nexp >= 3:
840+ # cfacs += cexp[2] * (15 - 10*G2_red + G2_red**2)
841+ # if nexp >= 4:
842+ # cfacs += cexp[3] * (105 - 105*G2_red + 21*G2_red**2 - G2_red**3)
843+ # vlocG += (2*np.pi)**(3/2.)*rloc**3 * cfacs * SI
770844
771845 vlocG [0 ] += vlocG0
772846 return vlocG
0 commit comments