6060import numpy as np
6161
6262from pytensor .graph .traversal import toposort
63- from pytensor .scalar import Add , Mul , Neg , Sub
63+ from pytensor .scalar import Add , Mul , Neg , Sub , upcast
6464from pytensor .tensor .rewriting .basic import register_specialize
6565
6666
@@ -359,8 +359,12 @@ def _gemm_from_factored_list(fgraph, lst):
359359 if isinstance (sM , tuple ):
360360 sm0 , sm1 = sM
361361 sm0 = ptb .as_tensor_variable (sm0 )
362- if pytensor .scalar .upcast (sm0 .dtype , sm1 .dtype ) == sm1 .dtype :
363- lst2 .append ((ptb .cast (sm0 , sm1 .dtype ), sM [1 ]))
362+ sm0_dtype = sm0 .type .dtype
363+ sm1_dtype = sm1 .type .dtype
364+ if sm0_dtype == sm1_dtype :
365+ lst2 .append ((sm0 , sm1 ))
366+ elif upcast (sm0_dtype , sm1_dtype ) == sm1_dtype :
367+ lst2 .append ((ptb .cast (sm0 , sm1_dtype ), sm1 ))
364368
365369 lst = lst2
366370
@@ -385,20 +389,15 @@ def item_to_var(t):
385389 if not M_j .type .in_same_class (M_i .type ):
386390 continue
387391
388- # print 'TRYING', (s_i, M_i, s_j, M_j)
389-
390392 gemm_of_sM_list , old_dot22 = _beta_L_plus_alpha_M (
391393 fgraph , s_i , M_i , s_j , M_j
392394 )
393- # print 'GOT IT', gemm_of_sM_list
394395 if gemm_of_sM_list :
395- assert len ( gemm_of_sM_list ) == 1
396+ [ new_add_inp ] = gemm_of_sM_list
396397 add_inputs = [
397398 item_to_var (input ) for k , input in enumerate (lst ) if k not in (i , j )
398399 ]
399- add_inputs .extend (gemm_of_sM_list )
400- rval = [variadic_add (* add_inputs )]
401- # print "RETURNING GEMM THING", rval
400+ rval = [variadic_add (* add_inputs , new_add_inp )]
402401 return rval , old_dot22
403402
404403
0 commit comments