Skip to content

Commit 0a33323

Browse files
Merge branch 'master' into remove_mcpdft_1
2 parents 777098c + b4da987 commit 0a33323

File tree

2 files changed

+67
-28
lines changed

2 files changed

+67
-28
lines changed

.github/workflows/run_ci.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ sudo apt-get -qq install \
99
curl
1010

1111
python -m pip install --upgrade pip
12+
pip install "scipy<1.16"
1213
pip install pytest
1314
pip install .
1415

pyscf/tools/trexio.py

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,21 @@
2121
from pyscf import gto
2222
from pyscf import scf
2323
from pyscf import pbc
24+
from pyscf import mcscf
2425
from pyscf import fci
2526

2627
import trexio
2728

28-
def to_trexio(obj, filename, backend='h5'):
29+
def to_trexio(obj, filename, backend='h5', ci_threshold=None, chunk_size=None):
2930
with trexio.File(filename, 'u', back_end=_mode(backend)) as tf:
3031
if isinstance(obj, gto.Mole):
3132
_mol_to_trexio(obj, tf)
3233
elif isinstance(obj, scf.hf.SCF):
3334
_scf_to_trexio(obj, tf)
35+
elif isinstance(obj, mcscf.casci.CASCI) or isinstance(obj, mcscf.CASSCF):
36+
ci_threshold = ci_threshold if ci_threshold is not None else 0.
37+
chunk_size = chunk_size if chunk_size is not None else 100000
38+
_mcscf_to_trexio(obj, tf, ci_threshold=ci_threshold, chunk_size=chunk_size)
3439
else:
3540
raise NotImplementedError(f'Conversion function for {obj.__class__}')
3641

@@ -255,8 +260,39 @@ def _scf_to_trexio(mf, trexio_file):
255260
def _cc_to_trexio(cc_obj, trexio_file):
256261
raise NotImplementedError
257262

258-
def _mcscf_to_trexio(cas_obj, trexio_file):
259-
raise NotImplementedError
263+
def _mcscf_to_trexio(cas_obj, trexio_file, ci_threshold=0., chunk_size=100000):
264+
mol = cas_obj.mol
265+
_mol_to_trexio(mol, trexio_file)
266+
mo_energy_cas = cas_obj.mo_energy
267+
mo_cas = cas_obj.mo_coeff
268+
num_mo = mo_energy_cas.size
269+
spin_cas = np.zeros(mo_energy_cas.size, dtype=int)
270+
mo_type_cas = 'CAS'
271+
trexio.write_mo_type(trexio_file, mo_type_cas)
272+
idx = _order_ao_index(mol)
273+
trexio.write_mo_num(trexio_file, num_mo)
274+
trexio.write_mo_coefficient(trexio_file, mo_cas[idx].T.ravel())
275+
trexio.write_mo_energy(trexio_file, mo_energy_cas)
276+
trexio.write_mo_spin(trexio_file, spin_cas)
277+
278+
ncore = cas_obj.ncore
279+
ncas = cas_obj.ncas
280+
mo_classes = np.array(["Virtual"] * num_mo, dtype=str) # Initialize all MOs as Virtual
281+
mo_classes[:ncore] = "Core"
282+
mo_classes[ncore:ncore + ncas] = "Active"
283+
trexio.write_mo_class(trexio_file, list(mo_classes))
284+
285+
occupation = np.zeros(num_mo)
286+
occupation[:ncore] = 2.0
287+
rdm1 = cas_obj.fcisolver.make_rdm1(cas_obj.ci, ncas, cas_obj.nelecas)
288+
natural_occ = np.linalg.eigh(rdm1)[0]
289+
occupation[ncore:ncore + ncas] = natural_occ[::-1]
290+
occupation[ncore + ncas:] = 0.0
291+
trexio.write_mo_occupation(trexio_file, occupation)
292+
293+
total_elec_cas = sum(cas_obj.nelecas)
294+
295+
det_to_trexio(cas_obj, ncas, total_elec_cas, trexio_file, ci_threshold, chunk_size)
260296

261297
def mol_from_trexio(filename):
262298
mol = gto.Mole()
@@ -336,19 +372,25 @@ def write_eri(eri, filename, backend='h5'):
336372
idx[:,:,2:] = idx_pair[None,:,:]
337373
idx = idx[np.tril_indices(npair)]
338374

375+
# Physicist notation
376+
idx=idx.reshape((num_integrals,4))
377+
for i in range(num_integrals):
378+
idx[i,1],idx[i,2]=idx[i,2],idx[i,1]
379+
380+
idx=idx.flatten()
381+
339382
with trexio.File(filename, 'w', back_end=_mode(backend)) as tf:
340383
trexio.write_mo_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel())
341384

342385
def read_eri(filename):
343-
'''Read ERIs in AO basis, 8-fold symmetry is assumed'''
344386
with trexio.File(filename, 'r', back_end=trexio.TREXIO_AUTO) as tf:
345387
nmo = trexio.read_mo_num(tf)
346388
nao_pair = nmo * (nmo+1) // 2
347389
eri_size = nao_pair * (nao_pair+1) // 2
348390
idx, data, n_read, eof_flag = trexio.read_mo_2e_int_eri(tf, 0, eri_size)
349391
eri = np.zeros(eri_size)
350-
x = idx[:,0]*(idx[:,0]+1)//2 + idx[:,1]
351-
y = idx[:,2]*(idx[:,2]+1)//2 + idx[:,3]
392+
x = idx[:,0]*(idx[:,0]+1)//2 + idx[:,2]
393+
y = idx[:,1]*(idx[:,1]+1)//2 + idx[:,3]
352394
eri[x*(x+1)//2+y] = data
353395
return eri
354396

@@ -417,17 +459,18 @@ def get_occsa_and_occsb(mcscf, norb, nelec, ci_threshold=0.):
417459

418460
return occsa_sorted, occsb_sorted, ci_values_sorted, num_determinants
419461

420-
def det_to_trexio(mcscf, norb, nelec, filename, backend='h5', ci_threshold=0., chunk_size=100000):
462+
def det_to_trexio(mcscf, norb, nelec, trexio_file, ci_threshold=0., chunk_size=100000):
421463
from trexio_tools.group_tools import determinant as trexio_det
422464

423-
mo_num = norb
424-
int64_num = int((mo_num - 1) / 64) + 1
465+
ncore = mcscf.ncore
466+
int64_num = trexio.get_int64_num(trexio_file)
467+
425468
occsa, occsb, ci_values, num_determinants = get_occsa_and_occsb(mcscf, norb, nelec, ci_threshold)
426469

427470
det_list = []
428471
for a, b, coeff in zip(occsa, occsb, ci_values):
429-
occsa_upshifted = [orb + 1 for orb in a]
430-
occsb_upshifted = [orb + 1 for orb in b]
472+
occsa_upshifted = [orb for orb in range(ncore)] + [orb+ncore for orb in a]
473+
occsb_upshifted = [orb for orb in range(ncore)] + [orb+ncore for orb in b]
431474
det_tmp = []
432475
det_tmp += trexio_det.to_determinant_list(occsa_upshifted, int64_num)
433476
det_tmp += trexio_det.to_determinant_list(occsb_upshifted, int64_num)
@@ -438,24 +481,19 @@ def det_to_trexio(mcscf, norb, nelec, filename, backend='h5', ci_threshold=0., c
438481
else:
439482
n_chunks = 1
440483

441-
with trexio.File(filename, 'u', back_end=_mode(backend)) as tf:
442-
if trexio.has_determinant(tf):
443-
trexio.delete_determinant(tf)
444-
trexio.write_mo_num(tf, mo_num)
445-
trexio.write_electron_up_num(tf, len(a))
446-
trexio.write_electron_dn_num(tf, len(b))
447-
trexio.write_electron_num(tf, len(a) + len(b))
484+
if trexio.has_determinant(trexio_file):
485+
trexio.delete_determinant(trexio_file)
448486

449-
offset_file = 0
450-
for i in range(n_chunks):
451-
start = i * chunk_size
452-
end = min((i + 1) * chunk_size, num_determinants)
453-
current_chunk_size = end - start
454-
455-
if current_chunk_size > 0:
456-
trexio.write_determinant_list(tf, offset_file, current_chunk_size, det_list[start:end])
457-
trexio.write_determinant_coefficient(tf, offset_file, current_chunk_size, ci_values[start:end])
458-
offset_file += current_chunk_size
487+
offset_file = 0
488+
for i in range(n_chunks):
489+
start = i * chunk_size
490+
end = min((i + 1) * chunk_size, num_determinants)
491+
current_chunk_size = end - start
492+
493+
if current_chunk_size > 0:
494+
trexio.write_determinant_list(trexio_file, offset_file, current_chunk_size, det_list[start:end])
495+
trexio.write_determinant_coefficient(trexio_file, offset_file, current_chunk_size, ci_values[start:end])
496+
offset_file += current_chunk_size
459497

460498
def read_det_trexio(filename):
461499
with trexio.File(filename, 'r', back_end=trexio.TREXIO_AUTO) as tf:

0 commit comments

Comments
 (0)