diff --git a/Cargo.lock b/Cargo.lock index 2e345e71b..a291189fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1596,6 +1596,7 @@ name = "datafusion-python" version = "50.1.0" dependencies = [ "arrow", + "arrow-select", "async-trait", "cstr", "datafusion", diff --git a/Cargo.toml b/Cargo.toml index 3b7a4caaa..1e8c3366d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ pyo3 = { version = "0.25", features = [ pyo3-async-runtimes = { version = "0.25", features = ["tokio-runtime"] } pyo3-log = "0.12.4" arrow = { version = "56", features = ["pyarrow"] } +arrow-select = { version = "56" } datafusion = { version = "50", features = ["avro", "unicode_expressions"] } datafusion-substrait = { version = "50", optional = true } datafusion-proto = { version = "50" } diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index 659589cf0..510bcbc68 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -200,6 +200,9 @@ To materialize the results of your DataFrame operations: # Count rows count = df.count() + # Collect a single column of data as a PyArrow Array + arr = df.collect_column("age") + Zero-copy streaming to Arrow-based Python libraries --------------------------------------------------- @@ -238,7 +241,7 @@ PyArrow: Each batch exposes ``to_pyarrow()``, allowing conversion to a PyArrow table. ``pa.table(df)`` collects the entire DataFrame eagerly into a -PyArrow table:: +PyArrow table: .. code-block:: python @@ -246,7 +249,7 @@ PyArrow table:: table = pa.table(df) Asynchronous iteration is supported as well, allowing integration with -``asyncio`` event loops:: +``asyncio`` event loops: .. code-block:: python diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c6ff7eda5..e847932b3 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -728,6 +728,10 @@ def collect(self) -> list[pa.RecordBatch]: """ return self.df.collect() + def collect_column(self, column_name: str) -> pa.Array | pa.ChunkedArray: + """Executes this :py:class:`DataFrame` for a single column.""" + return self.df.collect_column(column_name) + def cache(self) -> DataFrame: """Cache the DataFrame as a memory table. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 101dfc5b2..fbdef4c9a 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -1733,6 +1733,18 @@ def test_collect_partitioned(): assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned() +def test_collect_column(ctx: SessionContext): + batch_1 = pa.RecordBatch.from_pydict({"a": [1, 2, 3]}) + batch_2 = pa.RecordBatch.from_pydict({"a": [4, 5, 6]}) + batch_3 = pa.RecordBatch.from_pydict({"a": [7, 8, 9]}) + + ctx.register_record_batches("t", [[batch_1, batch_2], [batch_3]]) + + result = ctx.table("t").sort(column("a")).collect_column("a") + expected = pa.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert result == expected + + def test_union(ctx): batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], diff --git a/src/dataframe.rs b/src/dataframe.rs index a93aa0185..d855f8a9d 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -20,7 +20,7 @@ use std::collections::HashMap; use std::ffi::{CStr, CString}; use std::sync::Arc; -use arrow::array::{new_null_array, RecordBatch, RecordBatchReader}; +use arrow::array::{new_null_array, Array, ArrayRef, RecordBatch, RecordBatchReader}; use arrow::compute::can_cast_types; use arrow::error::ArrowError; use arrow::ffi::FFI_ArrowSchema; @@ -343,6 +343,23 @@ impl PyDataFrame { Ok(html_str) } + + async fn collect_column_inner(&self, column: &str) -> Result { + let batches = self + .df + .as_ref() + .clone() + .select_columns(&[column])? + .collect() + .await?; + + let arrays = batches + .iter() + .map(|b| b.column(0).as_ref()) + .collect::>(); + + arrow_select::concat::concat(&arrays).map_err(Into::into) + } } /// Synchronous wrapper around partitioned [`SendableRecordBatchStream`]s used @@ -610,6 +627,13 @@ impl PyDataFrame { .collect() } + fn collect_column(&self, py: Python, column: &str) -> PyResult { + wait_for_future(py, self.collect_column_inner(column))? + .map_err(PyDataFusionError::from)? + .to_data() + .to_pyarrow(py) + } + /// Print the result, 20 lines by default #[pyo3(signature = (num=20))] fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {