@@ -610,48 +610,68 @@ fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
610610 } )
611611}
612612
613- /// Helper function to find the appropriate window function. First, if a session
614- /// context is defined check it's registered functions. If no context is defined,
615- /// attempt to find from all default functions. Lastly, as a fall back attempt
616- /// to use built in window functions, which are being deprecated.
613+ /// Helper function to find the appropriate window function.
614+ ///
615+ /// Search procedure:
616+ /// 1) If a session context is provided:
617+ /// a) search User Defined Aggregate Functions (UDAFs)
618+ /// b) search registered window functions
619+ /// c) search registered aggregate functions
620+ /// 2) If no function has been found, search default aggregate functions.
621+ /// 3) Lastly, as a fall back attempt, search built in window functions, which are being deprecated.
617622fn find_window_fn ( name : & str , ctx : Option < PySessionContext > ) -> PyResult < WindowFunctionDefinition > {
618- let mut maybe_fn = match & ctx {
619- Some ( ctx) => {
620- let session_state = ctx. ctx . state ( ) ;
621-
622- match session_state. window_functions ( ) . contains_key ( name) {
623- true => session_state
624- . window_functions ( )
625- . get ( name)
626- . map ( |f| WindowFunctionDefinition :: WindowUDF ( f. clone ( ) ) ) ,
627- false => session_state
628- . aggregate_functions ( )
629- . get ( name)
630- . map ( |f| WindowFunctionDefinition :: AggregateUDF ( f. clone ( ) ) ) ,
631- }
623+ if let Some ( ctx) = ctx {
624+ // search UDAFs
625+ let udaf = ctx
626+ . ctx
627+ . udaf ( name)
628+ . map ( WindowFunctionDefinition :: AggregateUDF )
629+ . ok ( ) ;
630+
631+ if let Some ( udaf) = udaf {
632+ return Ok ( udaf) ;
632633 }
633- None => {
634- let default_aggregate_fns = all_default_aggregate_functions ( ) ;
635634
636- default_aggregate_fns
637- . iter ( )
638- . find ( |v| v. aliases ( ) . contains ( & name. to_string ( ) ) )
639- . map ( |f| WindowFunctionDefinition :: AggregateUDF ( f. clone ( ) ) )
635+ let session_state = ctx. ctx . state ( ) ;
636+
637+ // search registered window functions
638+ let window_fn = session_state
639+ . window_functions ( )
640+ . get ( name)
641+ . map ( |f| WindowFunctionDefinition :: WindowUDF ( f. clone ( ) ) ) ;
642+
643+ if let Some ( window_fn) = window_fn {
644+ return Ok ( window_fn) ;
640645 }
641- } ;
642646
643- if maybe_fn. is_none ( ) {
644- maybe_fn = find_df_window_func ( name) . or_else ( || {
645- ctx. and_then ( |ctx| {
646- ctx. ctx
647- . udaf ( name)
648- . map ( WindowFunctionDefinition :: AggregateUDF )
649- . ok ( )
650- } )
651- } ) ;
647+ // search registered aggregate functions
648+ let agg_fn = session_state
649+ . aggregate_functions ( )
650+ . get ( name)
651+ . map ( |f| WindowFunctionDefinition :: AggregateUDF ( f. clone ( ) ) ) ;
652+
653+ if let Some ( agg_fn) = agg_fn {
654+ return Ok ( agg_fn) ;
655+ }
656+ }
657+
658+ // search default aggregate functions
659+ let agg_fn = all_default_aggregate_functions ( )
660+ . iter ( )
661+ . find ( |v| v. aliases ( ) . contains ( & name. to_string ( ) ) )
662+ . map ( |f| WindowFunctionDefinition :: AggregateUDF ( f. clone ( ) ) ) ;
663+
664+ if let Some ( agg_fn) = agg_fn {
665+ return Ok ( agg_fn) ;
652666 }
653667
654- maybe_fn. ok_or ( DataFusionError :: Common ( format ! ( "window function `{name}` not found" ) ) . into ( ) )
668+ // search built in window functions (soon to be deprecated)
669+ let df_window_func = find_df_window_func ( name) ;
670+ if let Some ( df_window_func) = df_window_func {
671+ return Ok ( df_window_func) ;
672+ }
673+
674+ Err ( DataFusionError :: Common ( format ! ( "window function `{name}` not found" ) ) . into ( ) )
655675}
656676
657677/// Creates a new Window function expression
@@ -1206,4 +1226,4 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
12061226 m. add_wrapped ( wrap_pyfunction ! ( flatten) ) ?;
12071227
12081228 Ok ( ( ) )
1209- }
1229+ }
0 commit comments