Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pyo3 = { version = "0.25", features = [
pyo3-async-runtimes = { version = "0.25", features = ["tokio-runtime"] }
pyo3-log = "0.12.4"
arrow = { version = "56", features = ["pyarrow"] }
arrow-select = { version = "56" }
datafusion = { version = "50", features = ["avro", "unicode_expressions"] }
datafusion-substrait = { version = "50", optional = true }
datafusion-proto = { version = "50" }
Expand Down
7 changes: 5 additions & 2 deletions docs/source/user-guide/dataframe/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ To materialize the results of your DataFrame operations:
# Count rows
count = df.count()

# Collect a single column of data as a PyArrow Array
arr = df.collect_column("age")

Zero-copy streaming to Arrow-based Python libraries
---------------------------------------------------

Expand Down Expand Up @@ -238,15 +241,15 @@ PyArrow:

Each batch exposes ``to_pyarrow()``, allowing conversion to a PyArrow
table. ``pa.table(df)`` collects the entire DataFrame eagerly into a
PyArrow table::
PyArrow table:

.. code-block:: python

import pyarrow as pa
table = pa.table(df)

Asynchronous iteration is supported as well, allowing integration with
``asyncio`` event loops::
``asyncio`` event loops:

.. code-block:: python

Expand Down
4 changes: 4 additions & 0 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,10 @@ def collect(self) -> list[pa.RecordBatch]:
"""
return self.df.collect()

def collect_column(self, column_name: str) -> pa.Array | pa.ChunkedArray:
"""Executes this :py:class:`DataFrame` for a single column."""
return self.df.collect_column(column_name)

def cache(self) -> DataFrame:
"""Cache the DataFrame as a memory table.

Expand Down
12 changes: 12 additions & 0 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,18 @@ def test_collect_partitioned():
assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned()


def test_collect_column(ctx: SessionContext):
batch_1 = pa.RecordBatch.from_pydict({"a": [1, 2, 3]})
batch_2 = pa.RecordBatch.from_pydict({"a": [4, 5, 6]})
batch_3 = pa.RecordBatch.from_pydict({"a": [7, 8, 9]})

ctx.register_record_batches("t", [[batch_1, batch_2], [batch_3]])

result = ctx.table("t").sort(column("a")).collect_column("a")
expected = pa.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
assert result == expected


def test_union(ctx):
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
Expand Down
26 changes: 25 additions & 1 deletion src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::sync::Arc;

use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
use arrow::array::{new_null_array, Array, ArrayRef, RecordBatch, RecordBatchReader};
use arrow::compute::can_cast_types;
use arrow::error::ArrowError;
use arrow::ffi::FFI_ArrowSchema;
Expand Down Expand Up @@ -343,6 +343,23 @@ impl PyDataFrame {

Ok(html_str)
}

async fn collect_column_inner(&self, column: &str) -> Result<ArrayRef, DataFusionError> {
let batches = self
.df
.as_ref()
.clone()
.select_columns(&[column])?
.collect()
.await?;

let arrays = batches
.iter()
.map(|b| b.column(0).as_ref())
.collect::<Vec<_>>();

arrow_select::concat::concat(&arrays).map_err(Into::into)
}
}

/// Synchronous wrapper around partitioned [`SendableRecordBatchStream`]s used
Expand Down Expand Up @@ -610,6 +627,13 @@ impl PyDataFrame {
.collect()
}

fn collect_column(&self, py: Python, column: &str) -> PyResult<PyObject> {
wait_for_future(py, self.collect_column_inner(column))?
.map_err(PyDataFusionError::from)?
.to_data()
.to_pyarrow(py)
}

/// Print the result, 20 lines by default
#[pyo3(signature = (num=20))]
fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
Expand Down