diff --git a/gpu4pyscf/df/tests/test_df_hessian.py b/gpu4pyscf/df/tests/test_df_hessian.py index 166aa482c..e01a23796 100644 --- a/gpu4pyscf/df/tests/test_df_hessian.py +++ b/gpu4pyscf/df/tests/test_df_hessian.py @@ -33,10 +33,10 @@ def setUpModule(): global mol_sph, mol_cart mol_sph = pyscf.M(atom=atom, basis=bas0, max_memory=32000, cart=0, - output='/dev/null', verbose=1) + output='/dev/null', verbose=6) mol_cart = pyscf.M(atom=atom, basis=bas0, max_memory=32000, cart=1, - output='/dev/null', verbose=1) + output='/dev/null', verbose=6) def tearDownModule(): global mol_sph, mol_cart diff --git a/gpu4pyscf/dft/numint.py b/gpu4pyscf/dft/numint.py index 55c25dfab..866382635 100644 --- a/gpu4pyscf/dft/numint.py +++ b/gpu4pyscf/dft/numint.py @@ -29,7 +29,7 @@ from gpu4pyscf.lib import logger from gpu4pyscf.lib.multi_gpu import lru_cache from gpu4pyscf import __config__ -from gpu4pyscf.__config__ import _streams, num_devices +from gpu4pyscf.__config__ import num_devices LMAX_ON_GPU = 8 BAS_ALIGNED = 1 @@ -448,7 +448,7 @@ def _nr_rks_task(ni, mol, grids, xc_code, dm, mo_coeff, mo_occ, verbose=None, with_lapl=False, device_id=0, hermi=1): ''' nr_rks task on given device ''' - with cupy.cuda.Device(device_id), _streams[device_id]: + with cupy.cuda.Device(device_id): if isinstance(dm, cupy.ndarray): assert dm.ndim == 2 # Ensure dm allocated on each device @@ -858,7 +858,7 @@ def _nr_uks_task(ni, mol, grids, xc_code, dms, mo_coeff, mo_occ, verbose=None, with_lapl=False, device_id=0, hermi=1): ''' nr_uks task on one device ''' - with cupy.cuda.Device(device_id), _streams[device_id]: + with cupy.cuda.Device(device_id): if dms is not None: dma, dmb = dms dma = cupy.asarray(dma) @@ -1117,7 +1117,7 @@ def get_rho(ni, mol, dm, grids, max_memory=2000, verbose=None): def _nr_rks_fxc_task(ni, mol, grids, xc_code, fxc, dms, mo1, occ_coeff, verbose=None, hermi=1, device_id=0): - with cupy.cuda.Device(device_id), _streams[device_id]: + with cupy.cuda.Device(device_id): if dms is not None: dms = cupy.asarray(dms) if mo1 is not None: mo1 = cupy.asarray(mo1) if occ_coeff is not None: occ_coeff = cupy.asarray(occ_coeff) @@ -1281,7 +1281,7 @@ def nr_rks_fxc_st(ni, mol, grids, xc_code, dm0=None, dms_alpha=None, def _nr_uks_fxc_task(ni, mol, grids, xc_code, fxc, dms, mo1, occ_coeff, verbose=None, hermi=1, device_id=0): - with cupy.cuda.Device(device_id), _streams[device_id]: + with cupy.cuda.Device(device_id): if dms is not None: dma, dmb = dms dma = cupy.asarray(dma) @@ -2272,14 +2272,22 @@ def _scale_ao(ao, wv, out=None): raise RuntimeError('CUDA Error') return out -def _tau_dot(bra, ket, wv, buf=None, out=None): - '''1/2 ''' - # einsum('g,xig,xjg->ij', .5*wv, bra[1:4], ket[1:4]) - wv = cupy.asarray(.5 * wv) - out = contract('ig,jg->ij', bra[1], _scale_ao(ket[1], wv, out=buf), out=out) - out = contract('ig,jg->ij', bra[2], _scale_ao(ket[2], wv, out=buf), beta=1., out=out) - out = contract('ig,jg->ij', bra[3], _scale_ao(ket[3], wv, out=buf), beta=1., out=out) - return out +from gpu4pyscf.lib.cutensor import contract_trinary, __version__ +if __version__ is not None and __version__ >= 20301: + # NOTE: contract_trinary seems only working on the default stream (-1). + # Calling contract_trinary under the _stream[*] causes random outputs. + def _tau_dot(bra, ket, wv, buf=None, out=None): + '''1/2 ''' + return contract_trinary('g,xig,xjg->ij', .5*wv, bra[1:4], ket[1:4], out=out) +else: + def _tau_dot(bra, ket, wv, buf=None, out=None): + '''1/2 ''' + # einsum('g,xig,xjg->ij', .5*wv, bra[1:4], ket[1:4]) + wv = cupy.asarray(.5 * wv) + out = contract('ig,jg->ij', bra[1], _scale_ao(ket[1], wv, out=buf), out=out) + out = contract('ig,jg->ij', bra[2], _scale_ao(ket[2], wv, out=buf), beta=1., out=out) + out = contract('ig,jg->ij', bra[3], _scale_ao(ket[3], wv, out=buf), beta=1., out=out) + return out class _GDFTOpt: def __init__(self, mol): diff --git a/gpu4pyscf/hessian/rks.py b/gpu4pyscf/hessian/rks.py index 321c2b360..6bed895ec 100644 --- a/gpu4pyscf/hessian/rks.py +++ b/gpu4pyscf/hessian/rks.py @@ -31,7 +31,7 @@ from gpu4pyscf.lib.cupy_helper import (contract, add_sparse, get_avail_mem, reduce_to_device, transpose_sum, take_last2d) from gpu4pyscf.lib import logger -from gpu4pyscf.__config__ import _streams, num_devices, min_grid_blksize +from gpu4pyscf.__config__ import num_devices, min_grid_blksize from gpu4pyscf.dft.numint import NLC_REMOVE_ZERO_RHO_GRID_THRESHOLD, _contract_rho1_fxc import ctypes from pyscf import __config__ @@ -392,7 +392,7 @@ def _get_vxc_deriv2_task(hessobj, grids, mo_coeff, mo_occ, max_memory, device_id ngrids_glob = grids.coords.shape[0] grid_start, grid_end = numint.gen_grid_range(ngrids_glob, device_id) - with cupy.cuda.Device(device_id), _streams[device_id]: + with cupy.cuda.Device(device_id): log = logger.new_logger(mol, verbose) t1 = t0 = log.init_timer() mo_occ = cupy.asarray(mo_occ) @@ -1269,7 +1269,7 @@ def _get_vxc_deriv1_task(hessobj, grids, mo_coeff, mo_occ, max_memory, device_id ngrids_glob = grids.coords.shape[0] grid_start, grid_end = numint.gen_grid_range(ngrids_glob, device_id) - with cupy.cuda.Device(device_id), _streams[device_id]: + with cupy.cuda.Device(device_id): mo_occ = cupy.asarray(mo_occ) mo_coeff = cupy.asarray(mo_coeff) coeff = cupy.asarray(opt.coeff) @@ -2133,9 +2133,50 @@ def _get_vnlc_deriv1(hessobj, mo_coeff, mo_occ, max_memory): return vmat_mo +from gpu4pyscf.lib.cutensor import contract_trinary, __version__ +if __version__ is not None and __version__ >= 20301: + # NOTE: contract_trinary seems only working on the default stream (-1). + # Calling contract_trinary under the _stream[*] causes random outputs. + def _contract_vmat(xctype, ao, wv, vmat, mask, vtmp, buf=None): + if xctype == 'LDA': + v = contract_trinary('ng,ig,jg->nij', wv[:,0], ao, ao, out=vtmp) + elif xctype == 'GGA': + wv[:,0] *= .5 + v = contract_trinary('nxg,xig,jg->nij', wv, ao[:4], ao[0], out=vtmp) + elif xctype == 'NLC': + raise NotImplementedError('NLC') + else: + wv[:,0] *= .5 + wv[:,4] *= .5 + v = contract_trinary('g,xig,xjg->ij', wv[:,4], bra[1:4], ket[1:4], out=vtmp) + v = contract_trinary('xig,nxg,jg->nij', ao[:4], wv, ao[0], beta=1, out=v) + add_sparse(vmat, v, mask) + return vmat +else: + def _contract_vmat(xctype, ao, wv, vmat, mask, vtmp, buf=None): + for i in range(len(wv)): + if xctype == 'LDA': + aow = numint._scale_ao(ao, wv[i,0], out=buf) + add_sparse(vmat[i], ao.dot(aow.T, out=vtmp[0]), mask) + # vmat_tmp = ao.dot(numint._scale_ao(ao, wv[i][0]).T) + elif xctype == 'GGA': + wv[i,0] *= .5 + aow = numint._scale_ao(ao, wv[i], out=buf) + add_sparse(vmat[i], ao[0].dot(aow.T, out=vtmp), mask) + elif xctype == 'NLC': + raise NotImplementedError('NLC') + else: + wv[i,0] *= .5 + wv[i,4] *= .5 + vtmp = numint._tau_dot(ao, ao, wv[i,4], buf=buf, out=vtmp) + aow = numint._scale_ao(ao[:4], wv[i,:4], out=buf) + vtmp = contract('ig, jg->ij', ao[0], aow, beta=1, out=vtmp) + add_sparse(vmat[i], vtmp, mask) + return vmat + def _nr_rks_fxc_mo_task(ni, mol, grids, xc_code, fxc, mo_coeff, mo1, mocc, verbose=None, hermi=1, device_id=0): - with cupy.cuda.Device(device_id), _streams[device_id]: + with cupy.cuda.Device(device_id): if mo_coeff is not None: mo_coeff = cupy.asarray(mo_coeff) if mo1 is not None: mo1 = cupy.asarray(mo1) if mocc is not None: mocc = cupy.asarray(mocc) @@ -2163,7 +2204,6 @@ def _nr_rks_fxc_mo_task(ni, mol, grids, xc_code, fxc, mo_coeff, mo1, mocc, p0 = p1 = grid_start t1 = t0 = log.init_timer() - #### 初始化内存 if xctype == 'LDA': ncomp = 1 elif xctype == 'GGA': @@ -2172,13 +2212,14 @@ def _nr_rks_fxc_mo_task(ni, mol, grids, xc_code, fxc, mo_coeff, mo1, mocc, ncomp = 5 fxc_w_buf = cupy.empty(ncomp*ncomp*MIN_BLK_SIZE) buf = cupy.empty(MIN_BLK_SIZE * nao) - vtmp_buf = cupy.empty(nao*nao) + vtmp_buf = cupy.empty(nset*nao*nao) + vmat1 = cupy.zeros_like(vmat) for ao, mask, weights, coords in ni.block_loop(_sorted_mol, grids, nao, ao_deriv, max_memory=None, blksize=None, grid_range=(grid_start, grid_end)): blk_size = len(weights) nao_sub = len(mask) - vtmp = cupy.ndarray((nao_sub, nao_sub), memptr=vtmp_buf.data) + vtmp = cupy.ndarray((nset, nao_sub, nao_sub), memptr=vtmp_buf.data) p0, p1 = p1, p1+len(weights) occ_coeff_mask = mocc[mask] @@ -2191,26 +2232,7 @@ def _nr_rks_fxc_mo_task(ni, mol, grids, xc_code, fxc, mo_coeff, mo1, mocc, fxc_w = cupy.ndarray((ncomp, ncomp, blk_size), memptr=fxc_w_buf.data) fxc_w = cupy.multiply(fxc[:,:,p0:p1], weights, out=fxc_w) wv = contract('axg,xyg->ayg', rho1, fxc_w, out=rho1) - - for i in range(nset): - if xctype == 'LDA': - aow = numint._scale_ao(ao, wv[i][0], out=buf) - add_sparse(vmat[i], ao.dot(aow.T, out=vtmp), mask) - # vmat_tmp = ao.dot(numint._scale_ao(ao, wv[i][0]).T) - elif xctype == 'GGA': - wv[i,0] *= .5 - aow = numint._scale_ao(ao, wv[i], out=buf) - add_sparse(vmat[i], ao[0].dot(aow.T, out=vtmp), mask) - elif xctype == 'NLC': - raise NotImplementedError('NLC') - else: - wv[i,0] *= .5 - wv[i,4] *= .5 - vtmp = numint._tau_dot(ao, ao, wv[i,4], buf=buf, out=vtmp) - aow = numint._scale_ao(ao[:4], wv[i,:4], out=buf) - vtmp = contract('ig, jg->ij', ao[0], aow, beta=1, out=vtmp) # ao[0].dot(aow.T, out=vtmp) - add_sparse(vmat[i], vtmp, mask) - + _contract_vmat(xctype, ao, wv, vmat, mask, vtmp, buf) t1 = log.timer_debug2('integration', *t1) ao = rho1 = None t0 = log.timer_debug1(f'vxc on Device {device_id} ', *t0) diff --git a/gpu4pyscf/lib/cupy_helper.py b/gpu4pyscf/lib/cupy_helper.py index 86506f0d1..9c60cfcab 100644 --- a/gpu4pyscf/lib/cupy_helper.py +++ b/gpu4pyscf/lib/cupy_helper.py @@ -20,7 +20,7 @@ import cupy from pyscf import lib from gpu4pyscf.lib import logger -from gpu4pyscf.lib.cutensor import contract +from gpu4pyscf.lib.cutensor import contract, contract_trinary from gpu4pyscf.lib.cusolver import eigh, cholesky #NOQA from gpu4pyscf.lib.memcpy import copy_array, p2p_transfer #NOQA from gpu4pyscf.lib.multi_gpu import lru_cache diff --git a/gpu4pyscf/lib/cutensor.py b/gpu4pyscf/lib/cutensor.py index 24bef8353..af6d45b64 100644 --- a/gpu4pyscf/lib/cutensor.py +++ b/gpu4pyscf/lib/cutensor.py @@ -12,26 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings import numpy as np import cupy from gpu4pyscf.lib import logger +from gpu4pyscf.__config__ import props try: - import cupy_backends.cuda.libs.cutensor # NOQA + import cupy_backends.cuda.libs.cutensor as cutensor_backend from cupyx import cutensor - from cupy_backends.cuda.libs import cutensor as cutensor_backend ALGO_DEFAULT = cutensor_backend.ALGO_DEFAULT OP_IDENTITY = cutensor_backend.OP_IDENTITY JIT_MODE_NONE = cutensor_backend.JIT_MODE_NONE WORKSPACE_RECOMMENDED = cutensor_backend.WORKSPACE_MIN #WORKSPACE_RECOMMENDED = cutensor_backend.WORKSPACE_RECOMMENDED - _tensor_descriptors = {} + __version__ = cutensor_backend.get_version() + if __version__ >= 20301 and props['major'] <= 7: + warnings.warn(f'cuTENSOR {__version__} does not support sm_70. ' + 'It is recommended to install 2.2.0 or older versions.') + cutensor = None + + import ctypes + libcutensor = ctypes.CDLL(cutensor_backend.__file__) except (ImportError, AttributeError): cutensor = None ALGO_DEFAULT = None OP_IDENTITY = None JIT_MODE_NONE = None WORKSPACE_RECOMMENDED = None + __version__ = None def _auto_create_mode(array, mode): if not isinstance(mode, cutensor.Mode): @@ -132,6 +141,119 @@ def contraction( ws.data.ptr, ws_size) return out +_contraction_operators = {} + +# subclass OperationDescriptor, to overwrite the readonly attribute .ptr +class _OperationDescriptor(cutensor.OperationDescriptor): + def __init__(self, ptr): + assert isinstance(ptr, int) + self.ctypes_ptr = ptr + self._ptr = ptr + + @property + def ptr(self): + return self.ctypes_ptr + +def _create_contraction_trinary(desc_a, mode_a, op_a, desc_b, mode_b, op_b, + desc_c, mode_c, op_c, desc_d, mode_d, op_d, compute_desc=0): + handler = cutensor._get_handle() + dtype = desc_d.cutensor_dtype + key = (desc_a.ptr, mode_a.data, + desc_b.ptr, mode_b.data, + desc_c.ptr, mode_c.data, + desc_d.ptr, mode_d.data, dtype) + if key not in _contraction_operators: + op_desc_ptr = ctypes.c_void_p() + mode_a_ptr = ctypes.cast(mode_a.data, ctypes.POINTER(ctypes.c_int32)) + mode_b_ptr = ctypes.cast(mode_b.data, ctypes.POINTER(ctypes.c_int32)) + mode_c_ptr = ctypes.cast(mode_c.data, ctypes.POINTER(ctypes.c_int32)) + mode_d_ptr = ctypes.cast(mode_d.data, ctypes.POINTER(ctypes.c_int32)) + if dtype == 0 or dtype == 4: + descCompute = ctypes.c_void_p.in_dll(libcutensor, "CUTENSOR_COMPUTE_DESC_32F") + elif dtype == 1 or dtype == 5: + descCompute = ctypes.c_void_p.in_dll(libcutensor, "CUTENSOR_COMPUTE_DESC_64F") + else: + raise RuntimeError(f'dtype {dtype} not supported') + err = libcutensor.cutensorCreateContractionTrinary( + ctypes.c_void_p(handler.ptr), ctypes.byref(op_desc_ptr), + ctypes.c_void_p(desc_a.ptr), mode_a_ptr, ctypes.c_int(op_a), + ctypes.c_void_p(desc_b.ptr), mode_b_ptr, ctypes.c_int(op_b), + ctypes.c_void_p(desc_c.ptr), mode_c_ptr, ctypes.c_int(op_c), + ctypes.c_void_p(desc_d.ptr), mode_d_ptr, ctypes.c_int(op_d), + ctypes.c_void_p(desc_d.ptr), mode_d_ptr, + descCompute) + if err != cutensor_backend.STATUS_SUCCESS: + raise RuntimeError(f'cutensorCreateContractionTrinary failed. err={err}') + _contraction_operators[key] = _OperationDescriptor(op_desc_ptr.value) + return _contraction_operators[key] + +def contract_trinary(pattern, a, b, c, alpha=1., beta=0., out=None): + '''Three-tensor contraction + out = alpha * A * B * C + beta * out + ''' + pattern = pattern.replace(" ", "") + str_ops, str_out = pattern.split('->') + str_a, str_b, str_c = str_ops.split(',') + key = str_a + str_b + str_c + val = a.shape + b.shape + c.shape + shape = {k:v for k, v in zip(key, val)} + + mode_a = list(str_a) + mode_b = list(str_b) + mode_c = list(str_c) + mode_out = list(str_out) + if len(mode_out) != len(set(mode_out)): + raise ValueError('Output subscripts string includes the same subscript multiple times.') + + dtype = np.result_type(a.dtype, b.dtype, c.dtype) + a = cupy.asarray(a, dtype=dtype) + b = cupy.asarray(b, dtype=dtype) + c = cupy.asarray(c, dtype=dtype) + if out is None: + out = cupy.empty([shape[k] for k in str_out], order='C', dtype=dtype) + + desc_a = cutensor.create_tensor_descriptor(a) + desc_b = cutensor.create_tensor_descriptor(b) + desc_c = cutensor.create_tensor_descriptor(c) + desc_out = cutensor.create_tensor_descriptor(out) + + mode_a = _auto_create_mode(a, mode_a) + mode_b = _auto_create_mode(b, mode_b) + mode_c = _auto_create_mode(c, mode_c) + mode_out = _auto_create_mode(out, mode_out) + + operator = _create_contraction_trinary( + desc_a, mode_a, OP_IDENTITY, desc_b, mode_b, OP_IDENTITY, + desc_c, mode_c, OP_IDENTITY, desc_out, mode_out, OP_IDENTITY) + + handler = cutensor._get_handle() + algo = ALGO_DEFAULT + jit_mode = JIT_MODE_NONE + ws_pref = WORKSPACE_RECOMMENDED + plan_pref = cutensor.create_plan_preference(algo=algo, jit_mode=jit_mode) + ws_size = cutensor_backend.estimateWorkspaceSize( + handler.ptr, operator.ptr, plan_pref.ptr, ws_pref) + plan = cutensor.create_plan(operator, plan_pref, ws_limit=ws_size) + ws = cupy.empty(ws_size, dtype=np.int8) + + alpha = np.asarray(alpha, dtype=dtype) + beta = np.asarray(beta, dtype=dtype) + stream = cupy.cuda.get_current_stream() + err = libcutensor.cutensorContractTrinary( + ctypes.c_void_p(handler.ptr), ctypes.c_void_p(plan.ptr), + alpha.ctypes, + ctypes.cast(a.data.ptr, ctypes.c_void_p), + ctypes.cast(b.data.ptr, ctypes.c_void_p), + ctypes.cast(c.data.ptr, ctypes.c_void_p), + beta.ctypes, + ctypes.cast(out.data.ptr, ctypes.c_void_p), + ctypes.cast(out.data.ptr, ctypes.c_void_p), + ctypes.cast(ws.data.ptr, ctypes.c_void_p), + ctypes.c_int(ws_size), ctypes.c_void_p(stream.ptr)) + if err != cutensor_backend.STATUS_SUCCESS: + raise RuntimeError(f'cutensorContractTrinary failed. err={err}') + return out + import os contract_engine = None if cutensor is None: @@ -151,7 +273,6 @@ def contraction( else: raise RuntimeError('unknown tensor contraction engine.') - import warnings warnings.warn(f'using {contract_engine} as the tensor contraction engine.') def contract(pattern, a, b, alpha=1.0, beta=0.0, out=None): try: @@ -160,6 +281,10 @@ def contract(pattern, a, b, alpha=1.0, beta=0.0, out=None): print('Out of memory error caused by cupy.einsum. ' 'It is recommended to install cutensor to resolve this.') raise + + def contract_trinary(pattern, a, b, c, alpha=1., beta=0., out=None): + raise RuntimeError('contract_trinary is only supported with the cuTENSOR backend') + else: def contract(pattern, a, b, alpha=1.0, beta=0.0, out=None): ''' @@ -167,3 +292,7 @@ def contract(pattern, a, b, alpha=1.0, beta=0.0, out=None): pattern has to be a standard einsum notation ''' return contraction(pattern, a, b, alpha, beta, out=out) + + if __version__ < 20301: + def contract_trinary(pattern, a, b, c, alpha=1., beta=0., out=None): + raise RuntimeError('cuTENSOR 2.3 or newer is required') diff --git a/gpu4pyscf/lib/tests/test_cutensor.py b/gpu4pyscf/lib/tests/test_cutensor.py index e44a193bd..7f6f4d2ba 100644 --- a/gpu4pyscf/lib/tests/test_cutensor.py +++ b/gpu4pyscf/lib/tests/test_cutensor.py @@ -15,7 +15,8 @@ import unittest import numpy import cupy -from gpu4pyscf.lib.cupy_helper import contract +from gpu4pyscf.lib.cupy_helper import contract, contract_trinary +from gpu4pyscf.lib import cutensor class KnownValues(unittest.TestCase): def test_contract(self): @@ -56,6 +57,64 @@ def test_cache(self): c_einsum = cupy.einsum('ijkl,jl->ik', a, b) assert cupy.linalg.norm(c - c_einsum) < 1e-10 + @unittest.skipIf('', cutensor.__version__ is None or cutensor.__version__ < 20301) + def test_trinary_contraction(self): + a = cupy.random.rand(48,11) + b = cupy.random.rand(48,13) + c = cupy.random.rand(48,4) + ref = numpy.einsum('pi,pj,pk->kij', a.get(), b.get(), c.get(), optimize=True) + out = contract_trinary('pi,pj,pk->kij', a, b, c) + assert abs(out.get() - ref).max() < 1e-10 + + a = cupy.random.rand(48,11) + b = cupy.random.rand(48,13) + c = cupy.random.rand(48,4) + ref = numpy.einsum('pi,pj,pk->ikj', a.get(), b.get(), c.get(), optimize=True) + out = contract_trinary('pi,pj,pk->ikj', a, b, c) + assert abs(out.get() - ref).max() < 1e-10 + + a = cupy.random.rand(48,11) + b = cupy.random.rand(48,13) + c = cupy.random.rand(48,4) + ref = numpy.einsum('pi,pj,pk->ij', a.get(), b.get(), c.get(), optimize=True) + out = contract_trinary('pi,pj,pk->ij', a, b, c) + assert abs(out.get() - ref).max() < 1e-10 + + a = cupy.random.rand(48,11) + b = cupy.random.rand(3,48,13) + c = cupy.random.rand(48,4) + ref = numpy.einsum('pi,xpj,pk->xkij', a.get(), b.get(), c.get(), optimize=True) + out = contract_trinary('pi,xpj,pk->xkij', a, b, c) + assert abs(out.get() - ref).max() < 1e-10 + + a = cupy.random.rand(48,11) + b = cupy.random.rand(3,48,13) + c = cupy.random.rand(48,4) + ref = numpy.einsum('xpj,pi,pk->xkij', b.get(), a.get(), c.get(), optimize=True) + out = contract_trinary('xpj,pi,pk->xkij', b, a, c) + assert abs(out.get() - ref).max() < 1e-10 + + a = cupy.random.rand(4,48,11) + b = cupy.random.rand(48,13) + c = cupy.random.rand(4,48) + ref = numpy.einsum('xpi,pj,xp->ij', a.get(), b.get(), c.get(), optimize=True) + out = contract_trinary('xpi,pj,xp->ij', a, b, c) + assert abs(out.get() - ref).max() < 1e-10 + + a = cupy.random.rand(4,48,11) + b = cupy.random.rand(48,13) + c = cupy.random.rand(4,48) + ref = numpy.einsum('xpi,xp,pj->ij', a.get(), c.get(), b.get(), optimize=True) + out = contract_trinary('xpi,xp,pj->ij', a, c, b) + assert abs(out.get() - ref).max() < 1e-10 + + a = cupy.random.rand(20,4096) + b = cupy.random.rand(6,4096) + c = cupy.empty((6,20,20)) + c = contract_trinary('ig,ng,jg->nij', a, b, a, out=c) + ref = numpy.einsum('ig,ng,jg->nij', a.get(), b.get(), a.get(), optimize=True) + assert abs(c.get() - ref).max() < 1e-10 + if __name__ == "__main__": print("Full tests for cutensor module") unittest.main()