4242 paddle .complex128 ,
4343}
4444
45+ # NOTE: Implicit promotion rules of Paddle is a bit strict than other frameworks,
46+ # see details: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/advanced/auto_type_promotion_cn.html
4547_promotion_table = {
4648 # bool
4749 (paddle .bool , paddle .bool ): paddle .bool ,
4850 # ints
4951 (paddle .int8 , paddle .int8 ): paddle .int8 ,
50- (paddle .int8 , paddle .int16 ): paddle .int16 ,
51- (paddle .int8 , paddle .int32 ): paddle .int32 ,
52- (paddle .int8 , paddle .int64 ): paddle .int64 ,
53- (paddle .int16 , paddle .int8 ): paddle .int16 ,
5452 (paddle .int16 , paddle .int16 ): paddle .int16 ,
55- (paddle .int16 , paddle .int32 ): paddle .int32 ,
56- (paddle .int16 , paddle .int64 ): paddle .int64 ,
57- (paddle .int32 , paddle .int8 ): paddle .int32 ,
58- (paddle .int32 , paddle .int16 ): paddle .int32 ,
5953 (paddle .int32 , paddle .int32 ): paddle .int32 ,
60- (paddle .int32 , paddle .int64 ): paddle .int64 ,
61- (paddle .int64 , paddle .int8 ): paddle .int64 ,
62- (paddle .int64 , paddle .int16 ): paddle .int64 ,
63- (paddle .int64 , paddle .int32 ): paddle .int64 ,
6454 (paddle .int64 , paddle .int64 ): paddle .int64 ,
6555 # uints
6656 (paddle .uint8 , paddle .uint8 ): paddle .uint8 ,
67- # ints and uints (mixed sign)
68- (paddle .int8 , paddle .uint8 ): paddle .int16 ,
69- (paddle .int16 , paddle .uint8 ): paddle .int16 ,
70- (paddle .int32 , paddle .uint8 ): paddle .int32 ,
71- (paddle .int64 , paddle .uint8 ): paddle .int64 ,
72- (paddle .uint8 , paddle .int8 ): paddle .int16 ,
73- (paddle .uint8 , paddle .int16 ): paddle .int16 ,
74- (paddle .uint8 , paddle .int32 ): paddle .int32 ,
75- (paddle .uint8 , paddle .int64 ): paddle .int64 ,
7657 # floats
7758 (paddle .float32 , paddle .float32 ): paddle .float32 ,
7859 (paddle .float32 , paddle .float64 ): paddle .float64 ,
@@ -158,12 +139,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
158139 paddle .float64 : True ,
159140 paddle .complex64 : True ,
160141 paddle .complex128 : True ,
161- paddle .uint8 : False ,
162- paddle .int8 : False ,
163- paddle .int16 : False ,
164- paddle .int32 : False ,
165- paddle .int64 : False ,
166- paddle .bool : False ,
142+ paddle .uint8 : True ,
143+ paddle .int8 : True ,
144+ paddle .int16 : True ,
145+ paddle .int32 : True ,
146+ paddle .int64 : True ,
147+ paddle .bool : True ,
167148 },
168149 paddle .float16 : {
169150 paddle .bfloat16 : True ,
@@ -172,12 +153,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
172153 paddle .float64 : True ,
173154 paddle .complex64 : True ,
174155 paddle .complex128 : True ,
175- paddle .uint8 : False ,
176- paddle .int8 : False ,
177- paddle .int16 : False ,
178- paddle .int32 : False ,
179- paddle .int64 : False ,
180- paddle .bool : False ,
156+ paddle .uint8 : True ,
157+ paddle .int8 : True ,
158+ paddle .int16 : True ,
159+ paddle .int32 : True ,
160+ paddle .int64 : True ,
161+ paddle .bool : True ,
181162 },
182163 paddle .float32 : {
183164 paddle .bfloat16 : True ,
@@ -186,12 +167,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
186167 paddle .float64 : True ,
187168 paddle .complex64 : True ,
188169 paddle .complex128 : True ,
189- paddle .uint8 : False ,
190- paddle .int8 : False ,
191- paddle .int16 : False ,
192- paddle .int32 : False ,
193- paddle .int64 : False ,
194- paddle .bool : False ,
170+ paddle .uint8 : True ,
171+ paddle .int8 : True ,
172+ paddle .int16 : True ,
173+ paddle .int32 : True ,
174+ paddle .int64 : True ,
175+ paddle .bool : True ,
195176 },
196177 paddle .float64 : {
197178 paddle .bfloat16 : True ,
@@ -200,40 +181,40 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
200181 paddle .float64 : True ,
201182 paddle .complex64 : True ,
202183 paddle .complex128 : True ,
203- paddle .uint8 : False ,
204- paddle .int8 : False ,
205- paddle .int16 : False ,
206- paddle .int32 : False ,
207- paddle .int64 : False ,
208- paddle .bool : False ,
184+ paddle .uint8 : True ,
185+ paddle .int8 : True ,
186+ paddle .int16 : True ,
187+ paddle .int32 : True ,
188+ paddle .int64 : True ,
189+ paddle .bool : True ,
209190 },
210191 paddle .complex64 : {
211- paddle .bfloat16 : False ,
212- paddle .float16 : False ,
213- paddle .float32 : False ,
214- paddle .float64 : False ,
192+ paddle .bfloat16 : True ,
193+ paddle .float16 : True ,
194+ paddle .float32 : True ,
195+ paddle .float64 : True ,
215196 paddle .complex64 : True ,
216197 paddle .complex128 : True ,
217- paddle .uint8 : False ,
218- paddle .int8 : False ,
219- paddle .int16 : False ,
220- paddle .int32 : False ,
221- paddle .int64 : False ,
222- paddle .bool : False ,
198+ paddle .uint8 : True ,
199+ paddle .int8 : True ,
200+ paddle .int16 : True ,
201+ paddle .int32 : True ,
202+ paddle .int64 : True ,
203+ paddle .bool : True ,
223204 },
224205 paddle .complex128 : {
225- paddle .bfloat16 : False ,
226- paddle .float16 : False ,
227- paddle .float32 : False ,
228- paddle .float64 : False ,
206+ paddle .bfloat16 : True ,
207+ paddle .float16 : True ,
208+ paddle .float32 : True ,
209+ paddle .float64 : True ,
229210 paddle .complex64 : True ,
230211 paddle .complex128 : True ,
231- paddle .uint8 : False ,
232- paddle .int8 : False ,
233- paddle .int16 : False ,
234- paddle .int32 : False ,
235- paddle .int64 : False ,
236- paddle .bool : False ,
212+ paddle .uint8 : True ,
213+ paddle .int8 : True ,
214+ paddle .int16 : True ,
215+ paddle .int32 : True ,
216+ paddle .int64 : True ,
217+ paddle .bool : True ,
237218 },
238219 paddle .uint8 : {
239220 paddle .bfloat16 : True ,
@@ -247,7 +228,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
247228 paddle .int16 : True ,
248229 paddle .int32 : True ,
249230 paddle .int64 : True ,
250- paddle .bool : False ,
231+ paddle .bool : True ,
251232 },
252233 paddle .int8 : {
253234 paddle .bfloat16 : True ,
@@ -261,7 +242,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
261242 paddle .int16 : True ,
262243 paddle .int32 : True ,
263244 paddle .int64 : True ,
264- paddle .bool : False ,
245+ paddle .bool : True ,
265246 },
266247 paddle .int16 : {
267248 paddle .bfloat16 : True ,
@@ -275,7 +256,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
275256 paddle .int16 : True ,
276257 paddle .int32 : True ,
277258 paddle .int64 : True ,
278- paddle .bool : False ,
259+ paddle .bool : True ,
279260 },
280261 paddle .int32 : {
281262 paddle .bfloat16 : True ,
@@ -289,7 +270,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
289270 paddle .int16 : True ,
290271 paddle .int32 : True ,
291272 paddle .int64 : True ,
292- paddle .bool : False ,
273+ paddle .bool : True ,
293274 },
294275 paddle .int64 : {
295276 paddle .bfloat16 : True ,
@@ -303,7 +284,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
303284 paddle .int16 : True ,
304285 paddle .int32 : True ,
305286 paddle .int64 : True ,
306- paddle .bool : False ,
287+ paddle .bool : True ,
307288 },
308289 paddle .bool : {
309290 paddle .bfloat16 : True ,
0 commit comments