@@ -89,7 +89,7 @@ def assert_dtype(
8989 repr_name : str = "out.dtype" ,
9090):
9191 """
92- Tests the output dtype is as expected.
92+ Assert the output dtype is as expected.
9393
9494 We infer the expected dtype from in_dtype and to test out_dtype, e.g.
9595
@@ -128,7 +128,7 @@ def assert_dtype(
128128
129129def assert_kw_dtype (func_name : str , kw_dtype : DataType , out_dtype : DataType ):
130130 """
131- Test the output dtype is the passed keyword dtype, e.g.
131+ Assert the output dtype is the passed keyword dtype, e.g.
132132
133133 >>> kw = {'dtype': xp.uint8}
134134 >>> out = xp.ones(5, **kw)
@@ -144,33 +144,54 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
144144 assert out_dtype == kw_dtype , msg
145145
146146
147- def assert_default_float (func_name : str , dtype : DataType ):
148- f_dtype = dh .dtype_to_name [dtype ]
147+ def assert_default_float (func_name : str , out_dtype : DataType ):
148+ """
149+ Assert the output dtype is the default float, e.g.
150+
151+ >>> out = xp.ones(5)
152+ >>> assert_default_float('ones', out.dtype)
153+
154+ """
155+ f_dtype = dh .dtype_to_name [out_dtype ]
149156 f_default = dh .dtype_to_name [dh .default_float ]
150157 msg = (
151158 f"out.dtype={ f_dtype } , should be default "
152159 f"floating-point dtype { f_default } [{ func_name } ()]"
153160 )
154- assert dtype == dh .default_float , msg
161+ assert out_dtype == dh .default_float , msg
162+
163+
164+ def assert_default_int (func_name : str , out_dtype : DataType ):
165+ """
166+ Assert the output dtype is the default int, e.g.
155167
168+ >>> out = xp.full(5, 42)
169+ >>> assert_default_int('full', out.dtype)
156170
157- def assert_default_int ( func_name : str , dtype : DataType ):
158- f_dtype = dh .dtype_to_name [dtype ]
171+ """
172+ f_dtype = dh .dtype_to_name [out_dtype ]
159173 f_default = dh .dtype_to_name [dh .default_int ]
160174 msg = (
161175 f"out.dtype={ f_dtype } , should be default "
162176 f"integer dtype { f_default } [{ func_name } ()]"
163177 )
164- assert dtype == dh .default_int , msg
178+ assert out_dtype == dh .default_int , msg
179+
165180
181+ def assert_default_index (func_name : str , out_dtype : DataType , repr_name = "out.dtype" ):
182+ """
183+ Assert the output dtype is the default index dtype, e.g.
184+
185+ >>> out = xp.argmax(<array>)
186+ >>> assert_default_int('argmax', out.dtype)
166187
167- def assert_default_index ( func_name : str , dtype : DataType , repr_name = "out.dtype" ):
168- f_dtype = dh .dtype_to_name [dtype ]
188+ """
189+ f_dtype = dh .dtype_to_name [out_dtype ]
169190 msg = (
170191 f"{ repr_name } ={ f_dtype } , should be the default index dtype, "
171192 f"which is either int32 or int64 [{ func_name } ()]"
172193 )
173- assert dtype in (xp .int32 , xp .int64 ), msg
194+ assert out_dtype in (xp .int32 , xp .int64 ), msg
174195
175196
176197def assert_shape (
0 commit comments