@@ -583,7 +583,12 @@ def max_and_argmax(a, axis=None, keepdims=False):
583583 return [out , argout ]
584584
585585
586- class NonZeroCAReduce (CAReduce ):
586+ class FixedOpCAReduce (CAReduce ):
587+ def __str__ (self ):
588+ return f"{ type (self ).__name__ } {{{ self ._axis_str ()} }}"
589+
590+
591+ class NonZeroDimsCAReduce (FixedOpCAReduce ):
587592 def _c_all (self , node , name , inames , onames , sub ):
588593 decl , checks , alloc , loop , end = super ()._c_all (node , name , inames , onames , sub )
589594
@@ -614,7 +619,7 @@ def _c_all(self, node, name, inames, onames, sub):
614619 return decl , checks , alloc , loop , end
615620
616621
617- class Max (NonZeroCAReduce ):
622+ class Max (NonZeroDimsCAReduce ):
618623 nfunc_spec = ("max" , 1 , 1 )
619624
620625 def __init__ (self , axis ):
@@ -625,7 +630,7 @@ def clone(self, **kwargs):
625630 return type (self )(axis = axis )
626631
627632
628- class Min (NonZeroCAReduce ):
633+ class Min (NonZeroDimsCAReduce ):
629634 nfunc_spec = ("min" , 1 , 1 )
630635
631636 def __init__ (self , axis ):
@@ -1496,7 +1501,7 @@ def complex_from_polar(abs, angle):
14961501 """Return complex-valued tensor from polar coordinate specification."""
14971502
14981503
1499- class Mean (CAReduce ):
1504+ class Mean (FixedOpCAReduce ):
15001505 __props__ = ("axis" ,)
15011506 nfunc_spec = ("mean" , 1 , 1 )
15021507
@@ -2356,7 +2361,7 @@ def outer(x, y):
23562361 return dot (x .dimshuffle (0 , "x" ), y .dimshuffle ("x" , 0 ))
23572362
23582363
2359- class All (CAReduce ):
2364+ class All (FixedOpCAReduce ):
23602365 """Applies `logical and` to all the values of a tensor along the
23612366 specified axis(es).
23622367
@@ -2370,12 +2375,6 @@ def __init__(self, axis=None):
23702375 def _output_dtype (self , idtype ):
23712376 return "bool"
23722377
2373- def __str__ (self ):
2374- if self .axis is None :
2375- return "All"
2376- else :
2377- return "All{%s}" % ", " .join (map (str , self .axis ))
2378-
23792378 def make_node (self , input ):
23802379 input = as_tensor_variable (input )
23812380 if input .dtype != "bool" :
@@ -2392,7 +2391,7 @@ def clone(self, **kwargs):
23922391 return type (self )(axis = axis )
23932392
23942393
2395- class Any (CAReduce ):
2394+ class Any (FixedOpCAReduce ):
23962395 """Applies `bitwise or` to all the values of a tensor along the
23972396 specified axis(es).
23982397
@@ -2406,12 +2405,6 @@ def __init__(self, axis=None):
24062405 def _output_dtype (self , idtype ):
24072406 return "bool"
24082407
2409- def __str__ (self ):
2410- if self .axis is None :
2411- return "Any"
2412- else :
2413- return "Any{%s}" % ", " .join (map (str , self .axis ))
2414-
24152408 def make_node (self , input ):
24162409 input = as_tensor_variable (input )
24172410 if input .dtype != "bool" :
@@ -2428,7 +2421,7 @@ def clone(self, **kwargs):
24282421 return type (self )(axis = axis )
24292422
24302423
2431- class Sum (CAReduce ):
2424+ class Sum (FixedOpCAReduce ):
24322425 """
24332426 Sums all the values of a tensor along the specified axis(es).
24342427
@@ -2449,14 +2442,6 @@ def __init__(self, axis=None, dtype=None, acc_dtype=None):
24492442 upcast_discrete_output = True ,
24502443 )
24512444
2452- def __str__ (self ):
2453- name = self .__class__ .__name__
2454- axis = ""
2455- if self .axis is not None :
2456- axis = ", " .join (str (x ) for x in self .axis )
2457- axis = f"axis=[{ axis } ], "
2458- return f"{ name } {{{ axis } acc_dtype={ self .acc_dtype } }}"
2459-
24602445 def L_op (self , inp , out , grads ):
24612446 (x ,) = inp
24622447
@@ -2526,7 +2511,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
25262511pprint .assign (Sum , printing .FunctionPrinter (["sum" ], ["axis" ]))
25272512
25282513
2529- class Prod (CAReduce ):
2514+ class Prod (FixedOpCAReduce ):
25302515 """
25312516 Multiplies all the values of a tensor along the specified axis(es).
25322517
@@ -2537,7 +2522,6 @@ class Prod(CAReduce):
25372522 """
25382523
25392524 __props__ = ("scalar_op" , "axis" , "dtype" , "acc_dtype" , "no_zeros_in_input" )
2540-
25412525 nfunc_spec = ("prod" , 1 , 1 )
25422526
25432527 def __init__ (self , axis = None , dtype = None , acc_dtype = None , no_zeros_in_input = False ):
@@ -2683,6 +2667,14 @@ def clone(self, **kwargs):
26832667 no_zeros_in_input = no_zeros_in_input ,
26842668 )
26852669
2670+ def __str__ (self ):
2671+ if self .no_zeros_in_input :
2672+ return f"{ super ().__str__ ()[:- 1 ]} , no_zeros_in_input}})"
2673+ return super ().__str__ ()
2674+
2675+ def __repr__ (self ):
2676+ return f"{ super ().__repr__ ()[:- 1 ]} , no_zeros_in_input={ self .no_zeros_in_input } )"
2677+
26862678
26872679def prod (
26882680 input ,
@@ -2751,7 +2743,7 @@ def c_code_cache_version(self):
27512743mul_without_zeros = MulWithoutZeros (aes .upcast_out , name = "mul_without_zeros" )
27522744
27532745
2754- class ProdWithoutZeros (CAReduce ):
2746+ class ProdWithoutZeros (FixedOpCAReduce ):
27552747 def __init__ (self , axis = None , dtype = None , acc_dtype = None ):
27562748 super ().__init__ (
27572749 mul_without_zeros ,
0 commit comments