2121from pyscf import gto
2222from pyscf import scf
2323from pyscf import pbc
24+ from pyscf import mcscf
2425from pyscf import fci
2526
2627import trexio
@@ -31,6 +32,8 @@ def to_trexio(obj, filename, backend='h5'):
3132 _mol_to_trexio (obj , tf )
3233 elif isinstance (obj , scf .hf .SCF ):
3334 _scf_to_trexio (obj , tf )
35+ elif isinstance (obj , mcscf .casci .CASCI ) or isinstance (obj , mcscf .CASSCF ):
36+ _mcscf_to_trexio (obj , tf )
3437 else :
3538 raise NotImplementedError (f'Conversion function for { obj .__class__ } ' )
3639
@@ -249,7 +252,34 @@ def _cc_to_trexio(cc_obj, trexio_file):
249252 raise NotImplementedError
250253
251254def _mcscf_to_trexio (cas_obj , trexio_file ):
252- raise NotImplementedError
255+ mol = cas_obj .mol
256+ _mol_to_trexio (mol , trexio_file )
257+ mo_energy_cas = cas_obj .mo_energy
258+ mo_cas = cas_obj .mo_coeff
259+ num_mo = mo_energy_cas .size
260+ spin_cas = np .zeros (mo_energy_cas .size , dtype = int )
261+ mo_type_cas = 'CAS'
262+ trexio .write_mo_type (trexio_file , mo_type_cas )
263+ idx = _order_ao_index (mol )
264+ trexio .write_mo_num (trexio_file , num_mo )
265+ trexio .write_mo_coefficient (trexio_file , mo_cas [idx ].T .ravel ())
266+ trexio .write_mo_energy (trexio_file , mo_energy_cas )
267+ trexio .write_mo_spin (trexio_file , spin_cas )
268+
269+ ncore = cas_obj .ncore
270+ ncas = cas_obj .ncas
271+ mo_classes = np .array (["Virtual" ] * num_mo , dtype = str ) # Initialize all MOs as Virtual
272+ mo_classes [:ncore ] = "Core"
273+ mo_classes [ncore :ncore + ncas ] = "Active"
274+ trexio .write_mo_class (trexio_file , list (mo_classes ))
275+
276+ occupation = np .zeros (num_mo )
277+ occupation [:ncore ] = 2.0
278+ rdm1 = cas_obj .fcisolver .make_rdm1 (cas_obj .ci , ncas , cas_obj .nelecas )
279+ natural_occ = np .linalg .eigh (rdm1 )[0 ]
280+ occupation [ncore :ncore + ncas ] = natural_occ [::- 1 ]
281+ occupation [ncore + ncas :] = 0.0
282+ trexio .write_mo_occupation (trexio_file , occupation )
253283
254284def mol_from_trexio (filename ):
255285 mol = gto .Mole ()
@@ -434,10 +464,6 @@ def det_to_trexio(mcscf, norb, nelec, filename, backend='h5', ci_threshold=0., c
434464 with trexio .File (filename , 'u' , back_end = _mode (backend )) as tf :
435465 if trexio .has_determinant (tf ):
436466 trexio .delete_determinant (tf )
437- trexio .write_mo_num (tf , mo_num )
438- trexio .write_electron_up_num (tf , len (a ))
439- trexio .write_electron_dn_num (tf , len (b ))
440- trexio .write_electron_num (tf , len (a ) + len (b ))
441467
442468 offset_file = 0
443469 for i in range (n_chunks ):
0 commit comments