@@ -82,10 +82,10 @@ def is_neg_zero(n: float) -> bool:
8282
8383def assert_dtype (
8484 func_name : str ,
85+ * ,
8586 in_dtype : Union [DataType , Sequence [DataType ]],
8687 out_dtype : DataType ,
8788 expected : Optional [DataType ] = None ,
88- * ,
8989 repr_name : str = "out.dtype" ,
9090):
9191 """
@@ -96,7 +96,7 @@ def assert_dtype(
9696
9797 >>> x = xp.arange(5, dtype=xp.uint8)
9898 >>> out = xp.abs(x)
99- >>> assert_dtype('abs', x.dtype, out.dtype)
99+ >>> assert_dtype('abs', in_dtype= x.dtype, out_dtype= out.dtype)
100100
101101 is equivalent to
102102
@@ -108,7 +108,7 @@ def assert_dtype(
108108 >>> x1 = xp.arange(5, dtype=xp.uint8)
109109 >>> x2 = xp.arange(5, dtype=xp.uint16)
110110 >>> out = xp.add(x1, x2)
111- >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
111+ >>> assert_dtype('add', in_dtype= [x1.dtype, x2.dtype], out_dtype= out.dtype)
112112
113113 is equivalent to
114114
@@ -119,7 +119,7 @@ def assert_dtype(
119119 >>> x = xp.arange(5, dtype=xp.int8)
120120 >>> out = xp.sum(x)
121121 >>> default_int = xp.asarray(0).dtype
122- >>> assert_dtype('sum', x, out.dtype, default_int)
122+ >>> assert_dtype('sum', in_dtype= x, out_dtype= out.dtype, expected= default_int)
123123
124124 """
125125 in_dtypes = in_dtype if isinstance (in_dtype , Sequence ) and not isinstance (in_dtype , str ) else [in_dtype ]
@@ -135,13 +135,18 @@ def assert_dtype(
135135 assert out_dtype == expected , msg
136136
137137
138- def assert_kw_dtype (func_name : str , kw_dtype : DataType , out_dtype : DataType ):
138+ def assert_kw_dtype (
139+ func_name : str ,
140+ * ,
141+ kw_dtype : DataType ,
142+ out_dtype : DataType ,
143+ ):
139144 """
140145 Assert the output dtype is the passed keyword dtype, e.g.
141146
142147 >>> kw = {'dtype': xp.uint8}
143- >>> out = xp.ones(5, ** kw)
144- >>> assert_kw_dtype('ones', kw['dtype'], out.dtype)
148+ >>> out = xp.ones(5, kw= kw)
149+ >>> assert_kw_dtype('ones', kw_dtype= kw['dtype'], out_dtype= out.dtype)
145150
146151 """
147152 f_kw_dtype = dh .dtype_to_name [kw_dtype ]
@@ -222,17 +227,17 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
222227
223228def assert_shape (
224229 func_name : str ,
230+ * ,
225231 out_shape : Union [int , Shape ],
226232 expected : Union [int , Shape ],
227- / ,
228233 repr_name = "out.shape" ,
229- ** kw ,
234+ kw : dict = {} ,
230235):
231236 """
232237 Assert the output shape is as expected, e.g.
233238
234239 >>> out = xp.ones((3, 3, 3))
235- >>> assert_shape('ones', out.shape, (3, 3, 3))
240+ >>> assert_shape('ones', out_shape= out.shape, expected= (3, 3, 3))
236241
237242 """
238243 if isinstance (out_shape , int ):
@@ -249,11 +254,10 @@ def assert_result_shape(
249254 func_name : str ,
250255 in_shapes : Sequence [Shape ],
251256 out_shape : Shape ,
252- / ,
253257 expected : Optional [Shape ] = None ,
254258 * ,
255259 repr_name = "out.shape" ,
256- ** kw ,
260+ kw : dict = {} ,
257261):
258262 """
259263 Assert the output shape is as expected.
@@ -262,7 +266,7 @@ def assert_result_shape(
262266 in_shapes, to test against out_shape, e.g.
263267
264268 >>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
265- >>> assert_shape ('add', [(3, 1), (1, 3)], out.shape)
269+ >>> assert_result_shape ('add', in_shape= [(3, 1), (1, 3)], out_shape= out.shape)
266270
267271 is equivalent to
268272
@@ -281,21 +285,21 @@ def assert_result_shape(
281285
282286def assert_keepdimable_shape (
283287 func_name : str ,
288+ * ,
284289 in_shape : Shape ,
285290 out_shape : Shape ,
286291 axes : Tuple [int , ...],
287292 keepdims : bool ,
288- / ,
289- ** kw ,
293+ kw : dict = {},
290294):
291295 """
292296 Assert the output shape from a keepdimable function is as expected, e.g.
293297
294298 >>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
295299 >>> out1 = xp.max(x, keepdims=False)
296300 >>> out2 = xp.max(x, keepdims=True)
297- >>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False)
298- >>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True)
301+ >>> assert_keepdimable_shape('max', in_shape= x.shape, out_shape= out1.shape, axes= (0, 1), keepdims= False)
302+ >>> assert_keepdimable_shape('max', in_shape= x.shape, out_shape= out2.shape, axes= (0, 1), keepdims= True)
299303
300304 is equivalent to
301305
@@ -307,19 +311,26 @@ def assert_keepdimable_shape(
307311 shape = tuple (1 if axis in axes else side for axis , side in enumerate (in_shape ))
308312 else :
309313 shape = tuple (side for axis , side in enumerate (in_shape ) if axis not in axes )
310- assert_shape (func_name , out_shape , shape , ** kw )
314+ assert_shape (func_name , out_shape = out_shape , expected = shape , kw = kw )
311315
312316
313317def assert_0d_equals (
314- func_name : str , x_repr : str , x_val : Array , out_repr : str , out_val : Array , ** kw
318+ func_name : str ,
319+ * ,
320+ x_repr : str ,
321+ x_val : Array ,
322+ out_repr : str ,
323+ out_val : Array ,
324+ kw : dict = {},
315325):
316326 """
317327 Assert a 0d array is as expected, e.g.
318328
319329 >>> x = xp.asarray([0, 1, 2])
320- >>> res = xp.asarray(x, copy=True)
330+ >>> kw = {'copy': True}
331+ >>> res = xp.asarray(x, **kw)
321332 >>> res[0] = 42
322- >>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0])
333+ >>> assert_0d_equals('asarray', x_repr= 'x[0]', x_val= x[0], out_repr= 'x[0]', out_val= res[0], kw=kw )
323334
324335 is equivalent to
325336
@@ -338,20 +349,20 @@ def assert_0d_equals(
338349
339350def assert_scalar_equals (
340351 func_name : str ,
352+ * ,
341353 type_ : ScalarType ,
342354 idx : Shape ,
343355 out : Scalar ,
344356 expected : Scalar ,
345- / ,
346357 repr_name : str = "out" ,
347- ** kw ,
358+ kw : dict = {} ,
348359):
349360 """
350361 Assert a 0d array, convered to a scalar, is as expected, e.g.
351362
352363 >>> x = xp.ones(5, dtype=xp.uint8)
353364 >>> out = xp.sum(x)
354- >>> assert_scalar_equals('sum', int, (), int(out), 5)
365+ >>> assert_scalar_equals('sum', type_int, out= (), out= int(out), expected= 5)
355366
356367 is equivalent to
357368
@@ -372,13 +383,18 @@ def assert_scalar_equals(
372383
373384
374385def assert_fill (
375- func_name : str , fill_value : Scalar , dtype : DataType , out : Array , / , ** kw
386+ func_name : str ,
387+ * ,
388+ fill_value : Scalar ,
389+ dtype : DataType ,
390+ out : Array ,
391+ kw : dict = {},
376392):
377393 """
378394 Assert all elements of an array is as expected, e.g.
379395
380396 >>> out = xp.full(5, 42, dtype=xp.uint8)
381- >>> assert_fill('full', 42, xp.uint8, out, 5 )
397+ >>> assert_fill('full', fill_value= 42, dtype= xp.uint8, out=out, kw=dict(shape=5) )
382398
383399 is equivalent to
384400
@@ -408,22 +424,27 @@ def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
408424
409425
410426def assert_array_elements (
411- func_name : str , out : Array , expected : Array , / , * , out_repr : str = "out" , ** kw
427+ func_name : str ,
428+ * ,
429+ out : Array ,
430+ expected : Array ,
431+ out_repr : str = "out" ,
432+ kw : dict = {},
412433):
413434 """
414435 Assert array elements are (strictly) as expected, e.g.
415436
416437 >>> x = xp.arange(5)
417438 >>> out = xp.asarray(x)
418- >>> assert_array_elements('asarray', out, x)
439+ >>> assert_array_elements('asarray', out=out, expected= x)
419440
420441 is equivalent to
421442
422443 >>> assert xp.all(out == x)
423444
424445 """
425446 dh .result_type (out .dtype , expected .dtype ) # sanity check
426- assert_shape (func_name , out .shape , expected .shape , ** kw ) # sanity check
447+ assert_shape (func_name , out_shape = out .shape , expected = expected .shape , kw = kw ) # sanity check
427448 f_func = f"[{ func_name } ({ fmt_kw (kw )} )]"
428449 if out .dtype in dh .float_dtypes :
429450 for idx in sh .ndindex (out .shape ):
0 commit comments