@@ -184,6 +184,53 @@ def binary_assert_against_refimpl(
184184 )
185185
186186
187+ def right_scalar_assert_against_refimpl (
188+ func_name : str ,
189+ left : Array ,
190+ right : Scalar ,
191+ res : Array ,
192+ refimpl : Callable [[T , T ], T ],
193+ expr_template : str = None ,
194+ res_stype : Optional [ScalarType ] = None ,
195+ left_sym : str = "x1" ,
196+ res_name : str = "out" ,
197+ filter_ : Callable [[Scalar ], bool ] = default_filter ,
198+ strict_check : Optional [bool ] = None ,
199+ ):
200+ if filter_ (right ):
201+ return # short-circuit here as there will be nothing to test
202+ in_stype = dh .get_scalar_type (left .dtype )
203+ if res_stype is None :
204+ res_stype = in_stype
205+ m , M = dh .dtype_ranges .get (left .dtype , (None , None ))
206+ for idx in sh .ndindex (res .shape ):
207+ scalar_l = in_stype (left [idx ])
208+ if not filter_ (scalar_l ):
209+ continue
210+ try :
211+ expected = refimpl (scalar_l , right )
212+ except Exception :
213+ continue
214+ if left .dtype != xp .bool :
215+ assert m is not None and M is not None # for mypy
216+ if expected <= m or expected >= M :
217+ continue
218+ scalar_o = res_stype (res [idx ])
219+ f_l = sh .fmt_idx (left_sym , idx )
220+ f_o = sh .fmt_idx (res_name , idx )
221+ expr = expr_template .format (f_l , right , expected )
222+ if strict_check == False or dh .is_float_dtype (res .dtype ):
223+ assert isclose (scalar_o , expected ), (
224+ f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
225+ f"{ f_l } ={ scalar_l } "
226+ )
227+ else :
228+ assert scalar_o == expected , (
229+ f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
230+ f"{ f_l } ={ scalar_l } "
231+ )
232+
233+
187234# When appropiate, this module tests operators alongside their respective
188235# elementwise methods. We do this by parametrizing a generalised test method
189236# with every relevant method and operator.
@@ -392,40 +439,19 @@ def binary_param_assert_against_refimpl(
392439):
393440 expr_template = "({} " + op_sym + " {})={}"
394441 if ctx .right_is_scalar :
395- if filter_ (right ):
396- return # short-circuit here as there will be nothing to test
397- in_stype = dh .get_scalar_type (left .dtype )
398- if res_stype is None :
399- res_stype = in_stype
400- m , M = dh .dtype_ranges .get (left .dtype , (None , None ))
401- for idx in sh .ndindex (res .shape ):
402- scalar_l = in_stype (left [idx ])
403- if not filter_ (scalar_l ):
404- continue
405- try :
406- expected = refimpl (scalar_l , right )
407- except Exception :
408- continue
409- if left .dtype != xp .bool :
410- assert m is not None and M is not None # for mypy
411- if expected <= m or expected >= M :
412- continue
413- scalar_o = res_stype (res [idx ])
414- f_l = sh .fmt_idx (ctx .left_sym , idx )
415- f_o = sh .fmt_idx (ctx .res_name , idx )
416- expr = expr_template .format (f_l , right , expected )
417- if strict_check == False or dh .is_float_dtype (res .dtype ):
418- assert isclose (scalar_o , expected ), (
419- f"{ f_o } ={ scalar_o } , but should be roughly { expr } "
420- f"[{ ctx .func_name } ()]\n "
421- f"{ f_l } ={ scalar_l } "
422- )
423- else :
424- assert scalar_o == expected , (
425- f"{ f_o } ={ scalar_o } , but should be { expr } "
426- f"[{ ctx .func_name } ()]\n "
427- f"{ f_l } ={ scalar_l } "
428- )
442+ right_scalar_assert_against_refimpl (
443+ func_name = ctx .func_name ,
444+ left_sym = ctx .left_sym ,
445+ left = left ,
446+ right = right ,
447+ res_stype = res_stype ,
448+ res_name = ctx .res_name ,
449+ res = res ,
450+ refimpl = refimpl ,
451+ expr_template = expr_template ,
452+ filter_ = filter_ ,
453+ strict_check = strict_check ,
454+ )
429455 else :
430456 binary_assert_against_refimpl (
431457 func_name = ctx .func_name ,
0 commit comments