diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 56d338fc1371..7b2ad4b1c7a3 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1466,7 +1466,9 @@ def __init__( self.result_state_pdf_arrow_type = to_arrow_type( self.result_state_df_type, prefers_large_types=prefers_large_var_types ) - self.arrow_max_records_per_batch = arrow_max_records_per_batch + self.arrow_max_records_per_batch = ( + arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 + ) def load_stream(self, stream): """ @@ -1821,13 +1823,29 @@ def __init__( int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, arrow_cast=True, ) - self.arrow_max_records_per_batch = arrow_max_records_per_batch + self.arrow_max_records_per_batch = ( + arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 + ) self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch self.key_offsets = None self.average_arrow_row_size = 0 self.total_bytes = 0 self.total_rows = 0 + def _update_batch_size_stats(self, batch): + """ + Update batch size statistics for adaptive batching. + """ + # Short circuit batch size calculation if the batch size is + # unlimited as computing batch size is computationally expensive. + if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0: + batch_bytes = sum( + buf.size for col in batch.columns for buf in col.buffers() if buf is not None + ) + self.total_bytes += batch_bytes + self.total_rows += batch.num_rows + self.average_arrow_row_size = self.total_bytes / self.total_rows + def load_stream(self, stream): """ Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunk, and @@ -1855,18 +1873,7 @@ def generate_data_batches(batches): def row_stream(): for batch in batches: - # Short circuit batch size calculation if the batch size is - # unlimited as computing batch size is computationally expensive. - if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0: - batch_bytes = sum( - buf.size - for col in batch.columns - for buf in col.buffers() - if buf is not None - ) - self.total_bytes += batch_bytes - self.total_rows += batch.num_rows - self.average_arrow_row_size = self.total_bytes / self.total_rows + self._update_batch_size_stats(batch) data_pandas = [ self.arrow_to_pandas(c, i) for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) @@ -1946,6 +1953,7 @@ def __init__( def load_stream(self, stream): import pyarrow as pa + import pandas as pd from pyspark.sql.streaming.stateful_processor_util import ( TransformWithStateInPandasFuncMode, ) @@ -1964,6 +1972,12 @@ def generate_data_batches(batches): def flatten_columns(cur_batch, col_name): state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) + + # Check if the entire column is null + if state_column.null_count == len(state_column): + # Return empty table with no columns + return pa.Table.from_arrays([], names=[]) + state_field_names = [ state_column.type[i].name for i in range(state_column.type.num_fields) ] @@ -1981,30 +1995,67 @@ def flatten_columns(cur_batch, col_name): .add("inputData", dataSchema) .add("initState", initStateSchema) We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python - data generator. All rows in the same batch have the same grouping key. + data generator. Rows in the same batch may have different grouping keys, + but each batch will have either init_data or input_data, not mix. """ - for batch in batches: - flatten_state_table = flatten_columns(batch, "inputData") - data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(flatten_state_table.itercolumns()) - ] - flatten_init_table = flatten_columns(batch, "initState") - init_data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(flatten_init_table.itercolumns()) - ] - key_series = [data_pandas[o] for o in self.key_offsets] - init_key_series = [init_data_pandas[o] for o in self.init_key_offsets] + def row_stream(): + for batch in batches: + self._update_batch_size_stats(batch) - if any(s.empty for s in key_series): - # If any row is empty, assign batch_key using init_key_series - batch_key = tuple(s[0] for s in init_key_series) - else: - # If all rows are non-empty, create batch_key from key_series - batch_key = tuple(s[0] for s in key_series) - yield (batch_key, data_pandas, init_data_pandas) + flatten_state_table = flatten_columns(batch, "inputData") + data_pandas = [ + self.arrow_to_pandas(c, i) + for i, c in enumerate(flatten_state_table.itercolumns()) + ] + + if bool(data_pandas): + for row in pd.concat(data_pandas, axis=1).itertuples(index=False): + batch_key = tuple(row[s] for s in self.key_offsets) + yield (batch_key, row, None) + else: + flatten_init_table = flatten_columns(batch, "initState") + init_data_pandas = [ + self.arrow_to_pandas(c, i) + for i, c in enumerate(flatten_init_table.itercolumns()) + ] + if bool(init_data_pandas): + for row in pd.concat(init_data_pandas, axis=1).itertuples(index=False): + batch_key = tuple(row[s] for s in self.init_key_offsets) + yield (batch_key, None, row) + + EMPTY_DATAFRAME = pd.DataFrame() + for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]): + rows = [] + init_state_rows = [] + for _, row, init_state_row in group_rows: + if row is not None: + rows.append(row) + if init_state_row is not None: + init_state_rows.append(init_state_row) + + total_len = len(rows) + len(init_state_rows) + if ( + total_len >= self.arrow_max_records_per_batch + or total_len * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch + ): + yield ( + batch_key, + pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(), + pd.DataFrame(init_state_rows) + if len(init_state_rows) > 0 + else EMPTY_DATAFRAME.copy(), + ) + rows = [] + init_state_rows = [] + if rows or init_state_rows: + yield ( + batch_key, + pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(), + pd.DataFrame(init_state_rows) + if len(init_state_rows) > 0 + else EMPTY_DATAFRAME.copy(), + ) _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) @@ -2030,7 +2081,9 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer): def __init__(self, arrow_max_records_per_batch): super(TransformWithStateInPySparkRowSerializer, self).__init__() - self.arrow_max_records_per_batch = arrow_max_records_per_batch + self.arrow_max_records_per_batch = ( + arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 + ) self.key_offsets = None def load_stream(self, stream): @@ -2122,13 +2175,13 @@ def __init__(self, arrow_max_records_per_batch): self.init_key_offsets = None def load_stream(self, stream): - import itertools import pyarrow as pa from pyspark.sql.streaming.stateful_processor_util import ( TransformWithStateInPandasFuncMode, ) + from typing import Iterator, Any, Optional, Tuple - def generate_data_batches(batches): + def generate_data_batches(batches) -> Iterator[Tuple[Any, Optional[Any], Optional[Any]]]: """ Deserialize ArrowRecordBatches and return a generator of Row. The deserialization logic assumes that Arrow RecordBatches contain the data with the @@ -2139,8 +2192,15 @@ def generate_data_batches(batches): into the data generator. """ - def extract_rows(cur_batch, col_name, key_offsets): + def extract_rows( + cur_batch, col_name, key_offsets + ) -> Optional[Iterator[Tuple[Any, Any]]]: data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) + + # Check if the entire column is null + if data_column.null_count == len(data_column): + return None + data_field_names = [ data_column.type[i].name for i in range(data_column.type.num_fields) ] @@ -2153,18 +2213,17 @@ def extract_rows(cur_batch, col_name, key_offsets): table = pa.Table.from_arrays(data_field_arrays, names=data_field_names) if table.num_rows == 0: - return (None, iter([])) - else: - batch_key = tuple(table.column(o)[0].as_py() for o in key_offsets) + return None - rows = [] + def row_iterator(): for row_idx in range(table.num_rows): + key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets) row = DataRow( *(table.column(i)[row_idx].as_py() for i in range(table.num_columns)) ) - rows.append(row) + yield (key, row) - return (batch_key, iter(rows)) + return row_iterator() """ The arrow batch is written in the schema: @@ -2172,49 +2231,44 @@ def extract_rows(cur_batch, col_name, key_offsets): .add("inputData", dataSchema) .add("initState", initStateSchema) We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python - data generator. All rows in the same batch have the same grouping key. + data generator. Each batch will have either init_data or input_data, not mix. """ for batch in batches: - (input_batch_key, input_data_iter) = extract_rows( - batch, "inputData", self.key_offsets - ) - (init_batch_key, init_state_iter) = extract_rows( - batch, "initState", self.init_key_offsets - ) + # Detect which column has data - each batch contains only one type + input_result = extract_rows(batch, "inputData", self.key_offsets) - if input_batch_key is None: - batch_key = init_batch_key + if input_result is not None: + for key, input_data_row in input_result: + yield (key, input_data_row, None) else: - batch_key = input_batch_key - - for init_state_row in init_state_iter: - yield (batch_key, None, init_state_row) - - for input_data_row in input_data_iter: - yield (batch_key, input_data_row, None) + init_result = extract_rows(batch, "initState", self.init_key_offsets) + if init_result is not None: + for key, init_state_row in init_result: + yield (key, None, init_state_row) _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): - # g: list(batch_key, input_data_iter, init_state_iter) - - # they are sharing the iterator, hence need to copy - input_values_iter, init_state_iter = itertools.tee(g, 2) - - chained_input_values = itertools.chain(map(lambda x: x[1], input_values_iter)) - chained_init_state_values = itertools.chain(map(lambda x: x[2], init_state_iter)) - - chained_input_values_without_none = filter( - lambda x: x is not None, chained_input_values - ) - chained_init_state_values_without_none = filter( - lambda x: x is not None, chained_init_state_values - ) - - ret_tuple = (chained_input_values_without_none, chained_init_state_values_without_none) - - yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) + input_rows = [] + init_rows = [] + + for batch_key, input_row, init_row in g: + if input_row is not None: + input_rows.append(input_row) + if init_row is not None: + init_rows.append(init_row) + + total_len = len(input_rows) + len(init_rows) + if total_len >= self.arrow_max_records_per_batch: + ret_tuple = (iter(input_rows), iter(init_rows)) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) + input_rows = [] + init_rows = [] + + if input_rows or init_rows: + ret_tuple = (iter(input_rows), iter(init_rows)) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) diff --git a/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py index 57125c7820c0..ecdfcfda3c1d 100644 --- a/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py @@ -1483,6 +1483,96 @@ def check_results(batch_df, batch_id): ), ) + def test_transform_with_state_with_records_limit(self): + if not self.use_pandas(): + return + + def make_check_results(expected_per_batch): + def check_results(batch_df, batch_id): + batch_df.collect() + if batch_id == 0: + assert set(batch_df.sort("id").collect()) == expected_per_batch[0] + else: + assert set(batch_df.sort("id").collect()) == expected_per_batch[1] + + return check_results + + result_with_small_limit = [ + { + Row(id="0", chunkCount=2), + Row(id="1", chunkCount=2), + }, + { + Row(id="0", chunkCount=3), + Row(id="1", chunkCount=2), + }, + ] + + result_with_large_limit = [ + { + Row(id="0", chunkCount=1), + Row(id="1", chunkCount=1), + }, + { + Row(id="0", chunkCount=1), + Row(id="1", chunkCount=1), + }, + ] + + data = [("0", 789), ("3", 987)] + initial_state = self.spark.createDataFrame(data, "id string, initVal int").groupBy("id") + + with self.sql_conf( + # Set it to a very small number so that every row would be a separate pandas df + {"spark.sql.execution.arrow.maxRecordsPerBatch": "1"} + ): + self._test_transform_with_state_basic( + ChunkCountProcessorFactory(), + make_check_results(result_with_small_limit), + output_schema=StructType( + [ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True), + ] + ), + ) + + self._test_transform_with_state_basic( + ChunkCountProcessorWithInitialStateFactory(), + make_check_results(result_with_small_limit), + initial_state=initial_state, + output_schema=StructType( + [ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True), + ] + ), + ) + + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": "-1"}): + self._test_transform_with_state_basic( + ChunkCountProcessorFactory(), + make_check_results(result_with_large_limit), + output_schema=StructType( + [ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True), + ] + ), + ) + + self._test_transform_with_state_basic( + ChunkCountProcessorWithInitialStateFactory(), + make_check_results(result_with_large_limit), + initial_state=initial_state, + output_schema=StructType( + [ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True), + ] + ), + ) + # test all state types (value, list, map) with large values (512 KB) def test_transform_with_state_large_values(self): def check_results(batch_df, batch_id): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6e34b041665a..3c70836ea127 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -797,12 +797,14 @@ def wrapped(stateful_processor_api_client, mode, key, value_series_gen): def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, runner_conf): def wrapped(stateful_processor_api_client, mode, key, value_series_gen): - import pandas as pd - + # Split the generator into two using itertools.tee state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2) - state_values = (df for x, _ in state_values_gen if not (df := pd.concat(x, axis=1)).empty) - init_states = (df for _, x in init_states_gen if not (df := pd.concat(x, axis=1)).empty) + # Extract just the data DataFrames (first element of each tuple) + state_values = (data_df for data_df, _ in state_values_gen if not data_df.empty) + + # Extract just the init DataFrames (second element of each tuple) + init_states = (init_df for _, init_df in init_states_gen if not init_df.empty) result_iter = f(stateful_processor_api_client, mode, key, state_values, init_states) # TODO(SPARK-49100): add verification that elements in result_iter are @@ -3075,8 +3077,8 @@ def mapper(a): def values_gen(): for x in a[2]: - retVal = [x[1][o] for o in parsed_offsets[0][1]] - initVal = [x[2][o] for o in parsed_offsets[1][1]] + retVal = x[1] + initVal = x[2] yield retVal, initVal # This must be generator comprehension - do not materialize. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala index f0371cafb72a..8c9ab2a8c636 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala @@ -88,10 +88,14 @@ class BaseStreamingArrowWriter( protected def isBatchSizeLimitReached: Boolean = { // If we have either reached the records or bytes limit - totalNumRowsForBatch >= arrowMaxRecordsPerBatch || + (arrowMaxRecordsPerBatch > 0 && totalNumRowsForBatch >= arrowMaxRecordsPerBatch) || // Short circuit batch size calculation if the batch size is unlimited as computing batch // size is computationally expensive. ((arrowMaxBytesPerBatch != Int.MaxValue) && (arrowWriterForData.sizeInBytes() >= arrowMaxBytesPerBatch)) } + + def getTotalNumRowsForBatch: Int = { + totalNumRowsForBatch + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala index a1cf71844950..1ceaf6c4bf81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF} import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.execution.{CoGroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowPythonRunner import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, groupAndProject, resolveArgOffsets} @@ -347,9 +347,9 @@ case class TransformWithStateInPySparkExec( val initData = groupAndProject(initStateIterator, initialStateGroupingAttrs, initialState.output, initDedupAttributes) - // group input rows and initial state rows by the same grouping key - val groupedData: Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] = - new CoGroupedIterator(data, initData, groupingAttributes) + // concatenate input rows and initial state rows iterators + val inputIter: Iterator[((InternalRow, Iterator[InternalRow]), Boolean)] = + initData.map { item => (item, true) } ++ data.map { item => (item, false) } val evalType = { if (userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS) { @@ -374,7 +374,7 @@ case class TransformWithStateInPySparkExec( batchTimestampMs, eventTimeWatermarkForEviction ) - executePython(groupedData, output, runner) + executePython(inputIter, output, runner) } CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala index 42d4ad68c29a..b526d823ee09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala @@ -144,6 +144,9 @@ class TransformWithStateInPySparkPythonInitialStateRunner( private var pandasWriter: BaseStreamingArrowWriter = _ + private var currentDataIterator: Iterator[InternalRow] = _ + private var isCurrentIterFromInitState: Option[Boolean] = None + override protected def writeNextBatchToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, @@ -158,30 +161,52 @@ class TransformWithStateInPySparkPythonInitialStateRunner( ) } - if (inputIterator.hasNext) { - val startData = dataOut.size() - // a new grouping key with data & init state iter - val next = inputIterator.next() - val dataIter = next._2 - val initIter = next._3 - - while (dataIter.hasNext || initIter.hasNext) { - val dataRow = - if (dataIter.hasNext) dataIter.next() - else InternalRow.empty - val initRow = - if (initIter.hasNext) initIter.next() - else InternalRow.empty - pandasWriter.writeRow(InternalRow(dataRow, initRow)) + + // If we don't have data left for the current group, move to the next group. + if (currentDataIterator == null && inputIterator.hasNext) { + val ((_, data), isInitState) = inputIterator.next() + currentDataIterator = data + val isPrevIterFromInitState = isCurrentIterFromInitState + isCurrentIterFromInitState = Some(isInitState) + if (isPrevIterFromInitState.isDefined && + isPrevIterFromInitState.get != isInitState && + pandasWriter.getTotalNumRowsForBatch > 0) { + // So we won't have batches with mixed data and init state. + pandasWriter.finalizeCurrentArrowBatch() + return true } - pandasWriter.finalizeCurrentArrowBatch() - val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData + } + + val startData = dataOut.size() + + val hasInput = if (currentDataIterator != null) { + var isCurrentBatchFull = false + // Stop writing when the current arrowBatch is finalized/full. If we have rows left + while (currentDataIterator.hasNext && !isCurrentBatchFull) { + val dataRow = currentDataIterator.next() + isCurrentBatchFull = if (isCurrentIterFromInitState.get) { + pandasWriter.writeRow(InternalRow(null, dataRow)) + } else { + pandasWriter.writeRow(InternalRow(dataRow, null)) + } + } + + if (!currentDataIterator.hasNext) { + currentDataIterator = null + } + true } else { + if (pandasWriter.getTotalNumRowsForBatch > 0) { + pandasWriter.finalizeCurrentArrowBatch() + } super[PythonArrowInput].close() false } + + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData + hasInput } } @@ -392,5 +417,7 @@ trait TransformWithStateInPySparkPythonRunnerUtils extends Logging { object TransformWithStateInPySparkPythonRunner { type InType = (InternalRow, Iterator[InternalRow]) - type GroupedInType = (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) + + // ((key, rows), isInitState) + type GroupedInType = ((InternalRow, Iterator[InternalRow]), Boolean) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala index fc10a102b4f5..49839fb8c985 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala @@ -95,4 +95,43 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite with BeforeAndAfterEac verify(writer, times(2)).writeBatch() verify(arrowWriter, times(2)).reset() } + + test("test negative or zero arrowMaxRecordsPerBatch is unlimited") { + val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot]) + val dataRow = mock(classOf[InternalRow]) + + // Test with negative value + transformWithStateInPySparkWriter = new BaseStreamingArrowWriter( + root, writer, -1, arrowMaxBytesPerBatch, arrowWriter) + + // Write many rows (more than typical batch size) + for (_ <- 1 to 10) { + transformWithStateInPySparkWriter.writeRow(dataRow) + } + + // Verify all rows were written but batch was not finalized + verify(arrowWriter, times(10)).write(dataRow) + verify(writer, never()).writeBatch() + + // Only finalize when explicitly called + transformWithStateInPySparkWriter.finalizeCurrentArrowBatch() + verify(writer).writeBatch() + + // Test with zero value + transformWithStateInPySparkWriter = new BaseStreamingArrowWriter( + root, writer, 0, arrowMaxBytesPerBatch, arrowWriter) + + // Write many rows again + for (_ <- 1 to 10) { + transformWithStateInPySparkWriter.writeRow(dataRow) + } + + // Verify rows were written but batch was not finalized + verify(arrowWriter, times(20)).write(dataRow) + verify(writer).writeBatch() // still 1 from before + + // Only finalize when explicitly called + transformWithStateInPySparkWriter.finalizeCurrentArrowBatch() + verify(writer, times(2)).writeBatch() + } }