Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 108 additions & 62 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,7 @@ def __init__(

def load_stream(self, stream):
import pyarrow as pa
import pandas as pd
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)
Expand All @@ -1964,6 +1965,12 @@ def generate_data_batches(batches):

def flatten_columns(cur_batch, col_name):
state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name))

# Check if the entire column is null
if state_column.null_count == len(state_column):
# Return empty table with no columns
return pa.Table.from_arrays([], names=[])

state_field_names = [
state_column.type[i].name for i in range(state_column.type.num_fields)
]
Expand All @@ -1981,30 +1988,71 @@ def flatten_columns(cur_batch, col_name):
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python
data generator. All rows in the same batch have the same grouping key.
data generator. Rows in the same batch may have different grouping keys,
but each batch will have either init_data or input_data, not mix.
"""
for batch in batches:
flatten_state_table = flatten_columns(batch, "inputData")
data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(flatten_state_table.itercolumns())
]
def row_stream():
for batch in batches:
if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
batch_bytes = sum(
buf.size
for col in batch.columns
for buf in col.buffers()
if buf is not None
)
self.total_bytes += batch_bytes
self.total_rows += batch.num_rows
self.average_arrow_row_size = self.total_bytes / self.total_rows
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic seems to be duplicated from elsewhere in the file, maybe we can add it to a base class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


flatten_init_table = flatten_columns(batch, "initState")
init_data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(flatten_init_table.itercolumns())
]
key_series = [data_pandas[o] for o in self.key_offsets]
init_key_series = [init_data_pandas[o] for o in self.init_key_offsets]
flatten_state_table = flatten_columns(batch, "inputData")
data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(flatten_state_table.itercolumns())
]

if any(s.empty for s in key_series):
# If any row is empty, assign batch_key using init_key_series
batch_key = tuple(s[0] for s in init_key_series)
else:
# If all rows are non-empty, create batch_key from key_series
batch_key = tuple(s[0] for s in key_series)
yield (batch_key, data_pandas, init_data_pandas)
if bool(data_pandas):
for row in pd.concat(data_pandas, axis=1).itertuples(index=False):
batch_key = tuple(row[s] for s in self.key_offsets)
yield (batch_key, row, None)
else:
flatten_init_table = flatten_columns(batch, "initState")
init_data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(flatten_init_table.itercolumns())
]
if bool(init_data_pandas):
for row in pd.concat(init_data_pandas, axis=1).itertuples(index=False):
batch_key = tuple(row[s] for s in self.init_key_offsets)
yield (batch_key, None, row)

EMPTY_DATAFRAME = pd.DataFrame()
for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]):
rows = []
init_state_rows = []
for _, row, init_state_row in group_rows:
if row is not None:
rows.append(row)
if init_state_row is not None:
init_state_rows.append(init_state_row)

total_len = len(rows) + len(init_state_rows)
if (
total_len >= self.arrow_max_records_per_batch
or total_len * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch
):
yield (
batch_key,
pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(),
pd.DataFrame(init_state_rows) if len(init_state_rows) > 0 else EMPTY_DATAFRAME.copy()
)
rows = []
init_state_rows = []
if rows or init_state_rows:
yield (
batch_key,
pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(),
pd.DataFrame(init_state_rows) if len(init_state_rows) > 0 else EMPTY_DATAFRAME.copy()
)

_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)
Expand Down Expand Up @@ -2122,7 +2170,6 @@ def __init__(self, arrow_max_records_per_batch):
self.init_key_offsets = None

def load_stream(self, stream):
import itertools
import pyarrow as pa
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
Expand All @@ -2141,6 +2188,11 @@ def generate_data_batches(batches):

def extract_rows(cur_batch, col_name, key_offsets):
data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name))

# Check if the entire column is null
if data_column.null_count == len(data_column):
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given we've changed the implicit type signature of the function lets maybe add a type annotation on generate_data_batches for readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


data_field_names = [
data_column.type[i].name for i in range(data_column.type.num_fields)
]
Expand All @@ -2153,68 +2205,62 @@ def extract_rows(cur_batch, col_name, key_offsets):
table = pa.Table.from_arrays(data_field_arrays, names=data_field_names)

if table.num_rows == 0:
return (None, iter([]))
else:
batch_key = tuple(table.column(o)[0].as_py() for o in key_offsets)
return None

rows = []
def row_iterator():
for row_idx in range(table.num_rows):
key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets)
row = DataRow(
*(table.column(i)[row_idx].as_py() for i in range(table.num_columns))
)
rows.append(row)
yield (key, row)

return (batch_key, iter(rows))
return row_iterator()

"""
The arrow batch is written in the schema:
schema: StructType = new StructType()
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python
data generator. All rows in the same batch have the same grouping key.
data generator. Each batch will have either init_data or input_data, not mix.
"""
for batch in batches:
(input_batch_key, input_data_iter) = extract_rows(
batch, "inputData", self.key_offsets
)
(init_batch_key, init_state_iter) = extract_rows(
batch, "initState", self.init_key_offsets
)
# Detect which column has data - each batch contains only one type
input_result = extract_rows(batch, "inputData", self.key_offsets)

if input_batch_key is None:
batch_key = init_batch_key
if input_result is not None:
for key, input_data_row in input_result:
yield (key, input_data_row, None)
else:
batch_key = input_batch_key

for init_state_row in init_state_iter:
yield (batch_key, None, init_state_row)

for input_data_row in input_data_iter:
yield (batch_key, input_data_row, None)
init_result = extract_rows(batch, "initState", self.init_key_offsets)
if init_result is not None:
for key, init_state_row in init_result:
yield (key, None, init_state_row)

_batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
# g: list(batch_key, input_data_iter, init_state_iter)

# they are sharing the iterator, hence need to copy
input_values_iter, init_state_iter = itertools.tee(g, 2)

chained_input_values = itertools.chain(map(lambda x: x[1], input_values_iter))
chained_init_state_values = itertools.chain(map(lambda x: x[2], init_state_iter))

chained_input_values_without_none = filter(
lambda x: x is not None, chained_input_values
)
chained_init_state_values_without_none = filter(
lambda x: x is not None, chained_init_state_values
)

ret_tuple = (chained_input_values_without_none, chained_init_state_values_without_none)

yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple)
input_rows = []
init_rows = []

for batch_key, input_row, init_row in g:
if input_row is not None:
input_rows.append(input_row)
if init_row is not None:
init_rows.append(init_row)

total_len = len(input_rows) + len(init_rows)
if total_len >= self.arrow_max_records_per_batch:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SQLConf config param says if set to zero or negative number there is no limit, in this case if it's set to zero or a negative number we will always output a fresh batch per row. Let's change the behaviour and add a test covering this.

Copy link
Contributor Author

@nyaapa nyaapa Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right;
copied that from non-init state handling; 🫠
nice catch!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ret_tuple = (iter(input_rows), iter(init_rows))
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple)
input_rows = []
init_rows = []

if input_rows or init_rows:
ret_tuple = (iter(input_rows), iter(init_rows))
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple)

yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)

Expand Down
16 changes: 10 additions & 6 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,12 +797,14 @@ def wrapped(stateful_processor_api_client, mode, key, value_series_gen):

def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, runner_conf):
def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
import pandas as pd

# Split the generator into two using itertools.tee
state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2)
state_values = (df for x, _ in state_values_gen if not (df := pd.concat(x, axis=1)).empty)
init_states = (df for _, x in init_states_gen if not (df := pd.concat(x, axis=1)).empty)

# Extract just the data DataFrames (first element of each tuple)
state_values = (data_df for data_df, _ in state_values_gen if not data_df.empty)

# Extract just the init DataFrames (second element of each tuple)
init_states = (init_df for _, init_df in init_states_gen if not init_df.empty)
result_iter = f(stateful_processor_api_client, mode, key, state_values, init_states)

# TODO(SPARK-49100): add verification that elements in result_iter are
Expand Down Expand Up @@ -3067,6 +3069,7 @@ def values_gen():
ser.init_key_offsets = parsed_offsets[1][0]
stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)

import pandas as pd
def mapper(a):
mode = a[0]

Expand All @@ -3075,8 +3078,8 @@ def mapper(a):

def values_gen():
for x in a[2]:
retVal = [x[1][o] for o in parsed_offsets[0][1]]
initVal = [x[2][o] for o in parsed_offsets[1][1]]
retVal = x[1]
initVal = x[2]
yield retVal, initVal

# This must be generator comprehension - do not materialize.
Expand Down Expand Up @@ -3230,6 +3233,7 @@ def mapper(a):

parsed_offsets = extract_key_value_indexes(arg_offsets)

import pandas as pd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random place for an import

def mapper(a):
df1_keys = [a[0][o] for o in parsed_offsets[0][0]]
df1_vals = [a[0][o] for o in parsed_offsets[0][1]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,8 @@ class BaseStreamingArrowWriter(
((arrowMaxBytesPerBatch != Int.MaxValue)
&& (arrowWriterForData.sizeInBytes() >= arrowMaxBytesPerBatch))
}

def getTotalNumRowsForBatch: Int = {
totalNumRowsForBatch
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF}
import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{CoGroupedIterator, SparkPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowPythonRunner
import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, groupAndProject, resolveArgOffsets}
Expand Down Expand Up @@ -347,9 +347,9 @@ case class TransformWithStateInPySparkExec(
val initData =
groupAndProject(initStateIterator, initialStateGroupingAttrs,
initialState.output, initDedupAttributes)
// group input rows and initial state rows by the same grouping key
val groupedData: Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] =
new CoGroupedIterator(data, initData, groupingAttributes)
// concatenate input rows and initial state rows iterators
val inputIter: Iterator[((InternalRow, Iterator[InternalRow]), Boolean)] =
initData.map { item => (item, true) } ++ data.map { item => (item, false) }

val evalType = {
if (userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS) {
Expand All @@ -374,7 +374,7 @@ case class TransformWithStateInPySparkExec(
batchTimestampMs,
eventTimeWatermarkForEviction
)
executePython(groupedData, output, runner)
executePython(inputIter, output, runner)
}

CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, {
Expand Down
Loading