22import unittest
33import numpy as np
44from onnx import TensorProto
5- from onnx_array_api .ext_test_case import ExtTestCase
5+ from onnx_array_api .ext_test_case import ExtTestCase , ignore_warnings
66from onnx_array_api .array_api import onnx_numpy as xp
77from onnx_array_api .npx .npx_types import DType
88from onnx_array_api .npx .npx_numpy_tensors import EagerNumpyTensor as EagerTensor
9+ from onnx_array_api .npx .npx_functions import linspace as linspace_inline
10+ from onnx_array_api .npx .npx_types import Float64 , Int64
11+ from onnx_array_api .npx .npx_var import Input
12+ from onnx_array_api .reference import ExtendedReferenceEvaluator
913
1014
1115class TestOnnxNumpy (ExtTestCase ):
@@ -22,6 +26,7 @@ def test_zeros(self):
2226 a = xp .absolute (mat )
2327 self .assertEqualArray (np .absolute (mat .numpy ()), a .numpy ())
2428
29+ @ignore_warnings (DeprecationWarning )
2530 def test_arange_default (self ):
2631 a = EagerTensor (np .array ([0 ], dtype = np .int64 ))
2732 b = EagerTensor (np .array ([2 ], dtype = np .int64 ))
@@ -30,6 +35,7 @@ def test_arange_default(self):
3035 self .assertEqual (matnp .shape , (2 ,))
3136 self .assertEqualArray (matnp , np .arange (0 , 2 ).astype (np .int64 ))
3237
38+ @ignore_warnings (DeprecationWarning )
3339 def test_arange_step (self ):
3440 a = EagerTensor (np .array ([4 ], dtype = np .int64 ))
3541 s = EagerTensor (np .array ([2 ], dtype = np .int64 ))
@@ -78,6 +84,7 @@ def test_full_bool(self):
7884 self .assertNotEmpty (matnp [0 , 0 ])
7985 self .assertEqualArray (matnp , np .full ((4 , 5 ), False ))
8086
87+ @ignore_warnings (DeprecationWarning )
8188 def test_arange_int00a (self ):
8289 a = EagerTensor (np .array ([0 ], dtype = np .int64 ))
8390 b = EagerTensor (np .array ([0 ], dtype = np .int64 ))
@@ -89,6 +96,7 @@ def test_arange_int00a(self):
8996 expected = expected .astype (np .int64 )
9097 self .assertEqualArray (matnp , expected )
9198
99+ @ignore_warnings (DeprecationWarning )
92100 def test_arange_int00 (self ):
93101 mat = xp .arange (0 , 0 )
94102 matnp = mat .numpy ()
@@ -160,10 +168,94 @@ def test_eye_k(self):
160168 got = xp .eye (nr , k = 1 )
161169 self .assertEqualArray (expected , got .numpy ())
162170
171+ def test_linspace_int (self ):
172+ a = EagerTensor (np .array ([0 ], dtype = np .int64 ))
173+ b = EagerTensor (np .array ([6 ], dtype = np .int64 ))
174+ c = EagerTensor (np .array (3 , dtype = np .int64 ))
175+ mat = xp .linspace (a , b , c )
176+ matnp = mat .numpy ()
177+ expected = np .linspace (a .numpy (), b .numpy (), c .numpy ()).astype (np .int64 )
178+ self .assertEqualArray (expected , matnp )
179+
180+ def test_linspace_int5 (self ):
181+ a = EagerTensor (np .array ([0 ], dtype = np .int64 ))
182+ b = EagerTensor (np .array ([5 ], dtype = np .int64 ))
183+ c = EagerTensor (np .array (3 , dtype = np .int64 ))
184+ mat = xp .linspace (a , b , c )
185+ matnp = mat .numpy ()
186+ expected = np .linspace (a .numpy (), b .numpy (), c .numpy ()).astype (np .int64 )
187+ self .assertEqualArray (expected , matnp )
188+
189+ def test_linspace_float (self ):
190+ a = EagerTensor (np .array ([0.5 ], dtype = np .float64 ))
191+ b = EagerTensor (np .array ([5.5 ], dtype = np .float64 ))
192+ c = EagerTensor (np .array (2 , dtype = np .int64 ))
193+ mat = xp .linspace (a , b , c )
194+ matnp = mat .numpy ()
195+ expected = np .linspace (a .numpy (), b .numpy (), c .numpy ())
196+ self .assertEqualArray (expected , matnp )
197+
198+ def test_linspace_float_noendpoint (self ):
199+ a = EagerTensor (np .array ([0.5 ], dtype = np .float64 ))
200+ b = EagerTensor (np .array ([5.5 ], dtype = np .float64 ))
201+ c = EagerTensor (np .array (2 , dtype = np .int64 ))
202+ mat = xp .linspace (a , b , c , endpoint = 0 )
203+ matnp = mat .numpy ()
204+ expected = np .linspace (a .numpy (), b .numpy (), c .numpy (), endpoint = 0 )
205+ self .assertEqualArray (expected , matnp )
206+
207+ @ignore_warnings ((RuntimeWarning , DeprecationWarning )) # division by zero
208+ def test_linspace_zero (self ):
209+ expected = np .linspace (0.0 , 0.0 , 0 , endpoint = False )
210+ mat = xp .linspace (0.0 , 0.0 , 0 , endpoint = False )
211+ matnp = mat .numpy ()
212+ self .assertEqualArray (expected , matnp )
213+
214+ @ignore_warnings ((RuntimeWarning , DeprecationWarning )) # division by zero
215+ def test_linspace_zero_one (self ):
216+ expected = np .linspace (0.0 , 0.0 , 1 , endpoint = True )
217+
218+ f = linspace_inline (Input ("start" ), Input ("stop" ), Input ("num" ))
219+ onx = f .to_onnx (
220+ constraints = {
221+ "start" : Float64 [None ],
222+ "stop" : Float64 [None ],
223+ "num" : Int64 [None ],
224+ (0 , False ): Float64 [None ],
225+ }
226+ )
227+ ref = ExtendedReferenceEvaluator (onx )
228+ got = ref .run (
229+ None ,
230+ {
231+ "start" : np .array (0 , dtype = np .float64 ),
232+ "stop" : np .array (0 , dtype = np .float64 ),
233+ "num" : np .array (1 , dtype = np .int64 ),
234+ },
235+ )
236+ self .assertEqualArray (expected , got [0 ])
237+
238+ mat = xp .linspace (0.0 , 0.0 , 1 , endpoint = True )
239+ matnp = mat .numpy ()
240+
241+ self .assertEqualArray (expected , matnp )
242+
243+ def test_slice_minus_one (self ):
244+ g = EagerTensor (np .array ([0.0 ]))
245+ expected = g .numpy ()[:- 1 ]
246+ got = g [:- 1 ]
247+ self .assertEqualArray (expected , got .numpy ())
248+
249+ def test_linspace_bug1 (self ):
250+ expected = np .linspace (16777217.0 , 0.0 , 1 )
251+ mat = xp .linspace (16777217.0 , 0.0 , 1 )
252+ matnp = mat .numpy ()
253+ self .assertEqualArray (expected , matnp )
254+
163255
164256if __name__ == "__main__" :
165257 # import logging
166258
167259 # logging.basicConfig(level=logging.DEBUG)
168- TestOnnxNumpy ().test_eye ()
260+ TestOnnxNumpy ().test_linspace_float_noendpoint ()
169261 unittest .main (verbosity = 2 )
0 commit comments