@@ -83,6 +83,7 @@ pub fn return_type(
8383 Ok ( coerced_data_types[ 0 ] . clone ( ) )
8484 }
8585 AggregateFunction :: ApproxMedian => Ok ( coerced_data_types[ 0 ] . clone ( ) ) ,
86+ AggregateFunction :: BoolAnd | AggregateFunction :: BoolOr => Ok ( DataType :: Boolean ) ,
8687 }
8788}
8889
@@ -297,6 +298,13 @@ pub fn create_aggregate_expr(
297298 "MEDIAN(DISTINCT) aggregations are not available" . to_string ( ) ,
298299 ) ) ;
299300 }
301+ ( AggregateFunction :: BoolAnd , _) => Arc :: new ( expressions:: BoolAnd :: new (
302+ coerced_phy_exprs[ 0 ] . clone ( ) ,
303+ name,
304+ ) ) ,
305+ ( AggregateFunction :: BoolOr , _) => {
306+ Arc :: new ( expressions:: BoolOr :: new ( coerced_phy_exprs[ 0 ] . clone ( ) , name) )
307+ }
300308 } )
301309}
302310
@@ -374,16 +382,19 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature {
374382 . collect ( ) ,
375383 Volatility :: Immutable ,
376384 ) ,
385+ AggregateFunction :: BoolAnd | AggregateFunction :: BoolOr => {
386+ Signature :: exact ( vec ! [ DataType :: Boolean ] , Volatility :: Immutable )
387+ }
377388 }
378389}
379390
380391#[ cfg( test) ]
381392mod tests {
382393 use super :: * ;
383394 use crate :: physical_plan:: expressions:: {
384- ApproxDistinct , ApproxMedian , ApproxPercentileCont , ArrayAgg , Avg , Correlation ,
385- Count , Covariance , DistinctArrayAgg , DistinctCount , Max , Min , Stddev , Sum ,
386- Variance ,
395+ ApproxDistinct , ApproxMedian , ApproxPercentileCont , ArrayAgg , Avg , BoolAnd ,
396+ BoolOr , Correlation , Count , Covariance , DistinctArrayAgg , DistinctCount , Max ,
397+ Min , Stddev , Sum , Variance ,
387398 } ;
388399 use crate :: { error:: Result , scalar:: ScalarValue } ;
389400
@@ -995,6 +1006,45 @@ mod tests {
9951006 Ok ( ( ) )
9961007 }
9971008
1009+ #[ test]
1010+ fn test_bool_and_or_expr ( ) -> Result < ( ) > {
1011+ let funcs = vec ! [ AggregateFunction :: BoolAnd , AggregateFunction :: BoolOr ] ;
1012+ for fun in funcs {
1013+ let input_schema =
1014+ Schema :: new ( vec ! [ Field :: new( "c1" , DataType :: Boolean , true ) ] ) ;
1015+ let input_phy_exprs: Vec < Arc < dyn PhysicalExpr > > = vec ! [ Arc :: new(
1016+ expressions:: Column :: new_with_schema( "c1" , & input_schema) . unwrap( ) ,
1017+ ) ] ;
1018+ let result_agg_phy_exprs = create_aggregate_expr (
1019+ & fun,
1020+ false ,
1021+ & input_phy_exprs[ 0 ..1 ] ,
1022+ & input_schema,
1023+ "c1" ,
1024+ ) ?;
1025+ match fun {
1026+ AggregateFunction :: BoolAnd => {
1027+ assert ! ( result_agg_phy_exprs. as_any( ) . is:: <BoolAnd >( ) ) ;
1028+ assert_eq ! ( "c1" , result_agg_phy_exprs. name( ) ) ;
1029+ assert_eq ! (
1030+ Field :: new( "c1" , DataType :: Boolean , true ) ,
1031+ result_agg_phy_exprs. field( ) . unwrap( )
1032+ ) ;
1033+ }
1034+ AggregateFunction :: BoolOr => {
1035+ assert ! ( result_agg_phy_exprs. as_any( ) . is:: <BoolOr >( ) ) ;
1036+ assert_eq ! ( "c1" , result_agg_phy_exprs. name( ) ) ;
1037+ assert_eq ! (
1038+ Field :: new( "c1" , DataType :: Boolean , true ) ,
1039+ result_agg_phy_exprs. field( ) . unwrap( )
1040+ ) ;
1041+ }
1042+ _ => { }
1043+ } ;
1044+ }
1045+ Ok ( ( ) )
1046+ }
1047+
9981048 #[ test]
9991049 fn test_median ( ) -> Result < ( ) > {
10001050 let observed = return_type ( & AggregateFunction :: ApproxMedian , & [ DataType :: Utf8 ] ) ;
@@ -1158,4 +1208,32 @@ mod tests {
11581208 let observed = return_type ( & AggregateFunction :: Stddev , & [ DataType :: Utf8 ] ) ;
11591209 assert ! ( observed. is_err( ) ) ;
11601210 }
1211+
1212+ #[ test]
1213+ fn test_bool_and_return_type ( ) -> Result < ( ) > {
1214+ let observed = return_type ( & AggregateFunction :: BoolAnd , & [ DataType :: Boolean ] ) ?;
1215+ assert_eq ! ( DataType :: Boolean , observed) ;
1216+
1217+ Ok ( ( ) )
1218+ }
1219+
1220+ #[ test]
1221+ fn test_bool_and_no_utf8 ( ) {
1222+ let observed = return_type ( & AggregateFunction :: BoolAnd , & [ DataType :: Utf8 ] ) ;
1223+ assert ! ( observed. is_err( ) ) ;
1224+ }
1225+
1226+ #[ test]
1227+ fn test_bool_or_return_type ( ) -> Result < ( ) > {
1228+ let observed = return_type ( & AggregateFunction :: BoolOr , & [ DataType :: Boolean ] ) ?;
1229+ assert_eq ! ( DataType :: Boolean , observed) ;
1230+
1231+ Ok ( ( ) )
1232+ }
1233+
1234+ #[ test]
1235+ fn test_bool_or_no_utf8 ( ) {
1236+ let observed = return_type ( & AggregateFunction :: BoolOr , & [ DataType :: Utf8 ] ) ;
1237+ assert ! ( observed. is_err( ) ) ;
1238+ }
11611239}
0 commit comments