@@ -319,7 +319,7 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
319319 return from_dtype
320320
321321
322- def parse_cond (cond_str : str ) -> Tuple [UnaryCheck , str , FromDtypeFunc ]:
322+ def parse_cond (cond_str : str ) -> Tuple [UnaryCheck , str , BoundFromDtype ]:
323323 """
324324 Parses a Sphinx-formatted condition string to return:
325325
@@ -348,22 +348,30 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
348348 124.978
349349
350350 """
351+ # We first identify whether the condition starts with "not". If so, we note
352+ # this but parse the condition as if it was not negated.
351353 if m := r_not .match (cond_str ):
352354 cond_str = m .group (1 )
353355 not_cond = True
354356 else :
355357 not_cond = False
356358
359+ # We parse the condition to identify the condition function, expression
360+ # template, and xps.from_dtype()-like condition strategy.
357361 kwargs = {}
358362 filter_ = None
359363 from_dtype = None # type: ignore
360- strat = None
361364 if m := r_code .match (cond_str ):
362365 value = parse_value (m .group (1 ))
363366 cond = make_strict_eq (value )
364367 expr_template = "{} == " + m .group (1 )
365- if not not_cond :
366- strat = st .just (value )
368+ from_dtype = wrap_strat_as_from_dtype (st .just (value ))
369+ elif m := r_either_code .match (cond_str ):
370+ v1 = parse_value (m .group (1 ))
371+ v2 = parse_value (m .group (2 ))
372+ cond = make_or (make_strict_eq (v1 ), make_strict_eq (v2 ))
373+ expr_template = "({} == " + m .group (1 ) + " or {} == " + m .group (2 ) + ")"
374+ from_dtype = wrap_strat_as_from_dtype (st .sampled_from ([v1 , v2 ]))
367375 elif m := r_equal_to .match (cond_str ):
368376 value = parse_value (m .group (1 ))
369377 if math .isnan (value ):
@@ -374,97 +382,73 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
374382 value = parse_value (m .group (1 ))
375383 cond = make_gt (value )
376384 expr_template = "{} > " + m .group (1 )
377- if not not_cond :
378- kwargs = {"min_value" : value , "exclude_min" : True }
385+ kwargs = {"min_value" : value , "exclude_min" : True }
379386 elif m := r_lt .match (cond_str ):
380387 value = parse_value (m .group (1 ))
381388 cond = make_lt (value )
382389 expr_template = "{} < " + m .group (1 )
383- if not not_cond :
384- kwargs = {"max_value" : value , "exclude_max" : True }
385- elif m := r_either_code .match (cond_str ):
386- v1 = parse_value (m .group (1 ))
387- v2 = parse_value (m .group (2 ))
388- cond = make_or (make_strict_eq (v1 ), make_strict_eq (v2 ))
389- expr_template = "({} == " + m .group (1 ) + " or {} == " + m .group (2 ) + ")"
390- if not not_cond :
391- strat = st .sampled_from ([v1 , v2 ])
390+ kwargs = {"max_value" : value , "exclude_max" : True }
392391 elif cond_str in ["finite" , "a finite number" ]:
393392 cond = math .isfinite
394393 expr_template = "isfinite({})"
395- if not not_cond :
396- kwargs = {"allow_nan" : False , "allow_infinity" : False }
394+ kwargs = {"allow_nan" : False , "allow_infinity" : False }
397395 elif cond_str in "a positive (i.e., greater than ``0``) finite number" :
398396 cond = lambda i : math .isfinite (i ) and i > 0
399397 expr_template = "isfinite({}) and {} > 0"
400- if not not_cond :
401- kwargs = {
402- "allow_nan" : False ,
403- "allow_infinity" : False ,
404- "min_value" : 0 ,
405- "exclude_min" : True ,
406- }
398+ kwargs = {
399+ "allow_nan" : False ,
400+ "allow_infinity" : False ,
401+ "min_value" : 0 ,
402+ "exclude_min" : True ,
403+ }
407404 elif cond_str == "a negative (i.e., less than ``0``) finite number" :
408405 cond = lambda i : math .isfinite (i ) and i < 0
409406 expr_template = "isfinite({}) and {} < 0"
410- if not not_cond :
411- kwargs = {
412- "allow_nan" : False ,
413- "allow_infinity" : False ,
414- "max_value" : 0 ,
415- "exclude_max" : True ,
416- }
407+ kwargs = {
408+ "allow_nan" : False ,
409+ "allow_infinity" : False ,
410+ "max_value" : 0 ,
411+ "exclude_max" : True ,
412+ }
417413 elif cond_str == "positive" :
418414 cond = lambda i : math .copysign (1 , i ) == 1
419415 expr_template = "copysign(1, {}) == 1"
420- if not not_cond :
421- # We assume (positive) zero is special cased seperately
422- kwargs = {"min_value" : 0 , "exclude_min" : True }
416+ # We assume (positive) zero is special cased seperately
417+ kwargs = {"min_value" : 0 , "exclude_min" : True }
423418 elif cond_str == "negative" :
424419 cond = lambda i : math .copysign (1 , i ) == - 1
425420 expr_template = "copysign(1, {}) == -1"
426- if not not_cond :
427- # We assume (negative) zero is special cased seperately
428- kwargs = {"max_value" : 0 , "exclude_max" : True }
421+ # We assume (negative) zero is special cased seperately
422+ kwargs = {"max_value" : 0 , "exclude_max" : True }
429423 elif "nonzero finite" in cond_str :
430424 cond = lambda i : math .isfinite (i ) and i != 0
431425 expr_template = "isfinite({}) and {} != 0"
432- if not not_cond :
433- kwargs = {"allow_nan" : False , "allow_infinity" : False }
434- filter_ = lambda n : n != 0
426+ kwargs = {"allow_nan" : False , "allow_infinity" : False }
427+ filter_ = lambda n : n != 0
435428 elif cond_str == "an integer value" :
436429 cond = lambda i : i .is_integer ()
437430 expr_template = "{}.is_integer()"
438- if not not_cond :
439- from_dtype = integers_from_dtype # type: ignore
431+ from_dtype = integers_from_dtype # type: ignore
440432 elif cond_str == "an odd integer value" :
441433 cond = lambda i : i .is_integer () and i % 2 == 1
442434 expr_template = "{}.is_integer() and {} % 2 == 1"
443- if not not_cond :
435+ from_dtype = integers_from_dtype # type: ignore
444436
445- def from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
446- return integers_from_dtype (dtype , ** kw ).filter (lambda n : n % 2 == 1 )
437+ def from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
438+ return integers_from_dtype (dtype , ** kw ).filter (lambda n : n % 2 == 1 )
447439
448440 else :
449441 raise ParseError (cond_str )
450442
451- if strat is not None :
452- if (
453- not_cond
454- or len (kwargs ) != 0
455- or filter_ is not None
456- or from_dtype is not None
457- ):
458- raise ParseError (cond_str )
459- return cond , expr_template , wrap_strat_as_from_dtype (strat )
460-
461443 if not_cond :
462- expr_template = f"not { expr_template } "
444+ # We handle negated conitions by simply negating the condition function
445+ # and using it as a filter for xps.from_dtype() (or an equivalent).
463446 cond = make_not_cond (cond )
464- kwargs = {}
447+ expr_template = f"not { expr_template } "
465448 filter_ = cond
466- assert kwargs is not None
467- return cond , expr_template , BoundFromDtype (kwargs , filter_ , from_dtype )
449+ return cond , expr_template , BoundFromDtype (filter_ = filter_ )
450+ else :
451+ return cond , expr_template , BoundFromDtype (kwargs , filter_ , from_dtype )
468452
469453
470454def parse_result (result_str : str ) -> Tuple [UnaryCheck , str ]:
@@ -838,6 +822,9 @@ def check_result(i1: float, i2: float, result: float) -> bool:
838822
839823
840824def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
825+ """
826+ Returns a strategy that generates float-casted integers within the bounds of dtype.
827+ """
841828 for k in kw .keys ():
842829 # sanity check
843830 assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
@@ -1036,16 +1023,12 @@ def cond(i1: float, i2: float) -> bool:
10361023 elif len (x1_cond_from_dtypes ) == 1 :
10371024 x1_cond_from_dtype = x1_cond_from_dtypes [0 ]
10381025 else :
1039- if not all (isinstance (fd , BoundFromDtype ) for fd in x1_cond_from_dtypes ):
1040- raise ParseError (case_str )
10411026 x1_cond_from_dtype = sum (x1_cond_from_dtypes , start = BoundFromDtype ())
10421027 if len (x2_cond_from_dtypes ) == 0 :
10431028 x2_cond_from_dtype = xps .from_dtype
10441029 elif len (x2_cond_from_dtypes ) == 1 :
10451030 x2_cond_from_dtype = x2_cond_from_dtypes [0 ]
10461031 else :
1047- if not all (isinstance (fd , BoundFromDtype ) for fd in x2_cond_from_dtypes ):
1048- raise ParseError (case_str )
10491032 x2_cond_from_dtype = sum (x2_cond_from_dtypes , start = BoundFromDtype ())
10501033
10511034 return BinaryCase (
0 commit comments