Skip to content

Commit 52bd613

Browse files
committed
rsdf_builder dimension error for generally contracted basis
1 parent c8ae173 commit 52bd613

File tree

2 files changed

+76
-20
lines changed

2 files changed

+76
-20
lines changed

gpu4pyscf/pbc/df/ft_ao.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -131,27 +131,13 @@ def gen_ft_kernel(cell, kpts=None, verbose=None):
131131
class FTOpt:
132132
def __init__(self, cell, kpts=None, bvk_kmesh=None):
133133
self.cell = cell
134-
sorted_cell, coeff, uniq_l_ctr, l_ctr_counts = group_basis(cell, tile=1)
134+
sorted_cell, ao_idx, l_ctr_pad_counts, uniq_l_ctr, l_ctr_counts = group_basis(
135+
cell, tile=1, sparse_coeff=True)
135136
self.sorted_cell = sorted_cell
136137
self.uniq_l_ctr = uniq_l_ctr
137138
self.l_ctr_offsets = np.append(0, np.cumsum(l_ctr_counts))
138-
self.coeff = cp.asarray(coeff)
139-
140-
# TODO: ao_idx from group_basis
141-
ls = np.repeat(cell._bas[:,ANG_OF], cell._bas[:,NCTR_OF])
142-
nprims = np.repeat(cell._bas[:,NPRIM_OF], cell._bas[:,NCTR_OF])
143-
l_ctrs = np.column_stack((ls, -nprims))
144-
_, inv_idx = np.unique(l_ctrs, return_inverse=True, axis=0)
145-
sorted_idx = np.argsort(inv_idx.ravel(), kind='stable')
146-
if cell.cart:
147-
dims = (ls + 1) * (ls + 2) // 2
148-
else:
149-
dims = ls * 2 + 1
150-
ao_loc = np.append(0, dims.cumsum())
151-
ao_idx = np.array_split(np.arange(ao_loc[-1]), ao_loc[1:-1])
152-
# mat[ao_idx[:,None],ao_idx] transforms the matrix in original cell into
153-
# the matrix represented in the sorted AOs.
154-
self.ao_idx = np.hstack([ao_idx[i] for i in sorted_idx])
139+
self.ao_idx = ao_idx
140+
self.l_ctr_pad_counts = l_ctr_pad_counts
155141

156142
if bvk_kmesh is None:
157143
if kpts is None or is_zero(kpts):
@@ -193,6 +179,35 @@ def build(self, verbose=None):
193179
init_constant(cell)
194180
return self
195181

182+
@property
183+
def coeff(self):
184+
from pyscf import gto
185+
coeff = np.zeros((self.sorted_cell.nao, self.cell.nao))
186+
187+
l_max = max([l_ctr[0] for l_ctr in self.uniq_l_ctr])
188+
if self.cell.cart:
189+
cart2sph_per_l = [np.eye((l+1)*(l+2)//2) for l in range(l_max + 1)]
190+
else:
191+
cart2sph_per_l = [gto.mole.cart2sph(l, normalized = "sp") for l in range(l_max + 1)]
192+
i_spherical_offset = 0
193+
i_cartesian_offset = 0
194+
for i, l in enumerate(self.uniq_l_ctr[:,0]):
195+
cart2sph = cart2sph_per_l[l]
196+
ncart, nsph = cart2sph.shape
197+
l_ctr_count = self.l_ctr_offsets[i + 1] - self.l_ctr_offsets[i]
198+
cart_offs = i_cartesian_offset + np.arange(l_ctr_count) * ncart
199+
sph_offs = i_spherical_offset + np.arange(l_ctr_count) * nsph
200+
cart_idx = cart_offs[:,None] + np.arange(ncart)
201+
sph_idx = sph_offs[:,None] + np.arange(nsph)
202+
coeff[cart_idx[:,:,None],sph_idx[:,None,:]] = cart2sph
203+
l_ctr_pad_count = self.l_ctr_pad_counts[i]
204+
i_cartesian_offset += (l_ctr_count + l_ctr_pad_count) * ncart
205+
i_spherical_offset += l_ctr_count * nsph
206+
assert len(self.ao_idx) == self.cell.nao
207+
out = cp.zeros_like(coeff)
208+
out[:,self.ao_idx] = coeff
209+
return asarray(out)
210+
196211
@property
197212
def aft_envs(self):
198213
_aft_envs = self._aft_envs
@@ -334,7 +349,9 @@ def gen_ft_kernel(self, verbose=None):
334349
bvkmesh_Ls = cp.asarray(
335350
k2gamma.translation_vectors_for_kmesh(cell, bvk_kmesh, True))
336351
conj_mapping = cp.asarray(conj_images_in_bvk_cell(bvk_kmesh), dtype=np.int32)
337-
nao, nao_orig = self.coeff.shape
352+
nao = self.sorted_cell.nao
353+
nao_orig = self.cell.nao
354+
coeff = cp.asarray(self.coeff, dtype=np.complex128)
338355

339356
def _ft_sub(Gv, q, kptjs, img_idx_cache, transform_ao=True):
340357
t1 = log.init_timer()
@@ -408,7 +425,6 @@ def _ft_sub(Gv, q, kptjs, img_idx_cache, transform_ao=True):
408425
out = contract('Lk,LpqG->kpqG', expLk, out)
409426

410427
if transform_ao:
411-
coeff = cp.asarray(self.coeff, dtype=np.complex128)
412428
log.debug1('transform basis')
413429
#:out = einsum('pqLG,pi,qj->LGij', out, coeff, coeff)
414430
out = contract('kpqG,pi->kiqG', out, coeff)

gpu4pyscf/pbc/df/tests/test_rsdf_builder.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,46 @@ def test_kpts_compressed1():
351351
print(ki, kj)
352352
assert abs(_ref - out[ki]).max() < 1e-10
353353

354+
def test_kpts_compressed_general_contraction():
355+
cell = pyscf.M(
356+
atom='''C 1.3 .2 .3
357+
C .19 .1 1.1
358+
''',
359+
basis='''
360+
C D
361+
173 0.27 -0.03
362+
5.8 0.8 -0.26
363+
1.9 0.1 0.81
364+
''',
365+
a=np.eye(3)*6)
366+
367+
auxcell = cell.copy()
368+
auxcell.basis = '''
369+
C S
370+
2.00 1.
371+
C D
372+
0.59 1.''',
373+
auxcell.build()
374+
nao = cell.nao
375+
omega = 0.3
376+
kmesh = [2,1,1]
377+
kpts = cell.make_kpts(kmesh)
378+
dat, dat_neg, idx = rsdf_builder.compressed_cderi_kk(cell, auxcell, kpts, omega=omega)
379+
ref = build_cderi(cell, auxcell, kpts, omega=omega)[0]
380+
kk_conserv = k2gamma.double_translation_indices(kmesh)
381+
bvkmesh_Ls = k2gamma.translation_vectors_for_kmesh(cell, kmesh, True)
382+
expLk = cp.exp(1j*cp.asarray(bvkmesh_Ls.dot(kpts.T)))
383+
for kp in sorted(dat):
384+
out = rsdf_builder.unpack_cderi(dat[kp], idx, kp, kk_conserv, expLk, nao)
385+
ki_idx, kj_idx = np.where(kk_conserv == kp)
386+
for ki, kj in zip(ki_idx, kj_idx):
387+
if (ki, kj) in ref:
388+
_ref = ref[ki, kj]
389+
else:
390+
_ref = ref[kj, ki].conj().transpose(0,2,1)
391+
print(ki, kj)
392+
assert abs(_ref - out[ki]).max() < 1e-11
393+
354394
@pytest.mark.skip('Must include gamma point')
355395
def test_kpts_compressed2():
356396
from pyscf.pbc.df import df as df_cpu

0 commit comments

Comments
 (0)