2121from pyscf import gto
2222from pyscf import scf
2323from pyscf import pbc
24+ from pyscf import mcscf
2425from pyscf import fci
2526
2627import trexio
2728
28- def to_trexio (obj , filename , backend = 'h5' ):
29+ def to_trexio (obj , filename , backend = 'h5' , ci_threshold = None , chunk_size = None ):
2930 with trexio .File (filename , 'u' , back_end = _mode (backend )) as tf :
3031 if isinstance (obj , gto .Mole ):
3132 _mol_to_trexio (obj , tf )
3233 elif isinstance (obj , scf .hf .SCF ):
3334 _scf_to_trexio (obj , tf )
35+ elif isinstance (obj , mcscf .casci .CASCI ) or isinstance (obj , mcscf .CASSCF ):
36+ ci_threshold = ci_threshold if ci_threshold is not None else 0.
37+ chunk_size = chunk_size if chunk_size is not None else 100000
38+ _mcscf_to_trexio (obj , tf , ci_threshold = ci_threshold , chunk_size = chunk_size )
3439 else :
3540 raise NotImplementedError (f'Conversion function for { obj .__class__ } ' )
3641
@@ -255,8 +260,39 @@ def _scf_to_trexio(mf, trexio_file):
255260def _cc_to_trexio (cc_obj , trexio_file ):
256261 raise NotImplementedError
257262
258- def _mcscf_to_trexio (cas_obj , trexio_file ):
259- raise NotImplementedError
263+ def _mcscf_to_trexio (cas_obj , trexio_file , ci_threshold = 0. , chunk_size = 100000 ):
264+ mol = cas_obj .mol
265+ _mol_to_trexio (mol , trexio_file )
266+ mo_energy_cas = cas_obj .mo_energy
267+ mo_cas = cas_obj .mo_coeff
268+ num_mo = mo_energy_cas .size
269+ spin_cas = np .zeros (mo_energy_cas .size , dtype = int )
270+ mo_type_cas = 'CAS'
271+ trexio .write_mo_type (trexio_file , mo_type_cas )
272+ idx = _order_ao_index (mol )
273+ trexio .write_mo_num (trexio_file , num_mo )
274+ trexio .write_mo_coefficient (trexio_file , mo_cas [idx ].T .ravel ())
275+ trexio .write_mo_energy (trexio_file , mo_energy_cas )
276+ trexio .write_mo_spin (trexio_file , spin_cas )
277+
278+ ncore = cas_obj .ncore
279+ ncas = cas_obj .ncas
280+ mo_classes = np .array (["Virtual" ] * num_mo , dtype = str ) # Initialize all MOs as Virtual
281+ mo_classes [:ncore ] = "Core"
282+ mo_classes [ncore :ncore + ncas ] = "Active"
283+ trexio .write_mo_class (trexio_file , list (mo_classes ))
284+
285+ occupation = np .zeros (num_mo )
286+ occupation [:ncore ] = 2.0
287+ rdm1 = cas_obj .fcisolver .make_rdm1 (cas_obj .ci , ncas , cas_obj .nelecas )
288+ natural_occ = np .linalg .eigh (rdm1 )[0 ]
289+ occupation [ncore :ncore + ncas ] = natural_occ [::- 1 ]
290+ occupation [ncore + ncas :] = 0.0
291+ trexio .write_mo_occupation (trexio_file , occupation )
292+
293+ total_elec_cas = sum (cas_obj .nelecas )
294+
295+ det_to_trexio (cas_obj , ncas , total_elec_cas , trexio_file , ci_threshold , chunk_size )
260296
261297def mol_from_trexio (filename ):
262298 mol = gto .Mole ()
@@ -336,19 +372,25 @@ def write_eri(eri, filename, backend='h5'):
336372 idx [:,:,2 :] = idx_pair [None ,:,:]
337373 idx = idx [np .tril_indices (npair )]
338374
375+ # Physicist notation
376+ idx = idx .reshape ((num_integrals ,4 ))
377+ for i in range (num_integrals ):
378+ idx [i ,1 ],idx [i ,2 ]= idx [i ,2 ],idx [i ,1 ]
379+
380+ idx = idx .flatten ()
381+
339382 with trexio .File (filename , 'w' , back_end = _mode (backend )) as tf :
340383 trexio .write_mo_2e_int_eri (tf , 0 , num_integrals , idx , eri .ravel ())
341384
342385def read_eri (filename ):
343- '''Read ERIs in AO basis, 8-fold symmetry is assumed'''
344386 with trexio .File (filename , 'r' , back_end = trexio .TREXIO_AUTO ) as tf :
345387 nmo = trexio .read_mo_num (tf )
346388 nao_pair = nmo * (nmo + 1 ) // 2
347389 eri_size = nao_pair * (nao_pair + 1 ) // 2
348390 idx , data , n_read , eof_flag = trexio .read_mo_2e_int_eri (tf , 0 , eri_size )
349391 eri = np .zeros (eri_size )
350- x = idx [:,0 ]* (idx [:,0 ]+ 1 )// 2 + idx [:,1 ]
351- y = idx [:,2 ]* (idx [:,2 ]+ 1 )// 2 + idx [:,3 ]
392+ x = idx [:,0 ]* (idx [:,0 ]+ 1 )// 2 + idx [:,2 ]
393+ y = idx [:,1 ]* (idx [:,1 ]+ 1 )// 2 + idx [:,3 ]
352394 eri [x * (x + 1 )// 2 + y ] = data
353395 return eri
354396
@@ -417,17 +459,18 @@ def get_occsa_and_occsb(mcscf, norb, nelec, ci_threshold=0.):
417459
418460 return occsa_sorted , occsb_sorted , ci_values_sorted , num_determinants
419461
420- def det_to_trexio (mcscf , norb , nelec , filename , backend = 'h5' , ci_threshold = 0. , chunk_size = 100000 ):
462+ def det_to_trexio (mcscf , norb , nelec , trexio_file , ci_threshold = 0. , chunk_size = 100000 ):
421463 from trexio_tools .group_tools import determinant as trexio_det
422464
423- mo_num = norb
424- int64_num = int ((mo_num - 1 ) / 64 ) + 1
465+ ncore = mcscf .ncore
466+ int64_num = trexio .get_int64_num (trexio_file )
467+
425468 occsa , occsb , ci_values , num_determinants = get_occsa_and_occsb (mcscf , norb , nelec , ci_threshold )
426469
427470 det_list = []
428471 for a , b , coeff in zip (occsa , occsb , ci_values ):
429- occsa_upshifted = [orb + 1 for orb in a ]
430- occsb_upshifted = [orb + 1 for orb in b ]
472+ occsa_upshifted = [orb for orb in range ( ncore )] + [ orb + ncore for orb in a ]
473+ occsb_upshifted = [orb for orb in range ( ncore )] + [ orb + ncore for orb in b ]
431474 det_tmp = []
432475 det_tmp += trexio_det .to_determinant_list (occsa_upshifted , int64_num )
433476 det_tmp += trexio_det .to_determinant_list (occsb_upshifted , int64_num )
@@ -438,24 +481,19 @@ def det_to_trexio(mcscf, norb, nelec, filename, backend='h5', ci_threshold=0., c
438481 else :
439482 n_chunks = 1
440483
441- with trexio .File (filename , 'u' , back_end = _mode (backend )) as tf :
442- if trexio .has_determinant (tf ):
443- trexio .delete_determinant (tf )
444- trexio .write_mo_num (tf , mo_num )
445- trexio .write_electron_up_num (tf , len (a ))
446- trexio .write_electron_dn_num (tf , len (b ))
447- trexio .write_electron_num (tf , len (a ) + len (b ))
484+ if trexio .has_determinant (trexio_file ):
485+ trexio .delete_determinant (trexio_file )
448486
449- offset_file = 0
450- for i in range (n_chunks ):
451- start = i * chunk_size
452- end = min ((i + 1 ) * chunk_size , num_determinants )
453- current_chunk_size = end - start
454-
455- if current_chunk_size > 0 :
456- trexio .write_determinant_list (tf , offset_file , current_chunk_size , det_list [start :end ])
457- trexio .write_determinant_coefficient (tf , offset_file , current_chunk_size , ci_values [start :end ])
458- offset_file += current_chunk_size
487+ offset_file = 0
488+ for i in range (n_chunks ):
489+ start = i * chunk_size
490+ end = min ((i + 1 ) * chunk_size , num_determinants )
491+ current_chunk_size = end - start
492+
493+ if current_chunk_size > 0 :
494+ trexio .write_determinant_list (trexio_file , offset_file , current_chunk_size , det_list [start :end ])
495+ trexio .write_determinant_coefficient (trexio_file , offset_file , current_chunk_size , ci_values [start :end ])
496+ offset_file += current_chunk_size
459497
460498def read_det_trexio (filename ):
461499 with trexio .File (filename , 'r' , back_end = trexio .TREXIO_AUTO ) as tf :
0 commit comments