1717
1818use pyo3:: { prelude:: * , wrap_pyfunction} ;
1919
20+ use crate :: common:: data_type:: NullTreatment ;
2021use crate :: context:: PySessionContext ;
2122use crate :: errors:: DataFusionError ;
2223use crate :: expr:: conditional_expr:: PyCaseBuilder ;
@@ -73,15 +74,15 @@ pub fn var(y: PyExpr) -> PyExpr {
7374}
7475
7576#[ pyfunction]
76- #[ pyo3( signature = ( * args, distinct = false , filter = None , order_by = None ) ) ]
77+ #[ pyo3( signature = ( * args, distinct = false , filter = None , order_by = None , null_treatment = None ) ) ]
7778pub fn first_value (
7879 args : Vec < PyExpr > ,
7980 distinct : bool ,
8081 filter : Option < PyExpr > ,
8182 order_by : Option < Vec < PyExpr > > ,
83+ null_treatment : Option < NullTreatment > ,
8284) -> PyExpr {
83- // TODO: allow user to select null_treatment
84- let null_treatment = None ;
85+ let null_treatment = null_treatment. map ( Into :: into) ;
8586 let args = args. into_iter ( ) . map ( |x| x. expr ) . collect :: < Vec < _ > > ( ) ;
8687 let order_by = order_by. map ( |x| x. into_iter ( ) . map ( |x| x. expr ) . collect :: < Vec < _ > > ( ) ) ;
8788 functions_aggregate:: expr_fn:: first_value (
@@ -95,15 +96,15 @@ pub fn first_value(
9596}
9697
9798#[ pyfunction]
98- #[ pyo3( signature = ( * args, distinct = false , filter = None , order_by = None ) ) ]
99+ #[ pyo3( signature = ( * args, distinct = false , filter = None , order_by = None , null_treatment = None ) ) ]
99100pub fn last_value (
100101 args : Vec < PyExpr > ,
101102 distinct : bool ,
102103 filter : Option < PyExpr > ,
103104 order_by : Option < Vec < PyExpr > > ,
105+ null_treatment : Option < NullTreatment > ,
104106) -> PyExpr {
105- // TODO: allow user to select null_treatment
106- let null_treatment = None ;
107+ let null_treatment = null_treatment. map ( Into :: into) ;
107108 let args = args. into_iter ( ) . map ( |x| x. expr ) . collect :: < Vec < _ > > ( ) ;
108109 let order_by = order_by. map ( |x| x. into_iter ( ) . map ( |x| x. expr ) . collect :: < Vec < _ > > ( ) ) ;
109110 functions_aggregate:: expr_fn:: last_value (
@@ -320,14 +321,20 @@ fn window(
320321 window_frame : Option < PyWindowFrame > ,
321322 ctx : Option < PySessionContext > ,
322323) -> PyResult < PyExpr > {
323- let fun = find_df_window_func ( name) . or_else ( || {
324- ctx. and_then ( |ctx| {
325- ctx. ctx
326- . udaf ( name)
327- . map ( WindowFunctionDefinition :: AggregateUDF )
328- . ok ( )
324+ // workaround for https://github.com/apache/datafusion-python/issues/730
325+ let fun = if name == "sum" {
326+ let sum_udf = functions_aggregate:: sum:: sum_udaf ( ) ;
327+ Some ( WindowFunctionDefinition :: AggregateUDF ( sum_udf) )
328+ } else {
329+ find_df_window_func ( name) . or_else ( || {
330+ ctx. and_then ( |ctx| {
331+ ctx. ctx
332+ . udaf ( name)
333+ . map ( WindowFunctionDefinition :: AggregateUDF )
334+ . ok ( )
335+ } )
329336 } )
330- } ) ;
337+ } ;
331338 if fun. is_none ( ) {
332339 return Err ( DataFusionError :: Common ( "window function not found" . to_string ( ) ) . into ( ) ) ;
333340 }
0 commit comments