File tree Expand file tree Collapse file tree 4 files changed +17
-13
lines changed
docs/source/user-guide/common-operations Expand file tree Collapse file tree 4 files changed +17
-13
lines changed Original file line number Diff line number Diff line change @@ -50,6 +50,7 @@ Additionally the :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows yo
5050 import pyarrow.compute
5151 import datafusion
5252 from datafusion import col, udaf, Accumulator
53+ from typing import List
5354
5455 class MyAccumulator (Accumulator ):
5556 """
@@ -62,9 +63,9 @@ Additionally the :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows yo
6263 # not nice since pyarrow scalars can't be summed yet. This breaks on `None`
6364 self ._sum = pyarrow.scalar(self ._sum.as_py() + pyarrow.compute.sum(values).as_py())
6465
65- def merge (self , states : pyarrow.Array) -> None :
66+ def merge (self , states : List[ pyarrow.Array] ) -> None :
6667 # not nice since pyarrow scalars can't be summed yet. This breaks on `None`
67- self ._sum = pyarrow.scalar(self ._sum.as_py() + pyarrow.compute.sum(states).as_py())
68+ self ._sum = pyarrow.scalar(self ._sum.as_py() + pyarrow.compute.sum(states[ 0 ] ).as_py())
6869
6970 def state (self ) -> pyarrow.Array:
7071 return pyarrow.array([self ._sum.as_py()])
Original file line number Diff line number Diff line change @@ -38,10 +38,10 @@ def update(self, values: pa.Array) -> None:
3838 # This breaks on `None`
3939 self ._sum = pa .scalar (self ._sum .as_py () + pc .sum (values ).as_py ())
4040
41- def merge (self , states : pa .Array ) -> None :
41+ def merge (self , states : List [ pa .Array ] ) -> None :
4242 # Not nice since pyarrow scalars can't be summed yet.
4343 # This breaks on `None`
44- self ._sum = pa .scalar (self ._sum .as_py () + pc .sum (states ).as_py ())
44+ self ._sum = pa .scalar (self ._sum .as_py () + pc .sum (states [ 0 ] ).as_py ())
4545
4646 def evaluate (self ) -> pa .Scalar :
4747 return self ._sum
Original file line number Diff line number Diff line change @@ -157,7 +157,7 @@ def update(self, values: pyarrow.Array) -> None:
157157 pass
158158
159159 @abstractmethod
160- def merge (self , states : pyarrow .Array ) -> None :
160+ def merge (self , states : List [ pyarrow .Array ] ) -> None :
161161 """Merge a set of states."""
162162 pass
163163
Original file line number Diff line number Diff line change @@ -72,18 +72,21 @@ impl Accumulator for RustAccumulator {
7272
7373 fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
7474 Python :: with_gil ( |py| {
75- let state = & states[ 0 ] ;
76-
77- // 1. cast states to Pyarrow array
78- let state = state
79- . into_data ( )
80- . to_pyarrow ( py)
81- . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e}" ) ) ) ?;
75+ // // 1. cast states to Pyarrow arrays
76+ let py_states: Result < Vec < PyObject > > = states
77+ . iter ( )
78+ . map ( |state| {
79+ state
80+ . into_data ( )
81+ . to_pyarrow ( py)
82+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e}" ) ) )
83+ } )
84+ . collect ( ) ;
8285
8386 // 2. call merge
8487 self . accum
8588 . bind ( py)
86- . call_method1 ( "merge" , ( state , ) )
89+ . call_method1 ( "merge" , ( py_states? , ) )
8790 . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e}" ) ) ) ?;
8891
8992 Ok ( ( ) )
You can’t perform that action at this time.
0 commit comments