Skip to content

Commit b81e01b

Browse files
committed
MCSCF object can now be directly registered using TREXIO capabilities
1 parent 8fb92d6 commit b81e01b

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

pyscf/tools/trexio.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pyscf import gto
2222
from pyscf import scf
2323
from pyscf import pbc
24+
from pyscf import mcscf
2425
from pyscf import fci
2526

2627
import trexio
@@ -31,6 +32,8 @@ def to_trexio(obj, filename, backend='h5'):
3132
_mol_to_trexio(obj, tf)
3233
elif isinstance(obj, scf.hf.SCF):
3334
_scf_to_trexio(obj, tf)
35+
elif isinstance(obj, mcscf.casci.CASCI) or isinstance(obj, mcscf.CASSCF):
36+
_mcscf_to_trexio(obj, tf)
3437
else:
3538
raise NotImplementedError(f'Conversion function for {obj.__class__}')
3639

@@ -249,7 +252,34 @@ def _cc_to_trexio(cc_obj, trexio_file):
249252
raise NotImplementedError
250253

251254
def _mcscf_to_trexio(cas_obj, trexio_file):
252-
raise NotImplementedError
255+
mol = cas_obj.mol
256+
_mol_to_trexio(mol, trexio_file)
257+
mo_energy_cas = cas_obj.mo_energy
258+
mo_cas = cas_obj.mo_coeff
259+
num_mo = mo_energy_cas.size
260+
spin_cas = np.zeros(mo_energy_cas.size, dtype=int)
261+
mo_type_cas = 'CAS'
262+
trexio.write_mo_type(trexio_file, mo_type_cas)
263+
idx = _order_ao_index(mol)
264+
trexio.write_mo_num(trexio_file, num_mo)
265+
trexio.write_mo_coefficient(trexio_file, mo_cas[idx].T.ravel())
266+
trexio.write_mo_energy(trexio_file, mo_energy_cas)
267+
trexio.write_mo_spin(trexio_file, spin_cas)
268+
269+
ncore = cas_obj.ncore
270+
ncas = cas_obj.ncas
271+
mo_classes = np.array(["Virtual"] * num_mo, dtype=str) # Initialize all MOs as Virtual
272+
mo_classes[:ncore] = "Core"
273+
mo_classes[ncore:ncore + ncas] = "Active"
274+
trexio.write_mo_class(trexio_file, list(mo_classes))
275+
276+
occupation = np.zeros(num_mo)
277+
occupation[:ncore] = 2.0
278+
rdm1 = cas_obj.fcisolver.make_rdm1(cas_obj.ci, ncas, cas_obj.nelecas)
279+
natural_occ = np.linalg.eigh(rdm1)[0]
280+
occupation[ncore:ncore + ncas] = natural_occ[::-1]
281+
occupation[ncore + ncas:] = 0.0
282+
trexio.write_mo_occupation(trexio_file, occupation)
253283

254284
def mol_from_trexio(filename):
255285
mol = gto.Mole()
@@ -434,10 +464,6 @@ def det_to_trexio(mcscf, norb, nelec, filename, backend='h5', ci_threshold=0., c
434464
with trexio.File(filename, 'u', back_end=_mode(backend)) as tf:
435465
if trexio.has_determinant(tf):
436466
trexio.delete_determinant(tf)
437-
trexio.write_mo_num(tf, mo_num)
438-
trexio.write_electron_up_num(tf, len(a))
439-
trexio.write_electron_dn_num(tf, len(b))
440-
trexio.write_electron_num(tf, len(a) + len(b))
441467

442468
offset_file = 0
443469
for i in range(n_chunks):

0 commit comments

Comments
 (0)