Skip to content

Commit 27f5c73

Browse files
migrate regr_* functions to UDAF
Ref: apache/datafusion#10898
1 parent f9d699c commit 27f5c73

File tree

2 files changed

+99
-18
lines changed

2 files changed

+99
-18
lines changed

python/datafusion/functions.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,50 +1398,50 @@ def regr_avgx(y: Expr, x: Expr, distinct: bool = False) -> Expr:
13981398
13991399
Only non-null pairs of the inputs are evaluated.
14001400
"""
1401-
return Expr(f.regr_avgx[y.expr, x.expr], distinct)
1401+
return Expr(f.regr_avgx(y.expr, x.expr, distinct))
14021402

14031403

14041404
def regr_avgy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14051405
"""Computes the average of the dependent variable ``y``.
14061406
14071407
Only non-null pairs of the inputs are evaluated.
14081408
"""
1409-
return Expr(f.regr_avgy[y.expr, x.expr], distinct)
1409+
return Expr(f.regr_avgy(y.expr, x.expr, distinct))
14101410

14111411

14121412
def regr_count(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14131413
"""Counts the number of rows in which both expressions are not null."""
1414-
return Expr(f.regr_count[y.expr, x.expr], distinct)
1414+
return Expr(f.regr_count(y.expr, x.expr, distinct))
14151415

14161416

14171417
def regr_intercept(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14181418
"""Computes the intercept from the linear regression."""
1419-
return Expr(f.regr_intercept[y.expr, x.expr], distinct)
1419+
return Expr(f.regr_intercept(y.expr, x.expr, distinct))
14201420

14211421

14221422
def regr_r2(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14231423
"""Computes the R-squared value from linear regression."""
1424-
return Expr(f.regr_r2[y.expr, x.expr], distinct)
1424+
return Expr(f.regr_r2(y.expr, x.expr, distinct))
14251425

14261426

14271427
def regr_slope(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14281428
"""Computes the slope from linear regression."""
1429-
return Expr(f.regr_slope[y.expr, x.expr], distinct)
1429+
return Expr(f.regr_slope(y.expr, x.expr, distinct))
14301430

14311431

14321432
def regr_sxx(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14331433
"""Computes the sum of squares of the independent variable `x`."""
1434-
return Expr(f.regr_sxx[y.expr, x.expr], distinct)
1434+
return Expr(f.regr_sxx(y.expr, x.expr, distinct))
14351435

14361436

14371437
def regr_sxy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14381438
"""Computes the sum of products of pairs of numbers."""
1439-
return Expr(f.regr_sxy[y.expr, x.expr], distinct)
1439+
return Expr(f.regr_sxy(y.expr, x.expr, distinct))
14401440

14411441

14421442
def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14431443
"""Computes the sum of squares of the dependent variable `y`."""
1444-
return Expr(f.regr_syy[y.expr, x.expr], distinct)
1444+
return Expr(f.regr_syy(y.expr, x.expr, distinct))
14451445

14461446

14471447
def first_value(

src/functions.rs

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,96 @@ pub fn var_pop(expression: PyExpr, distinct: bool) -> PyResult<PyExpr> {
191191
}
192192
}
193193

194+
#[pyfunction]
195+
pub fn regr_avgx(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
196+
let expr = functions_aggregate::expr_fn::regr_avgx(expr_y.expr, expr_x.expr);
197+
if distinct {
198+
Ok(expr.distinct().build()?.into())
199+
} else {
200+
Ok(expr.into())
201+
}
202+
}
203+
204+
#[pyfunction]
205+
pub fn regr_avgy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
206+
let expr = functions_aggregate::expr_fn::regr_avgy(expr_y.expr, expr_x.expr);
207+
if distinct {
208+
Ok(expr.distinct().build()?.into())
209+
} else {
210+
Ok(expr.into())
211+
}
212+
}
213+
214+
#[pyfunction]
215+
pub fn regr_count(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
216+
let expr = functions_aggregate::expr_fn::regr_count(expr_y.expr, expr_x.expr);
217+
if distinct {
218+
Ok(expr.distinct().build()?.into())
219+
} else {
220+
Ok(expr.into())
221+
}
222+
}
223+
224+
#[pyfunction]
225+
pub fn regr_intercept(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
226+
let expr = functions_aggregate::expr_fn::regr_intercept(expr_y.expr, expr_x.expr);
227+
if distinct {
228+
Ok(expr.distinct().build()?.into())
229+
} else {
230+
Ok(expr.into())
231+
}
232+
}
233+
234+
#[pyfunction]
235+
pub fn regr_r2(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
236+
let expr = functions_aggregate::expr_fn::regr_r2(expr_y.expr, expr_x.expr);
237+
if distinct {
238+
Ok(expr.distinct().build()?.into())
239+
} else {
240+
Ok(expr.into())
241+
}
242+
}
243+
244+
#[pyfunction]
245+
pub fn regr_slope(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
246+
let expr = functions_aggregate::expr_fn::regr_slope(expr_y.expr, expr_x.expr);
247+
if distinct {
248+
Ok(expr.distinct().build()?.into())
249+
} else {
250+
Ok(expr.into())
251+
}
252+
}
253+
254+
#[pyfunction]
255+
pub fn regr_sxx(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
256+
let expr = functions_aggregate::expr_fn::regr_sxx(expr_y.expr, expr_x.expr);
257+
if distinct {
258+
Ok(expr.distinct().build()?.into())
259+
} else {
260+
Ok(expr.into())
261+
}
262+
}
263+
264+
#[pyfunction]
265+
pub fn regr_sxy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
266+
let expr = functions_aggregate::expr_fn::regr_sxy(expr_y.expr, expr_x.expr);
267+
if distinct {
268+
Ok(expr.distinct().build()?.into())
269+
} else {
270+
Ok(expr.into())
271+
}
272+
}
273+
274+
#[pyfunction]
275+
pub fn regr_syy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
276+
let expr = functions_aggregate::expr_fn::regr_syy(expr_y.expr, expr_x.expr);
277+
if distinct {
278+
Ok(expr.distinct().build()?.into())
279+
} else {
280+
Ok(expr.into())
281+
}
282+
}
283+
194284
#[pyfunction]
195285
#[pyo3(signature = (expr, distinct = false, filter = None, order_by = None, null_treatment = None))]
196286
pub fn first_value(
@@ -847,15 +937,6 @@ array_fn!(range, start stop step);
847937
aggregate_function!(array_agg, ArrayAgg);
848938
aggregate_function!(max, Max);
849939
aggregate_function!(min, Min);
850-
aggregate_function!(regr_avgx, RegrAvgx);
851-
aggregate_function!(regr_avgy, RegrAvgy);
852-
aggregate_function!(regr_count, RegrCount);
853-
aggregate_function!(regr_intercept, RegrIntercept);
854-
aggregate_function!(regr_r2, RegrR2);
855-
aggregate_function!(regr_slope, RegrSlope);
856-
aggregate_function!(regr_sxx, RegrSXX);
857-
aggregate_function!(regr_sxy, RegrSXY);
858-
aggregate_function!(regr_syy, RegrSYY);
859940
aggregate_function!(bit_and, BitAnd);
860941
aggregate_function!(bit_or, BitOr);
861942
aggregate_function!(bit_xor, BitXor);

0 commit comments

Comments
 (0)