|
16 | 16 | // under the License. |
17 | 17 |
|
18 | 18 | use datafusion::functions_aggregate::all_default_aggregate_functions; |
| 19 | +use datafusion_expr::AggregateExt; |
19 | 20 | use pyo3::{prelude::*, wrap_pyfunction}; |
20 | 21 |
|
21 | 22 | use crate::common::data_type::NullTreatment; |
@@ -75,47 +76,80 @@ pub fn var(y: PyExpr) -> PyExpr { |
75 | 76 | } |
76 | 77 |
|
77 | 78 | #[pyfunction] |
78 | | -#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))] |
| 79 | +#[pyo3(signature = (expr, distinct = false, filter = None, order_by = None, null_treatment = None))] |
79 | 80 | pub fn first_value( |
80 | | - args: Vec<PyExpr>, |
| 81 | + expr: PyExpr, |
81 | 82 | distinct: bool, |
82 | 83 | filter: Option<PyExpr>, |
83 | 84 | order_by: Option<Vec<PyExpr>>, |
84 | 85 | null_treatment: Option<NullTreatment>, |
85 | | -) -> PyExpr { |
86 | | - let null_treatment = null_treatment.map(Into::into); |
87 | | - let args = args.into_iter().map(|x| x.expr).collect::<Vec<_>>(); |
| 86 | +) -> PyResult<PyExpr> { |
88 | 87 | let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>()); |
89 | | - functions_aggregate::expr_fn::first_value( |
90 | | - args, |
91 | | - distinct, |
92 | | - filter.map(|x| Box::new(x.expr)), |
93 | | - order_by, |
94 | | - null_treatment, |
95 | | - ) |
96 | | - .into() |
| 88 | + |
| 89 | + // TODO: add `builder()` to `AggregateExt` to avoid this boilerplate |
| 90 | + let builder = functions_aggregate::expr_fn::first_value(expr.expr, order_by); |
| 91 | + |
| 92 | + let builder = if let Some(filter) = filter { |
| 93 | + let filter = filter.expr; |
| 94 | + builder.filter(filter).build()? |
| 95 | + } else { |
| 96 | + builder |
| 97 | + }; |
| 98 | + |
| 99 | + let builder = if distinct { |
| 100 | + builder.distinct().build()? |
| 101 | + } else { |
| 102 | + builder |
| 103 | + }; |
| 104 | + |
| 105 | + let builder = if let Some(null_treatment) = null_treatment { |
| 106 | + builder.null_treatment(null_treatment.into()).build()? |
| 107 | + } else { |
| 108 | + builder |
| 109 | + }; |
| 110 | + |
| 111 | + Ok(builder.into()) |
97 | 112 | } |
98 | 113 |
|
99 | 114 | #[pyfunction] |
100 | | -#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))] |
| 115 | +#[pyo3(signature = (expr, distinct = false, filter = None, order_by = None, null_treatment = None))] |
101 | 116 | pub fn last_value( |
102 | | - args: Vec<PyExpr>, |
| 117 | + expr: PyExpr, |
103 | 118 | distinct: bool, |
104 | 119 | filter: Option<PyExpr>, |
105 | 120 | order_by: Option<Vec<PyExpr>>, |
106 | 121 | null_treatment: Option<NullTreatment>, |
107 | | -) -> PyExpr { |
108 | | - let null_treatment = null_treatment.map(Into::into); |
109 | | - let args = args.into_iter().map(|x| x.expr).collect::<Vec<_>>(); |
110 | | - let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>()); |
111 | | - functions_aggregate::expr_fn::last_value( |
112 | | - args, |
113 | | - distinct, |
114 | | - filter.map(|x| Box::new(x.expr)), |
115 | | - order_by, |
116 | | - null_treatment, |
117 | | - ) |
118 | | - .into() |
| 122 | +) -> PyResult<PyExpr> { |
| 123 | + // TODO: add `builder()` to `AggregateExt` to avoid this boilerplate |
| 124 | + let builder = functions_aggregate::expr_fn::last_value(vec![expr.expr]); |
| 125 | + |
| 126 | + let builder = if distinct { |
| 127 | + builder.distinct().build()? |
| 128 | + } else { |
| 129 | + builder |
| 130 | + }; |
| 131 | + |
| 132 | + let builder = if let Some(filter) = filter { |
| 133 | + let filter = filter.expr; |
| 134 | + builder.filter(filter).build()? |
| 135 | + } else { |
| 136 | + builder |
| 137 | + }; |
| 138 | + |
| 139 | + let builder = if let Some(order_by) = order_by { |
| 140 | + let order_by = order_by.into_iter().map(|x| x.expr).collect::<Vec<_>>(); |
| 141 | + builder.order_by(order_by).build()? |
| 142 | + } else { |
| 143 | + builder |
| 144 | + }; |
| 145 | + |
| 146 | + let builder = if let Some(null_treatment) = null_treatment { |
| 147 | + builder.null_treatment(null_treatment.into()).build()? |
| 148 | + } else { |
| 149 | + builder |
| 150 | + }; |
| 151 | + |
| 152 | + Ok(builder.into()) |
119 | 153 | } |
120 | 154 |
|
121 | 155 | #[pyfunction] |
|
0 commit comments