@@ -56,6 +56,20 @@ def _auto_create_mode(array, mode):
5656# cutensor_dtype, alignment_req=alignment_req)
5757# return _tensor_descriptors[key]
5858
59+ def _contract_einsum (pattern , a , b , alpha , beta , out = None , einsum = cupy .einsum ):
60+ if out is None :
61+ out = einsum (pattern , a , b )
62+ out *= alpha
63+ elif beta == 0. :
64+ out [:] = einsum (pattern , a , b )
65+ out *= alpha
66+ else :
67+ out *= beta
68+ tmp = einsum (pattern , a , b )
69+ tmp *= alpha
70+ out += tmp
71+ return cupy .asarray (out , order = 'C' )
72+
5973def contraction (
6074 pattern , a , b , alpha , beta ,
6175 out = None ,
@@ -67,6 +81,9 @@ def contraction(
6781 compute_desc = 0 ,
6882 ws_pref = WORKSPACE_RECOMMENDED
6983):
84+ if a .size == 0 or b .size == 0 :
85+ # cutensor does not support the 0-sized operands
86+ return _contract_einsum (pattern , a , b , alpha , beta , out )
7087
7188 pattern = pattern .replace (" " , "" )
7289 str_a , rest = pattern .split (',' )
@@ -138,22 +155,11 @@ def contraction(
138155 warnings .warn (f'using { contract_engine } as the tensor contraction engine.' )
139156 def contract (pattern , a , b , alpha = 1.0 , beta = 0.0 , out = None ):
140157 try :
141- if out is None :
142- out = einsum (pattern , a , b )
143- out *= alpha
144- elif beta == 0. :
145- out [:] = einsum (pattern , a , b )
146- out *= alpha
147- else :
148- out *= beta
149- tmp = einsum (pattern , a , b )
150- tmp *= alpha
151- out += tmp
158+ return _contract_einsum (pattern , a , b , alpha , beta , out , einsum )
152159 except cupy .cuda .memory .OutOfMemoryError :
153160 print ('Out of memory error caused by cupy.einsum. '
154161 'It is recommended to install cutensor to resolve this.' )
155162 raise
156- return cupy .asarray (out , order = 'C' )
157163else :
158164 def contract (pattern , a , b , alpha = 1.0 , beta = 0.0 , out = None ):
159165 '''
0 commit comments