@@ -526,6 +526,10 @@ def __repr__(self) -> str:
526526 return f"{ self .__class__ .__name__ } (<{ self } >)"
527527
528528
529+ r_case_block = re .compile (r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters" )
530+ r_case = re .compile (r"\s+-\s*(.*)\." )
531+
532+
529533class UnaryCond (Protocol ):
530534 def __call__ (self , i : float ) -> bool :
531535 ...
@@ -586,7 +590,7 @@ def check_result(i: float, result: float) -> bool:
586590 return check_result
587591
588592
589- def parse_unary_docstring ( docstring : str ) -> List [UnaryCase ]:
593+ def parse_unary_case_block ( case_block : str ) -> List [UnaryCase ]:
590594 """
591595 Parses a Sphinx-formatted docstring of a unary function to return a list of
592596 codified unary cases, e.g.
@@ -616,7 +620,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
616620 ... an array containing the square root of each element in ``x``
617621 ... '''
618622 ...
619- >>> unary_cases = parse_unary_docstring(sqrt.__doc__)
623+ >>> case_block = r_case_block.match(sqrt.__doc__).group(1)
624+ >>> unary_cases = parse_unary_case_block(case_block)
620625 >>> for case in unary_cases:
621626 ... print(repr(case))
622627 UnaryCase(<x_i < 0 -> NaN>)
@@ -631,19 +636,10 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
631636 True
632637
633638 """
634-
635- match = r_special_cases .search (docstring )
636- if match is None :
637- return []
638- lines = match .group (1 ).split ("\n " )[:- 1 ]
639639 cases = []
640- for line in lines :
641- if m := r_case .match (line ):
642- case = m .group (1 )
643- else :
644- warn (f"line not machine-readable: '{ line } '" )
645- continue
646- if m := r_unary_case .search (case ):
640+ for case_m in r_case .finditer (case_block ):
641+ case_str = case_m .group (1 )
642+ if m := r_unary_case .search (case_str ):
647643 try :
648644 cond , cond_expr_template , cond_from_dtype = parse_cond (m .group (1 ))
649645 _check_result , result_expr = parse_result (m .group (2 ))
@@ -662,11 +658,11 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
662658 check_result = check_result ,
663659 )
664660 cases .append (case )
665- elif m := r_even_round_halves_case .search (case ):
661+ elif m := r_even_round_halves_case .search (case_str ):
666662 cases .append (even_round_halves_case )
667663 else :
668- if not r_remaining_case .search (case ):
669- warn (f"case not machine-readable: '{ case } '" )
664+ if not r_remaining_case .search (case_str ):
665+ warn (f"case not machine-readable: '{ case_str } '" )
670666 return cases
671667
672668
@@ -690,12 +686,6 @@ class BinaryCase(Case):
690686 check_result : BinaryResultCheck
691687
692688
693- r_special_cases = re .compile (
694- r"\*\*Special [Cc]ases\*\*(?:\n.*)+"
695- r"For floating-point operands,\n+"
696- r"((?:\s*-\s*.*\n)+)"
697- )
698- r_case = re .compile (r"\s+-\s*(.*)\.\n?" )
699689r_binary_case = re .compile ("If (.+), the result (.+)" )
700690r_remaining_case = re .compile ("In the remaining cases.+" )
701691r_cond_sep = re .compile (r"(?<!``x1_i``),? and |(?<!i\.e\.), " )
@@ -880,8 +870,7 @@ def parse_binary_case(case_str: str) -> BinaryCase:
880870
881871 """
882872 case_m = r_binary_case .match (case_str )
883- if case_m is None :
884- raise ParseError (case_str )
873+ assert case_m is not None # sanity check
885874 cond_strs = r_cond_sep .split (case_m .group (1 ))
886875
887876 partial_conds = []
@@ -1078,7 +1067,7 @@ def cond(i1: float, i2: float) -> bool:
10781067r_redundant_case = re .compile ("result.+determined by the rule already stated above" )
10791068
10801069
1081- def parse_binary_docstring ( docstring : str ) -> List [BinaryCase ]:
1070+ def parse_binary_case_block ( case_block : str ) -> List [BinaryCase ]:
10821071 """
10831072 Parses a Sphinx-formatted docstring of a binary function to return a list of
10841073 codified binary cases, e.g.
@@ -1108,29 +1097,21 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11081097 ... an array containing the results
11091098 ... '''
11101099 ...
1111- >>> binary_cases = parse_binary_docstring(logaddexp.__doc__)
1100+ >>> case_block = r_case_block.match(logaddexp.__doc__).group(1)
1101+ >>> binary_cases = parse_binary_case_block(case_block)
11121102 >>> for case in binary_cases:
11131103 ... print(repr(case))
11141104 BinaryCase(<x1_i == NaN or x2_i == NaN -> NaN>)
11151105 BinaryCase(<x1_i == +infinity and not x2_i == NaN -> +infinity>)
11161106 BinaryCase(<not x1_i == NaN and x2_i == +infinity -> +infinity>)
11171107
11181108 """
1119-
1120- match = r_special_cases .search (docstring )
1121- if match is None :
1122- return []
1123- lines = match .group (1 ).split ("\n " )[:- 1 ]
11241109 cases = []
1125- for line in lines :
1126- if m := r_case .match (line ):
1127- case_str = m .group (1 )
1128- else :
1129- warn (f"line not machine-readable: '{ line } '" )
1130- continue
1110+ for case_m in r_case .finditer (case_block ):
1111+ case_str = case_m .group (1 )
11311112 if r_redundant_case .search (case_str ):
11321113 continue
1133- if m := r_binary_case .match (case_str ):
1114+ if r_binary_case .match (case_str ):
11341115 try :
11351116 case = parse_binary_case (case_str )
11361117 cases .append (case )
@@ -1142,14 +1123,19 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11421123 return cases
11431124
11441125
1126+ category_stub_pairs = [(c , s ) for c , stubs in category_to_funcs .items () for s in stubs ]
11451127unary_params = []
11461128binary_params = []
11471129iop_params = []
11481130func_to_op : Dict [str , str ] = {v : k for k , v in dh .op_to_func .items ()}
1149- for stub in category_to_funcs [ "elementwise" ] :
1131+ for category , stub in category_stub_pairs :
11501132 if stub .__doc__ is None :
11511133 warn (f"{ stub .__name__ } () stub has no docstring" )
11521134 continue
1135+ if m := r_case_block .search (stub .__doc__ ):
1136+ case_block = m .group (1 )
1137+ else :
1138+ continue
11531139 marks = []
11541140 try :
11551141 func = getattr (xp , stub .__name__ )
@@ -1163,47 +1149,56 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11631149 if len (sig .parameters ) == 0 :
11641150 warn (f"{ func = } has no parameters" )
11651151 continue
1166- if param_names [0 ] == "x" :
1167- if cases := parse_unary_docstring (stub .__doc__ ):
1168- name_to_func = {stub .__name__ : func }
1169- if stub .__name__ in func_to_op .keys ():
1170- op_name = func_to_op [stub .__name__ ]
1171- op = getattr (operator , op_name )
1172- name_to_func [op_name ] = op
1173- for func_name , func in name_to_func .items ():
1174- for case in cases :
1175- id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1176- p = pytest .param (func_name , func , case , id = id_ )
1177- unary_params .append (p )
1178- continue
1179- if len (sig .parameters ) == 1 :
1180- warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
1181- continue
1182- if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1183- if cases := parse_binary_docstring (stub .__doc__ ):
1184- name_to_func = {stub .__name__ : func }
1185- if stub .__name__ in func_to_op .keys ():
1186- op_name = func_to_op [stub .__name__ ]
1187- op = getattr (operator , op_name )
1188- name_to_func [op_name ] = op
1189- # We collect inplaceoperator test cases seperately
1190- iop_name = "__i" + op_name [2 :]
1191- iop = getattr (operator , iop_name )
1192- for case in cases :
1193- id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1194- p = pytest .param (iop_name , iop , case , id = id_ )
1195- iop_params .append (p )
1196- for func_name , func in name_to_func .items ():
1197- for case in cases :
1198- id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1199- p = pytest .param (func_name , func , case , id = id_ )
1200- binary_params .append (p )
1201- continue
1152+ if category == "elementwise" :
1153+ if param_names [0 ] == "x" :
1154+ if cases := parse_unary_case_block (case_block ):
1155+ name_to_func = {stub .__name__ : func }
1156+ if stub .__name__ in func_to_op .keys ():
1157+ op_name = func_to_op [stub .__name__ ]
1158+ op = getattr (operator , op_name )
1159+ name_to_func [op_name ] = op
1160+ for func_name , func in name_to_func .items ():
1161+ for case in cases :
1162+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1163+ p = pytest .param (func_name , func , case , id = id_ )
1164+ unary_params .append (p )
1165+ else :
1166+ warn ("TODO" )
1167+ continue
1168+ if len (sig .parameters ) == 1 :
1169+ warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
1170+ continue
1171+ if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1172+ if cases := parse_binary_case_block (case_block ):
1173+ name_to_func = {stub .__name__ : func }
1174+ if stub .__name__ in func_to_op .keys ():
1175+ op_name = func_to_op [stub .__name__ ]
1176+ op = getattr (operator , op_name )
1177+ name_to_func [op_name ] = op
1178+ # We collect inplace operator test cases seperately
1179+ iop_name = "__i" + op_name [2 :]
1180+ iop = getattr (operator , iop_name )
1181+ for case in cases :
1182+ id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1183+ p = pytest .param (iop_name , iop , case , id = id_ )
1184+ iop_params .append (p )
1185+ for func_name , func in name_to_func .items ():
1186+ for case in cases :
1187+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1188+ p = pytest .param (func_name , func , case , id = id_ )
1189+ binary_params .append (p )
1190+ else :
1191+ warn ("TODO" )
1192+ continue
1193+ else :
1194+ warn (
1195+ f"{ func = } starts with two parameters '{ param_names [0 ]} ' and "
1196+ f"'{ param_names [1 ]} ', which are not named 'x1' and 'x2'"
1197+ )
1198+ elif category == "statistical" :
1199+ pass # TODO
12021200 else :
1203- warn (
1204- f"{ func = } starts with two parameters '{ param_names [0 ]} ' and "
1205- f"'{ param_names [1 ]} ', which are not named 'x1' and 'x2'"
1206- )
1201+ warn ("TODO" )
12071202
12081203
12091204# test_unary and test_binary naively generate arrays, i.e. arrays that might not
0 commit comments