@@ -1010,22 +1010,28 @@ def partial_cond(i1: float, i2: float) -> bool:
10101010 st .sampled_from ([(True , False ), (False , True ), (True , True )])
10111011 )
10121012
1013- def _x1_cond_from_dtype (dtype ) -> st .SearchStrategy [float ]:
1013+ def _x1_cond_from_dtype (dtype , ** kw ) -> st .SearchStrategy [float ]:
1014+ assert len (kw ) == 0 # sanity check
10141015 return use_x1_or_x2_strat .flatmap (
10151016 lambda t : cond_from_dtype (dtype )
10161017 if t [0 ]
10171018 else xps .from_dtype (dtype )
10181019 )
10191020
1020- def _x2_cond_from_dtype (dtype ) -> st .SearchStrategy [float ]:
1021+ def _x2_cond_from_dtype (dtype , ** kw ) -> st .SearchStrategy [float ]:
1022+ assert len (kw ) == 0 # sanity check
10211023 return use_x1_or_x2_strat .flatmap (
10221024 lambda t : cond_from_dtype (dtype )
10231025 if t [1 ]
10241026 else xps .from_dtype (dtype )
10251027 )
10261028
1027- x1_cond_from_dtypes .append (_x1_cond_from_dtype )
1028- x2_cond_from_dtypes .append (_x2_cond_from_dtype )
1029+ x1_cond_from_dtypes .append (
1030+ BoundFromDtype (base_func = _x1_cond_from_dtype )
1031+ )
1032+ x2_cond_from_dtypes .append (
1033+ BoundFromDtype (base_func = _x2_cond_from_dtype )
1034+ )
10291035
10301036 partial_conds .append (partial_cond )
10311037 partial_exprs .append (partial_expr )
@@ -1050,18 +1056,8 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
10501056 def cond (i1 : float , i2 : float ) -> bool :
10511057 return all (pc (i1 , i2 ) for pc in partial_conds )
10521058
1053- if len (x1_cond_from_dtypes ) == 0 :
1054- x1_cond_from_dtype = xps .from_dtype
1055- elif len (x1_cond_from_dtypes ) == 1 :
1056- x1_cond_from_dtype = x1_cond_from_dtypes [0 ]
1057- else :
1058- x1_cond_from_dtype = sum (x1_cond_from_dtypes , start = BoundFromDtype ())
1059- if len (x2_cond_from_dtypes ) == 0 :
1060- x2_cond_from_dtype = xps .from_dtype
1061- elif len (x2_cond_from_dtypes ) == 1 :
1062- x2_cond_from_dtype = x2_cond_from_dtypes [0 ]
1063- else :
1064- x2_cond_from_dtype = sum (x2_cond_from_dtypes , start = BoundFromDtype ())
1059+ x1_cond_from_dtype = sum (x1_cond_from_dtypes , start = BoundFromDtype ())
1060+ x2_cond_from_dtype = sum (x2_cond_from_dtypes , start = BoundFromDtype ())
10651061
10661062 return BinaryCase (
10671063 cond_expr = cond_expr ,
0 commit comments