|
36 | 36 | from gpu4pyscf.__config__ import _streams, num_devices |
37 | 37 | from gpu4pyscf.lib import logger |
38 | 38 | from gpu4pyscf.lib import multi_gpu |
| 39 | +from gpu4pyscf.lib import utils |
39 | 40 | from gpu4pyscf.scf.jk import ( |
40 | 41 | LMAX, QUEUE_DEPTH, SHM_SIZE, THREADS, GROUP_SIZE, libvhf_rys, _VHFOpt, |
41 | 42 | init_constant, _make_tril_tile_mappings, _make_tril_pair_mappings, |
@@ -911,6 +912,11 @@ def _get_veff_resp_mo(hessobj, mol, dms, mo_coeff, mo_occ, hermi=1, omega=None): |
911 | 912 | return vj - 0.5 * vk |
912 | 913 |
|
913 | 914 | class HessianBase(lib.StreamObject): |
| 915 | + |
| 916 | + to_cpu = utils.to_cpu |
| 917 | + to_gpu = utils.to_gpu |
| 918 | + device = utils.device |
| 919 | + |
914 | 920 | # attributes |
915 | 921 | max_cycle = rhf_hess_cpu.HessianBase.max_cycle |
916 | 922 | level_shift = rhf_hess_cpu.HessianBase.level_shift |
@@ -947,19 +953,9 @@ def dump_flags(self, verbose=None): |
947 | 953 | self.max_memory, lib.current_memory()[0]) |
948 | 954 | return self |
949 | 955 |
|
950 | | - def to_cpu(self): |
951 | | - mf = self.base.to_cpu() |
952 | | - from importlib import import_module |
953 | | - mod = import_module(self.__module__.replace('gpu4pyscf', 'pyscf')) |
954 | | - cls = getattr(mod, self.__class__.__name__) |
955 | | - obj = cls(mf) |
956 | | - return obj |
957 | | - |
958 | 956 | class Hessian(HessianBase): |
959 | 957 | '''Non-relativistic restricted Hartree-Fock hessian''' |
960 | 958 |
|
961 | | - from gpu4pyscf.lib.utils import to_gpu, device |
962 | | - |
963 | 959 | def __init__(self, scf_method): |
964 | 960 | self.verbose = scf_method.verbose |
965 | 961 | self.stdout = scf_method.stdout |
|
0 commit comments