@@ -494,6 +494,7 @@ def check_result(result: float) -> bool:
494494class Case (Protocol ):
495495 cond_expr : str
496496 result_expr : str
497+ raw_case : Optional [str ]
497498
498499 def cond (self , * args ) -> bool :
499500 ...
@@ -532,6 +533,7 @@ class UnaryCase(Case):
532533 cond_from_dtype : FromDtypeFunc
533534 cond : UnaryCheck
534535 check_result : UnaryResultCheck
536+ raw_case : Optional [str ] = field (default = None )
535537
536538
537539r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
@@ -674,6 +676,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
674676 cond_from_dtype = cond_from_dtype ,
675677 result_expr = result_expr ,
676678 check_result = check_result ,
679+ raw_case = case_str ,
677680 )
678681 cases .append (case )
679682 else :
@@ -700,6 +703,7 @@ class BinaryCase(Case):
700703 x2_cond_from_dtype : FromDtypeFunc
701704 cond : BinaryCond
702705 check_result : BinaryResultCheck
706+ raw_case : Optional [str ] = field (default = None )
703707
704708
705709r_binary_case = re .compile ("If (.+), the result (.+)" )
@@ -1058,6 +1062,7 @@ def cond(i1: float, i2: float) -> bool:
10581062 x2_cond_from_dtype = x2_cond_from_dtype ,
10591063 result_expr = result_expr ,
10601064 check_result = check_result ,
1065+ raw_case = case_str ,
10611066 )
10621067
10631068
0 commit comments