@@ -535,12 +535,34 @@ class UnaryCase(Case):
535535
536536
537537r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
538+ r_already_int_case = re .compile (
539+ "If ``x_i`` is already integer-valued, the result is ``x_i``"
540+ )
538541r_even_round_halves_case = re .compile (
539542 "If two integers are equally close to ``x_i``, "
540543 "the result is the even integer closest to ``x_i``"
541544)
542545
543546
547+ def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
548+ """
549+ Returns a strategy that generates float-casted integers within the bounds of dtype.
550+ """
551+ for k in kw .keys ():
552+ # sanity check
553+ assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
554+ m , M = dh .dtype_ranges [dtype ]
555+ if "min_value" in kw .keys ():
556+ m = kw ["min_value" ]
557+ if "exclude_min" in kw .keys ():
558+ m += 1
559+ if "max_value" in kw .keys ():
560+ M = kw ["max_value" ]
561+ if "exclude_max" in kw .keys ():
562+ M -= 1
563+ return st .integers (math .ceil (m ), math .floor (M )).map (float )
564+
565+
544566def trailing_halves_from_dtype (dtype : DataType ) -> st .SearchStrategy [float ]:
545567 """
546568 Returns a strategy that generates floats that end with .5 and are within the
@@ -557,6 +579,13 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
557579 )
558580
559581
582+ already_int_case = UnaryCase (
583+ cond_expr = "x_i.is_integer()" ,
584+ cond = lambda i : i .is_integer (),
585+ cond_from_dtype = integers_from_dtype ,
586+ result_expr = "x_i" ,
587+ check_result = lambda i , result : i == result ,
588+ )
560589even_round_halves_case = UnaryCase (
561590 cond_expr = "modf(i)[0] == 0.5" ,
562591 cond = lambda i : math .modf (i )[0 ] == 0.5 ,
@@ -624,7 +653,11 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
624653 cases = []
625654 for case_m in r_case .finditer (case_block ):
626655 case_str = case_m .group (1 )
627- if m := r_unary_case .search (case_str ):
656+ if m := r_already_int_case .search (case_str ):
657+ cases .append (already_int_case )
658+ elif m := r_even_round_halves_case .search (case_str ):
659+ cases .append (even_round_halves_case )
660+ elif m := r_unary_case .search (case_str ):
628661 try :
629662 cond , cond_expr_template , cond_from_dtype = parse_cond (m .group (1 ))
630663 _check_result , result_expr = parse_result (m .group (2 ))
@@ -643,8 +676,6 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
643676 check_result = check_result ,
644677 )
645678 cases .append (case )
646- elif m := r_even_round_halves_case .search (case_str ):
647- cases .append (even_round_halves_case )
648679 else :
649680 if not r_remaining_case .search (case_str ):
650681 warn (f"case not machine-readable: '{ case_str } '" )
@@ -818,25 +849,6 @@ def check_result(i1: float, i2: float, result: float) -> bool:
818849 return check_result
819850
820851
821- def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
822- """
823- Returns a strategy that generates float-casted integers within the bounds of dtype.
824- """
825- for k in kw .keys ():
826- # sanity check
827- assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
828- m , M = dh .dtype_ranges [dtype ]
829- if "min_value" in kw .keys ():
830- m = kw ["min_value" ]
831- if "exclude_min" in kw .keys ():
832- m += 1
833- if "max_value" in kw .keys ():
834- M = kw ["max_value" ]
835- if "exclude_max" in kw .keys ():
836- M -= 1
837- return st .integers (math .ceil (m ), math .floor (M )).map (float )
838-
839-
840852def parse_binary_case (case_str : str ) -> BinaryCase :
841853 """
842854 Parses a Sphinx-formatted binary case string to return codified binary cases, e.g.
0 commit comments