1+ import warnings
12from typing import Any , Callable , List , Optional , Tuple
23import numpy as np
34from onnx import ModelProto , TensorProto
@@ -221,13 +222,18 @@ def __bool__(self):
221222 if self .shape == (0 ,):
222223 return False
223224 if len (self .shape ) != 0 :
224- raise ValueError (
225- f"Conversion to bool only works for scalar, not for { self !r} ."
225+ warnings .warn (
226+ f"Conversion to bool only works for scalar, not for { self !r} , "
227+ f"bool(...)={ bool (self ._tensor )} ."
226228 )
229+ try :
230+ return bool (self ._tensor )
231+ except ValueError as e :
232+ raise ValueError (f"Unable to convert { self } to bool." ) from e
227233 return bool (self ._tensor )
228234
229235 def __int__ (self ):
230- "Implicit conversion to bool ."
236+ "Implicit conversion to int ."
231237 if len (self .shape ) != 0 :
232238 raise ValueError (
233239 f"Conversion to bool only works for scalar, not for { self !r} ."
@@ -249,7 +255,7 @@ def __int__(self):
249255 return int (self ._tensor )
250256
251257 def __float__ (self ):
252- "Implicit conversion to bool ."
258+ "Implicit conversion to float ."
253259 if len (self .shape ) != 0 :
254260 raise ValueError (
255261 f"Conversion to bool only works for scalar, not for { self !r} ."
@@ -261,11 +267,24 @@ def __float__(self):
261267 DType (TensorProto .BFLOAT16 ),
262268 }:
263269 raise TypeError (
264- f"Conversion to int only works for float scalar, "
270+ f"Conversion to float only works for float scalar, "
265271 f"not for dtype={ self .dtype } ."
266272 )
267273 return float (self ._tensor )
268274
275+ def __iter__ (self ):
276+ """
277+ The :epkg:`Array API` does not define this function (2022/12).
278+ This method raises an exception with a better error message.
279+ """
280+ warnings .warn (
281+ f"Iterators are not implemented in the generic case. "
282+ f"Every function using them cannot be converted into ONNX "
283+ f"(tensors - { type (self )} )."
284+ )
285+ for row in self ._tensor :
286+ yield self .__class__ (row )
287+
269288
270289class JitNumpyTensor (NumpyTensor , JitTensor ):
271290 """
0 commit comments