Skip to content

Commit 7bd1c9b

Browse files
authored
Fix SMD hessian (#534)
* Fix SMD hessian * lint * import error * Forcing an empty commit. * Update SMD implementation to make it consistent with PySCF SMD module * Mute some tests * Update tests * Fix to_gpu to_cpu issues * lint
1 parent 1c9489d commit 7bd1c9b

File tree

17 files changed

+148
-181
lines changed

17 files changed

+148
-181
lines changed

gpu4pyscf/df/hessian/rhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def _get_jk_mo(hessobj, mol, dms, mo_coeff, mocc,
643643
class Hessian(rhf_hess.Hessian):
644644
'''Non-relativistic restricted Hartree-Fock hessian'''
645645

646-
from gpu4pyscf.lib.utils import to_gpu, device
646+
_keys = {'auxbasis_response',}
647647

648648
auxbasis_response = 1
649649
partial_hess_elec = partial_hess_elec

gpu4pyscf/df/hessian/rks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def make_h1(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None, verbose=None):
104104

105105
class Hessian(rks_hess.Hessian):
106106
'''Non-relativistic RKS hessian'''
107-
from gpu4pyscf.lib.utils import to_gpu, device
107+
108+
_keys = {'auxbasis_response',}
108109

109110
auxbasis_response = 1
110111
partial_hess_elec = partial_hess_elec

gpu4pyscf/df/hessian/uhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def _ao2mo(mat, mocc, mo):
708708
class Hessian(uhf_hess.Hessian):
709709
'''Non-relativistic restricted Hartree-Fock hessian'''
710710

711-
from gpu4pyscf.lib.utils import to_gpu, device
711+
_keys = {'auxbasis_response',}
712712

713713
auxbasis_response = 1
714714
partial_hess_elec = partial_hess_elec

gpu4pyscf/df/hessian/uks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def make_h1(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None, verbose=None):
132132

133133
class Hessian(uks_hess.Hessian):
134134
'''Non-relativistic RKS hessian'''
135-
from gpu4pyscf.lib.utils import to_gpu, device
135+
136+
_keys = {'auxbasis_response',}
136137

137138
auxbasis_response = 1
138139
partial_hess_elec = partial_hess_elec

gpu4pyscf/hessian/rhf.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from gpu4pyscf.__config__ import _streams, num_devices
3737
from gpu4pyscf.lib import logger
3838
from gpu4pyscf.lib import multi_gpu
39+
from gpu4pyscf.lib import utils
3940
from gpu4pyscf.scf.jk import (
4041
LMAX, QUEUE_DEPTH, SHM_SIZE, THREADS, GROUP_SIZE, libvhf_rys, _VHFOpt,
4142
init_constant, _make_tril_tile_mappings, _make_tril_pair_mappings,
@@ -911,6 +912,11 @@ def _get_veff_resp_mo(hessobj, mol, dms, mo_coeff, mo_occ, hermi=1, omega=None):
911912
return vj - 0.5 * vk
912913

913914
class HessianBase(lib.StreamObject):
915+
916+
to_cpu = utils.to_cpu
917+
to_gpu = utils.to_gpu
918+
device = utils.device
919+
914920
# attributes
915921
max_cycle = rhf_hess_cpu.HessianBase.max_cycle
916922
level_shift = rhf_hess_cpu.HessianBase.level_shift
@@ -947,19 +953,9 @@ def dump_flags(self, verbose=None):
947953
self.max_memory, lib.current_memory()[0])
948954
return self
949955

950-
def to_cpu(self):
951-
mf = self.base.to_cpu()
952-
from importlib import import_module
953-
mod = import_module(self.__module__.replace('gpu4pyscf', 'pyscf'))
954-
cls = getattr(mod, self.__class__.__name__)
955-
obj = cls(mf)
956-
return obj
957-
958956
class Hessian(HessianBase):
959957
'''Non-relativistic restricted Hartree-Fock hessian'''
960958

961-
from gpu4pyscf.lib.utils import to_gpu, device
962-
963959
def __init__(self, scf_method):
964960
self.verbose = scf_method.verbose
965961
self.stdout = scf_method.stdout

gpu4pyscf/hessian/rks.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4250,8 +4250,6 @@ def get_veff_resp_mo(hessobj, mol, dms, mo_coeff, mo_occ, hermi=1, omega=None):
42504250
class Hessian(rhf_hess.HessianBase):
42514251
'''Non-relativistic RKS hessian'''
42524252

4253-
from gpu4pyscf.lib.utils import to_gpu, device
4254-
42554253
_keys = {'grids', 'grid_response'}
42564254

42574255
def __init__(self, mf):

gpu4pyscf/hessian/uhf.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,6 @@ def _get_veff_resp_mo(hessobj, mol, dms, mo_coeff, mo_occ, hermi=1):
459459
class Hessian(rhf_hess_gpu.HessianBase):
460460
'''Non-relativistic unrestricted Hartree-Fock hessian'''
461461

462-
from gpu4pyscf.lib.utils import to_gpu, device
463-
464462
__init__ = rhf_hess_gpu.Hessian.__init__
465463
partial_hess_elec = partial_hess_elec
466464
hess_elec = hess_elec

gpu4pyscf/hessian/uks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,6 @@ def get_veff_resp_mo(hessobj, mol, dms, mo_coeff, mo_occ, hermi=1):
895895

896896
class Hessian(rhf_hess.HessianBase):
897897
'''Non-relativistic UKS hessian'''
898-
from gpu4pyscf.lib.utils import to_gpu, device
899898

900899
def __init__(self, mf):
901900
rhf_hess.Hessian.__init__(self, mf)

gpu4pyscf/lib/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def to_cpu(method, out=None):
8787
val = val.get()
8888
setattr(out, key, val)
8989
if hasattr(out, 'reset'):
90-
out.reset()
90+
try:
91+
out.reset()
92+
except NotImplementedError:
93+
pass
9194
return out
9295

9396
def to_gpu(method, device=None):

gpu4pyscf/scf/hf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,7 @@ def reset(self, mol=None):
763763
self.mol = mol
764764
self._opt_gpu = {None: None}
765765
self._opt_jengine = {None: None}
766+
self._eri = None
766767
self.scf_summary = {}
767768
self.overlap_canonical_decomposed_x = None
768769
return self

0 commit comments

Comments
 (0)