@@ -91,19 +91,28 @@ def assert_dtype(
9191 """
9292 Assert the output dtype is as expected.
9393
94- We infer the expected dtype from in_dtype and to test out_dtype, e.g.
94+ If expected=None, we infer the expected dtype as in_dtype, to test
95+ out_dtype, e.g.
9596
9697 >>> x = xp.arange(5, dtype=xp.uint8)
9798 >>> out = xp.abs(x)
9899 >>> assert_dtype('abs', x.dtype, out.dtype)
99100
101+ is equivalent to
102+
103+ >>> assert out.dtype == xp.uint8
104+
100105 Or for multiple input dtypes, the expected dtype is inferred from their
101106 resulting type promotion, e.g.
102107
103108 >>> x1 = xp.arange(5, dtype=xp.uint8)
104109 >>> x2 = xp.arange(5, dtype=xp.uint16)
105110 >>> out = xp.add(x1, x2)
106- >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype) # expected=xp.uint16
111+ >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
112+
113+ is equivalent to
114+
115+ >>> assert out.dtype == xp.uint16
107116
108117 We can also specify the expected dtype ourselves, e.g.
109118
@@ -182,7 +191,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
182191 """
183192 Assert the output dtype is the default index dtype, e.g.
184193
185- >>> out = xp.argmax(<array> )
194+ >>> out = xp.argmax(xp.arange(5) )
186195 >>> assert_default_int('argmax', out.dtype)
187196
188197 """
@@ -202,6 +211,13 @@ def assert_shape(
202211 repr_name = "out.shape" ,
203212 ** kw ,
204213):
214+ """
215+ Assert the output shape is as expected, e.g.
216+
217+ >>> out = xp.ones((3, 3, 3))
218+ >>> assert_shape('ones', out.shape, (3, 3, 3))
219+
220+ """
205221 if isinstance (out_shape , int ):
206222 out_shape = (out_shape ,)
207223 if isinstance (expected , int ):
@@ -222,6 +238,20 @@ def assert_result_shape(
222238 repr_name = "out.shape" ,
223239 ** kw ,
224240):
241+ """
242+ Assert the output shape is as expected.
243+
244+ If expected=None, we infer the expected shape as the result of broadcasting
245+ in_shapes, to test against out_shape, e.g.
246+
247+ >>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
248+ >>> assert_shape('add', [(3, 1), (1, 3)], out.shape)
249+
250+ is equivalent to
251+
252+ >>> assert out.shape == (3, 3)
253+
254+ """
225255 if expected is None :
226256 expected = sh .broadcast_shapes (* in_shapes )
227257 f_in_shapes = " . " .join (str (s ) for s in in_shapes )
0 commit comments