@@ -131,27 +131,13 @@ def gen_ft_kernel(cell, kpts=None, verbose=None):
131131class FTOpt :
132132 def __init__ (self , cell , kpts = None , bvk_kmesh = None ):
133133 self .cell = cell
134- sorted_cell , coeff , uniq_l_ctr , l_ctr_counts = group_basis (cell , tile = 1 )
134+ sorted_cell , ao_idx , l_ctr_pad_counts , uniq_l_ctr , l_ctr_counts = group_basis (
135+ cell , tile = 1 , sparse_coeff = True )
135136 self .sorted_cell = sorted_cell
136137 self .uniq_l_ctr = uniq_l_ctr
137138 self .l_ctr_offsets = np .append (0 , np .cumsum (l_ctr_counts ))
138- self .coeff = cp .asarray (coeff )
139-
140- # TODO: ao_idx from group_basis
141- ls = np .repeat (cell ._bas [:,ANG_OF ], cell ._bas [:,NCTR_OF ])
142- nprims = np .repeat (cell ._bas [:,NPRIM_OF ], cell ._bas [:,NCTR_OF ])
143- l_ctrs = np .column_stack ((ls , - nprims ))
144- _ , inv_idx = np .unique (l_ctrs , return_inverse = True , axis = 0 )
145- sorted_idx = np .argsort (inv_idx .ravel (), kind = 'stable' )
146- if cell .cart :
147- dims = (ls + 1 ) * (ls + 2 ) // 2
148- else :
149- dims = ls * 2 + 1
150- ao_loc = np .append (0 , dims .cumsum ())
151- ao_idx = np .array_split (np .arange (ao_loc [- 1 ]), ao_loc [1 :- 1 ])
152- # mat[ao_idx[:,None],ao_idx] transforms the matrix in original cell into
153- # the matrix represented in the sorted AOs.
154- self .ao_idx = np .hstack ([ao_idx [i ] for i in sorted_idx ])
139+ self .ao_idx = ao_idx
140+ self .l_ctr_pad_counts = l_ctr_pad_counts
155141
156142 if bvk_kmesh is None :
157143 if kpts is None or is_zero (kpts ):
@@ -193,6 +179,35 @@ def build(self, verbose=None):
193179 init_constant (cell )
194180 return self
195181
182+ @property
183+ def coeff (self ):
184+ from pyscf import gto
185+ coeff = np .zeros ((self .sorted_cell .nao , self .cell .nao ))
186+
187+ l_max = max ([l_ctr [0 ] for l_ctr in self .uniq_l_ctr ])
188+ if self .cell .cart :
189+ cart2sph_per_l = [np .eye ((l + 1 )* (l + 2 )// 2 ) for l in range (l_max + 1 )]
190+ else :
191+ cart2sph_per_l = [gto .mole .cart2sph (l , normalized = "sp" ) for l in range (l_max + 1 )]
192+ i_spherical_offset = 0
193+ i_cartesian_offset = 0
194+ for i , l in enumerate (self .uniq_l_ctr [:,0 ]):
195+ cart2sph = cart2sph_per_l [l ]
196+ ncart , nsph = cart2sph .shape
197+ l_ctr_count = self .l_ctr_offsets [i + 1 ] - self .l_ctr_offsets [i ]
198+ cart_offs = i_cartesian_offset + np .arange (l_ctr_count ) * ncart
199+ sph_offs = i_spherical_offset + np .arange (l_ctr_count ) * nsph
200+ cart_idx = cart_offs [:,None ] + np .arange (ncart )
201+ sph_idx = sph_offs [:,None ] + np .arange (nsph )
202+ coeff [cart_idx [:,:,None ],sph_idx [:,None ,:]] = cart2sph
203+ l_ctr_pad_count = self .l_ctr_pad_counts [i ]
204+ i_cartesian_offset += (l_ctr_count + l_ctr_pad_count ) * ncart
205+ i_spherical_offset += l_ctr_count * nsph
206+ assert len (self .ao_idx ) == self .cell .nao
207+ out = cp .zeros_like (coeff )
208+ out [:,self .ao_idx ] = coeff
209+ return asarray (out )
210+
196211 @property
197212 def aft_envs (self ):
198213 _aft_envs = self ._aft_envs
@@ -334,7 +349,9 @@ def gen_ft_kernel(self, verbose=None):
334349 bvkmesh_Ls = cp .asarray (
335350 k2gamma .translation_vectors_for_kmesh (cell , bvk_kmesh , True ))
336351 conj_mapping = cp .asarray (conj_images_in_bvk_cell (bvk_kmesh ), dtype = np .int32 )
337- nao , nao_orig = self .coeff .shape
352+ nao = self .sorted_cell .nao
353+ nao_orig = self .cell .nao
354+ coeff = cp .asarray (self .coeff , dtype = np .complex128 )
338355
339356 def _ft_sub (Gv , q , kptjs , img_idx_cache , transform_ao = True ):
340357 t1 = log .init_timer ()
@@ -408,7 +425,6 @@ def _ft_sub(Gv, q, kptjs, img_idx_cache, transform_ao=True):
408425 out = contract ('Lk,LpqG->kpqG' , expLk , out )
409426
410427 if transform_ao :
411- coeff = cp .asarray (self .coeff , dtype = np .complex128 )
412428 log .debug1 ('transform basis' )
413429 #:out = einsum('pqLG,pi,qj->LGij', out, coeff, coeff)
414430 out = contract ('kpqG,pi->kiqG' , out , coeff )
0 commit comments