@@ -88,6 +88,31 @@ def assert_dtype(
8888 * ,
8989 repr_name : str = "out.dtype" ,
9090):
91+ """
92+ Tests the output dtype is as expected.
93+
94+ We infer the expected dtype from in_dtype and to test out_dtype, e.g.
95+
96+ >>> x = xp.arange(5, dtype=xp.uint8)
97+ >>> out = xp.abs(x)
98+ >>> assert_dtype('abs', x.dtype, out.dtype)
99+
100+ Or for multiple input dtypes, the expected dtype is inferred from their
101+ resulting type promotion, e.g.
102+
103+ >>> x1 = xp.arange(5, dtype=xp.uint8)
104+ >>> x2 = xp.arange(5, dtype=xp.uint16)
105+ >>> out = xp.add(x1, x2)
106+ >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype) # expected=xp.uint16
107+
108+ We can also specify the expected dtype ourselves, e.g.
109+
110+ >>> x = xp.arange(5, dtype=xp.int8)
111+ >>> out = xp.sum(x)
112+ >>> default_int = xp.asarray(0).dtype
113+ >>> assert_dtype('sum', x, out.dtype, default_int)
114+
115+ """
91116 in_dtypes = in_dtype if isinstance (in_dtype , Sequence ) else [in_dtype ]
92117 f_in_dtypes = dh .fmt_types (tuple (in_dtypes ))
93118 f_out_dtype = dh .dtype_to_name [out_dtype ]
0 commit comments