@@ -732,6 +732,35 @@ def make_binary_cond(
732732 * ,
733733 input_wrapper : Optional [Callable [[float ], float ]] = None ,
734734) -> BinaryCond :
735+ """
736+ Wraps a unary condition as a binary condition, e.g.
737+
738+ >>> unary_cond = lambda i: i == 42
739+
740+ >>> binary_cond_first = make_binary_cond(BinaryCondArg.FIRST, unary_cond)
741+ >>> binary_cond_first(42, 0)
742+ True
743+ >>> binary_cond_second = make_binary_cond(BinaryCondArg.SECOND, unary_cond)
744+ >>> binary_cond_second(42, 0)
745+ False
746+ >>> binary_cond_second(0, 42)
747+ True
748+ >>> binary_cond_both = make_binary_cond(BinaryCondArg.BOTH, unary_cond)
749+ >>> binary_cond_both(42, 0)
750+ False
751+ >>> binary_cond_both(42, 42)
752+ True
753+ >>> binary_cond_either = make_binary_cond(BinaryCondArg.EITHER, unary_cond)
754+ >>> binary_cond_either(0, 0)
755+ False
756+ >>> binary_cond_either(42, 0)
757+ True
758+ >>> binary_cond_either(0, 42)
759+ True
760+ >>> binary_cond_either(42, 42)
761+ True
762+
763+ """
735764 if input_wrapper is None :
736765 input_wrapper = noop
737766
@@ -823,11 +852,13 @@ def parse_binary_case(case_str: str) -> BinaryCase:
823852 if in_sign != "" or other_no == in_no :
824853 raise ParseError (cond_str )
825854 partial_expr = f"{ in_sign } x{ in_no } _i == { other_sign } x{ other_no } _i"
855+
826856 input_wrapper = lambda i : - i if other_sign == "-" else noop
857+ # For these scenarios, we want to make sure both array elements
858+ # generate respective to one another by using a shared strategy.
827859 shared_from_dtype = lambda d , ** kw : st .shared (
828860 xps .from_dtype (d , ** kw ), key = cond_str
829861 )
830-
831862 if other_no == "1" :
832863
833864 def partial_cond (i1 : float , i2 : float ) -> bool :
0 commit comments