@@ -636,93 +636,89 @@ def local_inplace_ger(fgraph, node):
636636@node_rewriter ([gemm_no_inplace ])
637637def local_gemm_to_gemv (fgraph , node ):
638638 """GEMM acting on row or column matrices -> GEMV."""
639- if node .op == gemm_no_inplace :
640- z , a , x , y , b = node .inputs
641- if z .broadcastable == x .broadcastable == (True , False ):
642- r = gemv_no_inplace (z .dimshuffle (1 ), a , y .T , x .dimshuffle (1 ), b )
643- new_out = [r .dimshuffle ("x" , 0 )]
644- elif z .broadcastable == y .broadcastable == (False , True ):
645- r = gemv_no_inplace (z .dimshuffle (0 ), a , x , y .dimshuffle (0 ), b )
646- new_out = [r .dimshuffle (0 , "x" )]
647- else :
648- return
649- copy_stack_trace (node .outputs , new_out )
650- return new_out
639+ z , a , x , y , b = node .inputs
640+ if z .broadcastable == x .broadcastable == (True , False ):
641+ r = gemv_no_inplace (z .dimshuffle (1 ), a , y .T , x .dimshuffle (1 ), b )
642+ new_out = [r .dimshuffle ("x" , 0 )]
643+ elif z .broadcastable == y .broadcastable == (False , True ):
644+ r = gemv_no_inplace (z .dimshuffle (0 ), a , x , y .dimshuffle (0 ), b )
645+ new_out = [r .dimshuffle (0 , "x" )]
646+ else :
647+ return
648+ copy_stack_trace (node .outputs , new_out )
649+ return new_out
651650
652651
653652@node_rewriter ([gemm_no_inplace ])
654653def local_gemm_to_ger (fgraph , node ):
655654 """GEMM computing an outer-product -> GER."""
656- if node .op == gemm_no_inplace :
657- z , a , x , y , b = node .inputs
658- if x .broadcastable [1 ] and y .broadcastable [0 ]:
659- # x and y are both vectors so this might qualifies for a GER
660- xv = x .dimshuffle (0 )
661- yv = y .dimshuffle (1 )
662- try :
663- bval = ptb .get_underlying_scalar_constant_value (b )
664- except NotScalarConstantError :
665- # b isn't a constant, GEMM is doing useful pre-scaling
666- return
667-
668- if bval == 1 : # best case a natural GER
669- rval = ger (z , a , xv , yv )
670- new_out = [rval ]
671- elif bval == 0 : # GER on zeros_like should be faster than GEMM
672- zeros = ptb .zeros ([x .shape [0 ], y .shape [1 ]], x .dtype )
673- rval = ger (zeros , a , xv , yv )
674- new_out = [rval ]
675- else :
676- # if bval is another constant, then z is being usefully
677- # pre-scaled and GER isn't really the right tool for the job.
678- return
679- copy_stack_trace (node .outputs , new_out )
680- return new_out
681-
655+ z , a , x , y , b = node .inputs
656+ if x .broadcastable [1 ] and y .broadcastable [0 ]:
657+ # x and y are both vectors so this might qualifies for a GER
658+ xv = x .dimshuffle (0 )
659+ yv = y .dimshuffle (1 )
660+ try :
661+ bval = ptb .get_underlying_scalar_constant_value (b )
662+ except NotScalarConstantError :
663+ # b isn't a constant, GEMM is doing useful pre-scaling
664+ return
682665
683- # TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
684- # working
685- @node_rewriter ([_dot22 ])
686- def local_dot22_to_ger_or_gemv (fgraph , node ):
687- """dot22 computing an outer-product -> GER."""
688- if node .op == _dot22 :
689- x , y = node .inputs
690- xb = x .broadcastable
691- yb = y .broadcastable
692- one = ptb .as_tensor_variable (np .asarray (1 , dtype = x .dtype ))
693- zero = ptb .as_tensor_variable (np .asarray (0 , dtype = x .dtype ))
694- if xb [1 ] and yb [0 ]:
695- # x and y are both vectors so this might qualifies for a GER
696- xv = x .dimshuffle (0 )
697- yv = y .dimshuffle (1 )
698- zeros = ptb .zeros ([x .shape [0 ], y .shape [1 ]], dtype = x .dtype )
699- rval = ger (zeros , one , xv , yv )
666+ if bval == 1 : # best case a natural GER
667+ rval = ger (z , a , xv , yv )
668+ new_out = [rval ]
669+ elif bval == 0 : # GER on zeros_like should be faster than GEMM
670+ zeros = ptb .zeros ([x .shape [0 ], y .shape [1 ]], x .dtype )
671+ rval = ger (zeros , a , xv , yv )
700672 new_out = [rval ]
701- elif xb [0 ] and yb [1 ]:
702- # x and y are both vectors so this qualifies for a sdot / ddot
703- # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
704- xv = x .dimshuffle (1 )
705- zeros = ptb .AllocEmpty (x .dtype )(1 )
706- rval = gemv_no_inplace (zeros , one , y .T , xv , zero )
707- new_out = [rval .dimshuffle ("x" , 0 )]
708- elif xb [0 ] and not yb [0 ] and not yb [1 ]:
709- # x is vector, y is matrix so try gemv
710- xv = x .dimshuffle (1 )
711- zeros = ptb .AllocEmpty (x .dtype )(y .shape [1 ])
712- rval = gemv_no_inplace (zeros , one , y .T , xv , zero )
713- new_out = [rval .dimshuffle ("x" , 0 )]
714- elif not xb [0 ] and not xb [1 ] and yb [1 ]:
715- # x is matrix, y is vector, try gemv
716- yv = y .dimshuffle (0 )
717- zeros = ptb .AllocEmpty (x .dtype )(x .shape [0 ])
718- rval = gemv_no_inplace (zeros , one , x , yv , zero )
719- new_out = [rval .dimshuffle (0 , "x" )]
720673 else :
674+ # if bval is another constant, then z is being usefully
675+ # pre-scaled and GER isn't really the right tool for the job.
721676 return
722677 copy_stack_trace (node .outputs , new_out )
723678 return new_out
724679
725680
681+ # TODO: delete this optimization when we have the proper dot->gemm->ger pipeline working
682+ @node_rewriter ([_dot22 ])
683+ def local_dot22_to_ger_or_gemv (fgraph , node ):
684+ """dot22 computing an outer-product -> GER."""
685+ x , y = node .inputs
686+ xb = x .broadcastable
687+ yb = y .broadcastable
688+ one = ptb .as_tensor_variable (np .asarray (1 , dtype = x .dtype ))
689+ zero = ptb .as_tensor_variable (np .asarray (0 , dtype = x .dtype ))
690+ if xb [1 ] and yb [0 ]:
691+ # x and y are both vectors so this might qualifies for a GER
692+ xv = x .dimshuffle (0 )
693+ yv = y .dimshuffle (1 )
694+ zeros = ptb .zeros ([x .shape [0 ], y .shape [1 ]], dtype = x .dtype )
695+ rval = ger (zeros , one , xv , yv )
696+ new_out = [rval ]
697+ elif xb [0 ] and yb [1 ]:
698+ # x and y are both vectors so this qualifies for a sdot / ddot
699+ # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
700+ xv = x .dimshuffle (1 )
701+ zeros = ptb .AllocEmpty (x .dtype )(1 )
702+ rval = gemv_no_inplace (zeros , one , y .T , xv , zero )
703+ new_out = [rval .dimshuffle ("x" , 0 )]
704+ elif xb [0 ] and not yb [0 ] and not yb [1 ]:
705+ # x is vector, y is matrix so try gemv
706+ xv = x .dimshuffle (1 )
707+ zeros = ptb .AllocEmpty (x .dtype )(y .shape [1 ])
708+ rval = gemv_no_inplace (zeros , one , y .T , xv , zero )
709+ new_out = [rval .dimshuffle ("x" , 0 )]
710+ elif not xb [0 ] and not xb [1 ] and yb [1 ]:
711+ # x is matrix, y is vector, try gemv
712+ yv = y .dimshuffle (0 )
713+ zeros = ptb .AllocEmpty (x .dtype )(x .shape [0 ])
714+ rval = gemv_no_inplace (zeros , one , x , yv , zero )
715+ new_out = [rval .dimshuffle (0 , "x" )]
716+ else :
717+ return
718+ copy_stack_trace (node .outputs , new_out )
719+ return new_out
720+
721+
726722#################################
727723#
728724# Set up the BlasOpt optimizer
0 commit comments