@@ -296,6 +296,19 @@ def assert_keepdimable_shape(
296296def assert_0d_equals (
297297 func_name : str , x_repr : str , x_val : Array , out_repr : str , out_val : Array , ** kw
298298):
299+ """
300+ Assert a 0d array is as expected, e.g.
301+
302+ >>> x = xp.asarray([0, 1, 2])
303+ >>> res = xp.asarray(x, copy=True)
304+ >>> res[0] = 42
305+ >>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0])
306+
307+ is equivalent to
308+
309+ >>> assert res[0] == x[0]
310+
311+ """
299312 msg = (
300313 f"{ out_repr } ={ out_val } , but should be { x_repr } ={ x_val } "
301314 f"[{ func_name } ({ fmt_kw (kw )} )]"
@@ -316,9 +329,21 @@ def assert_scalar_equals(
316329 repr_name : str = "out" ,
317330 ** kw ,
318331):
332+ """
333+ Assert a 0d array, convered to a scalar, is as expected, e.g.
334+
335+ >>> x = xp.ones(5, dtype=xp.uint8)
336+ >>> out = xp.sum(x)
337+ >>> assert_scalar_equals('sum', int, (), int(out), 5)
338+
339+ is equivalent to
340+
341+ >>> assert int(out) == 5
342+
343+ """
319344 repr_name = repr_name if idx == () else f"{ repr_name } [{ idx } ]"
320345 f_func = f"{ func_name } ({ fmt_kw (kw )} )"
321- if type_ is bool or type_ is int :
346+ if type_ in [ bool , int ] :
322347 msg = f"{ repr_name } ={ out } , but should be { expected } [{ f_func } ]"
323348 assert out == expected , msg
324349 elif math .isnan (expected ):
@@ -332,6 +357,17 @@ def assert_scalar_equals(
332357def assert_fill (
333358 func_name : str , fill_value : Scalar , dtype : DataType , out : Array , / , ** kw
334359):
360+ """
361+ Assert all elements of an array is as expected, e.g.
362+
363+ >>> out = xp.full(5, 42, dtype=xp.uint8)
364+ >>> assert_fill('full', 42, xp.uint8, out, 5)
365+
366+ is equivalent to
367+
368+ >>> assert xp.all(out == 42)
369+
370+ """
335371 msg = f"out not filled with { fill_value } [{ func_name } ({ fmt_kw (kw )} )]\n { out = } "
336372 if math .isnan (fill_value ):
337373 assert ah .all (ah .isnan (out )), msg
@@ -340,6 +376,18 @@ def assert_fill(
340376
341377
342378def assert_array (func_name : str , out : Array , expected : Array , / , ** kw ):
379+ """
380+ Assert array is (strictly) as expected, e.g.
381+
382+ >>> x = xp.arange(5)
383+ >>> out = xp.asarray(x)
384+ >>> assert_array('asarray', out, x)
385+
386+ is equivalent to
387+
388+ >>> assert xp.all(out == x)
389+
390+ """
343391 assert_dtype (func_name , out .dtype , expected .dtype )
344392 assert_shape (func_name , out .shape , expected .shape , ** kw )
345393 f_func = f"[{ func_name } ({ fmt_kw (kw )} )]"
0 commit comments