Skip to content

Commit 078c9ba

Browse files
committed
upcast not needed all the time
1 parent 30809e6 commit 078c9ba

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

pytensor/tensor/rewriting/blas.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
import numpy as np
6161

6262
from pytensor.graph.traversal import toposort
63-
from pytensor.scalar import Add, Mul, Neg, Sub
63+
from pytensor.scalar import Add, Mul, Neg, Sub, upcast
6464
from pytensor.tensor.rewriting.basic import register_specialize
6565

6666

@@ -359,8 +359,12 @@ def _gemm_from_factored_list(fgraph, lst):
359359
if isinstance(sM, tuple):
360360
sm0, sm1 = sM
361361
sm0 = ptb.as_tensor_variable(sm0)
362-
if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype:
363-
lst2.append((ptb.cast(sm0, sm1.dtype), sM[1]))
362+
sm0_dtype = sm0.type.dtype
363+
sm1_dtype = sm1.type.dtype
364+
if sm0_dtype == sm1_dtype:
365+
lst2.append((sm0, sm1))
366+
elif upcast(sm0_dtype, sm1_dtype) == sm1_dtype:
367+
lst2.append((ptb.cast(sm0, sm1_dtype), sm1))
364368

365369
lst = lst2
366370

@@ -385,20 +389,15 @@ def item_to_var(t):
385389
if not M_j.type.in_same_class(M_i.type):
386390
continue
387391

388-
# print 'TRYING', (s_i, M_i, s_j, M_j)
389-
390392
gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M(
391393
fgraph, s_i, M_i, s_j, M_j
392394
)
393-
# print 'GOT IT', gemm_of_sM_list
394395
if gemm_of_sM_list:
395-
assert len(gemm_of_sM_list) == 1
396+
[new_add_inp] = gemm_of_sM_list
396397
add_inputs = [
397398
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
398399
]
399-
add_inputs.extend(gemm_of_sM_list)
400-
rval = [variadic_add(*add_inputs)]
401-
# print "RETURNING GEMM THING", rval
400+
rval = [variadic_add(*add_inputs, new_add_inp)]
402401
return rval, old_dot22
403402

404403

0 commit comments

Comments
 (0)