@@ -2081,63 +2081,65 @@ def local_pow_to_nested_squaring(fgraph, node):
20812081 Note: This sounds like the kind of thing any half-decent compiler can do by itself?
20822082 """
20832083
2084- if node .op == at_pow :
2085- # the idea here is that we have pow(x, y)
2086- odtype = node .outputs [0 ].dtype
2087- xsym = node .inputs [0 ]
2088- ysym = node .inputs [1 ]
2089- y = get_constant (ysym )
2090-
2091- # the next line is needed to fix a strange case that I don't
2092- # know how to make a separate test.
2093- # That happen in the `test_log_erfc` test.
2094- # y is a ndarray with dtype int8 and value 2,4 or 6. This make
2095- # the abs(y) <= 512 fail!
2096- # taking the value outside ndarray solve the problem.
2097- # it could be that in that case, numpy make the comparison
2098- # into the wrong type(do in int8 that overflow.)
2099- if isinstance (y , np .ndarray ):
2100- assert y .size == 1
2101- try :
2102- y = y [0 ]
2103- except IndexError :
2104- pass
2105- if (y is not None ) and not broadcasted_by (xsym , ysym ):
2106- rval = None
2107- # 512 is too small for the cpu and too big for some gpu!
2108- if abs (y ) == int (abs (y )) and abs (y ) <= 512 :
2109- pow2 = [xsym ]
2110- pow2_scal = [aes .get_scalar_type (xsym .dtype )()]
2111- y_to_do = abs (y )
2112- for i in range (int (np .log2 (y_to_do ))):
2113- pow2 .append (sqr (pow2 [i ]))
2114- pow2_scal .append (aes .sqr (pow2_scal [i ]))
2115- rval1 = None
2116- rval1_scal = None
2117- while y_to_do > 0 :
2118- log_to_do = int (np .log2 (y_to_do ))
2119- if rval1 :
2120- rval1 *= pow2 [log_to_do ]
2121- rval1_scal *= pow2_scal [log_to_do ]
2122- else :
2123- rval1 = pow2 [log_to_do ]
2124- rval1_scal = pow2_scal [log_to_do ]
2125- y_to_do -= 2 ** log_to_do
2126-
2127- if abs (y ) > 2 :
2128- # We fuse all the pow together here to make
2129- # compilation faster
2130- rval1 = Elemwise (
2131- aes .Composite ([pow2_scal [0 ]], [rval1_scal ])
2132- ).make_node (xsym )
2133- if y < 0 :
2134- rval = [reciprocal (rval1 )]
2084+ # the idea here is that we have pow(x, y)
2085+ odtype = node .outputs [0 ].dtype
2086+ xsym = node .inputs [0 ]
2087+ ysym = node .inputs [1 ]
2088+ y = get_constant (ysym )
2089+
2090+ # the next line is needed to fix a strange case that I don't
2091+ # know how to make a separate test.
2092+ # That happen in the `test_log_erfc` test.
2093+ # y is a ndarray with dtype int8 and value 2,4 or 6. This make
2094+ # the abs(y) <= 512 fail!
2095+ # taking the value outside ndarray solve the problem.
2096+ # it could be that in that case, numpy make the comparison
2097+ # into the wrong type(do in int8 that overflow.)
2098+ if isinstance (y , np .ndarray ):
2099+ assert y .size == 1
2100+ try :
2101+ y = y [0 ]
2102+ except IndexError :
2103+ pass
2104+ if (y is not None ) and not broadcasted_by (xsym , ysym ):
2105+ rval = None
2106+ # 512 is too small for the cpu and too big for some gpu!
2107+ if abs (y ) == int (abs (y )) and abs (y ) <= 512 :
2108+ pow2 = [xsym ]
2109+ pow2_scal = [aes .get_scalar_type (xsym .dtype )()]
2110+ y_to_do = abs (y )
2111+ for i in range (int (np .log2 (y_to_do ))):
2112+ pow2 .append (sqr (pow2 [i ]))
2113+ pow2_scal .append (aes .sqr (pow2_scal [i ]))
2114+ rval1 = None
2115+ rval1_scal = None
2116+ while y_to_do > 0 :
2117+ log_to_do = int (np .log2 (y_to_do ))
2118+ if rval1 :
2119+ rval1 *= pow2 [log_to_do ]
2120+ rval1_scal *= pow2_scal [log_to_do ]
21352121 else :
2136- rval = [rval1 ]
2137- if rval :
2138- rval [0 ] = cast (rval [0 ], odtype )
2139- assert rval [0 ].type == node .outputs [0 ].type , (rval , node .outputs )
2140- return rval
2122+ rval1 = pow2 [log_to_do ]
2123+ rval1_scal = pow2_scal [log_to_do ]
2124+ y_to_do -= 2 ** log_to_do
2125+
2126+ if abs (y ) > 2 :
2127+ # We fuse all the pow together here to make
2128+ # compilation faster
2129+ rval1 = Elemwise (aes .Composite ([pow2_scal [0 ]], [rval1_scal ])).make_node (
2130+ xsym
2131+ )
2132+ if y < 0 :
2133+ rval = [reciprocal (rval1 )]
2134+ else :
2135+ rval = [rval1 ]
2136+ if rval :
2137+ rval [0 ] = cast (rval [0 ], odtype )
2138+ # TODO: We can add a specify_broadcastable and/or unbroadcast to make the
2139+ # output types compatible. Or work on #408 and let TensorType.filter_variable do it.
2140+ if rval [0 ].type .broadcastable != node .outputs [0 ].type .broadcastable :
2141+ return None
2142+ return rval
21412143
21422144
21432145@register_specialize
0 commit comments