-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-54392][SS] Optimize JVM-Python communication for TWS initial state #53122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
35dd2b9
af121b2
bdf52de
ab2d6c0
64dd204
fcde4f9
621c23d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1946,6 +1946,7 @@ def __init__( | |
|
|
||
| def load_stream(self, stream): | ||
| import pyarrow as pa | ||
| import pandas as pd | ||
| from pyspark.sql.streaming.stateful_processor_util import ( | ||
| TransformWithStateInPandasFuncMode, | ||
| ) | ||
|
|
@@ -1964,6 +1965,12 @@ def generate_data_batches(batches): | |
|
|
||
| def flatten_columns(cur_batch, col_name): | ||
| state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) | ||
|
|
||
| # Check if the entire column is null | ||
| if state_column.null_count == len(state_column): | ||
| # Return empty table with no columns | ||
| return pa.Table.from_arrays([], names=[]) | ||
|
|
||
| state_field_names = [ | ||
| state_column.type[i].name for i in range(state_column.type.num_fields) | ||
| ] | ||
|
|
@@ -1981,30 +1988,71 @@ def flatten_columns(cur_batch, col_name): | |
| .add("inputData", dataSchema) | ||
| .add("initState", initStateSchema) | ||
| We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python | ||
| data generator. All rows in the same batch have the same grouping key. | ||
| data generator. Rows in the same batch may have different grouping keys, | ||
| but each batch will have either init_data or input_data, not mix. | ||
| """ | ||
| for batch in batches: | ||
| flatten_state_table = flatten_columns(batch, "inputData") | ||
| data_pandas = [ | ||
| self.arrow_to_pandas(c, i) | ||
| for i, c in enumerate(flatten_state_table.itercolumns()) | ||
| ] | ||
| def row_stream(): | ||
| for batch in batches: | ||
| if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0: | ||
| batch_bytes = sum( | ||
| buf.size | ||
| for col in batch.columns | ||
| for buf in col.buffers() | ||
| if buf is not None | ||
| ) | ||
| self.total_bytes += batch_bytes | ||
| self.total_rows += batch.num_rows | ||
| self.average_arrow_row_size = self.total_bytes / self.total_rows | ||
|
|
||
| flatten_init_table = flatten_columns(batch, "initState") | ||
| init_data_pandas = [ | ||
| self.arrow_to_pandas(c, i) | ||
| for i, c in enumerate(flatten_init_table.itercolumns()) | ||
| ] | ||
| key_series = [data_pandas[o] for o in self.key_offsets] | ||
| init_key_series = [init_data_pandas[o] for o in self.init_key_offsets] | ||
| flatten_state_table = flatten_columns(batch, "inputData") | ||
| data_pandas = [ | ||
| self.arrow_to_pandas(c, i) | ||
| for i, c in enumerate(flatten_state_table.itercolumns()) | ||
| ] | ||
|
|
||
| if any(s.empty for s in key_series): | ||
| # If any row is empty, assign batch_key using init_key_series | ||
| batch_key = tuple(s[0] for s in init_key_series) | ||
| else: | ||
| # If all rows are non-empty, create batch_key from key_series | ||
| batch_key = tuple(s[0] for s in key_series) | ||
| yield (batch_key, data_pandas, init_data_pandas) | ||
| if bool(data_pandas): | ||
| for row in pd.concat(data_pandas, axis=1).itertuples(index=False): | ||
| batch_key = tuple(row[s] for s in self.key_offsets) | ||
| yield (batch_key, row, None) | ||
| else: | ||
| flatten_init_table = flatten_columns(batch, "initState") | ||
| init_data_pandas = [ | ||
| self.arrow_to_pandas(c, i) | ||
| for i, c in enumerate(flatten_init_table.itercolumns()) | ||
| ] | ||
| if bool(init_data_pandas): | ||
| for row in pd.concat(init_data_pandas, axis=1).itertuples(index=False): | ||
| batch_key = tuple(row[s] for s in self.init_key_offsets) | ||
| yield (batch_key, None, row) | ||
|
|
||
| EMPTY_DATAFRAME = pd.DataFrame() | ||
| for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]): | ||
| rows = [] | ||
| init_state_rows = [] | ||
| for _, row, init_state_row in group_rows: | ||
| if row is not None: | ||
| rows.append(row) | ||
| if init_state_row is not None: | ||
| init_state_rows.append(init_state_row) | ||
|
|
||
| total_len = len(rows) + len(init_state_rows) | ||
| if ( | ||
| total_len >= self.arrow_max_records_per_batch | ||
| or total_len * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch | ||
| ): | ||
| yield ( | ||
| batch_key, | ||
| pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(), | ||
| pd.DataFrame(init_state_rows) if len(init_state_rows) > 0 else EMPTY_DATAFRAME.copy() | ||
| ) | ||
| rows = [] | ||
| init_state_rows = [] | ||
| if rows or init_state_rows: | ||
| yield ( | ||
| batch_key, | ||
| pd.DataFrame(rows) if len(rows) > 0 else EMPTY_DATAFRAME.copy(), | ||
| pd.DataFrame(init_state_rows) if len(init_state_rows) > 0 else EMPTY_DATAFRAME.copy() | ||
| ) | ||
|
|
||
| _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) | ||
| data_batches = generate_data_batches(_batches) | ||
|
|
@@ -2122,7 +2170,6 @@ def __init__(self, arrow_max_records_per_batch): | |
| self.init_key_offsets = None | ||
|
|
||
| def load_stream(self, stream): | ||
| import itertools | ||
| import pyarrow as pa | ||
| from pyspark.sql.streaming.stateful_processor_util import ( | ||
| TransformWithStateInPandasFuncMode, | ||
|
|
@@ -2141,6 +2188,11 @@ def generate_data_batches(batches): | |
|
|
||
| def extract_rows(cur_batch, col_name, key_offsets): | ||
| data_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) | ||
|
|
||
| # Check if the entire column is null | ||
| if data_column.null_count == len(data_column): | ||
| return None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given we've changed the implicit type signature of the function lets maybe add a type annotation on generate_data_batches for readability.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
|
||
| data_field_names = [ | ||
| data_column.type[i].name for i in range(data_column.type.num_fields) | ||
| ] | ||
|
|
@@ -2153,68 +2205,62 @@ def extract_rows(cur_batch, col_name, key_offsets): | |
| table = pa.Table.from_arrays(data_field_arrays, names=data_field_names) | ||
|
|
||
| if table.num_rows == 0: | ||
| return (None, iter([])) | ||
| else: | ||
| batch_key = tuple(table.column(o)[0].as_py() for o in key_offsets) | ||
| return None | ||
|
|
||
| rows = [] | ||
| def row_iterator(): | ||
| for row_idx in range(table.num_rows): | ||
| key = tuple(table.column(o)[row_idx].as_py() for o in key_offsets) | ||
| row = DataRow( | ||
| *(table.column(i)[row_idx].as_py() for i in range(table.num_columns)) | ||
| ) | ||
| rows.append(row) | ||
| yield (key, row) | ||
|
|
||
| return (batch_key, iter(rows)) | ||
| return row_iterator() | ||
|
|
||
| """ | ||
| The arrow batch is written in the schema: | ||
| schema: StructType = new StructType() | ||
| .add("inputData", dataSchema) | ||
| .add("initState", initStateSchema) | ||
| We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python | ||
| data generator. All rows in the same batch have the same grouping key. | ||
| data generator. Each batch will have either init_data or input_data, not mix. | ||
| """ | ||
| for batch in batches: | ||
| (input_batch_key, input_data_iter) = extract_rows( | ||
| batch, "inputData", self.key_offsets | ||
| ) | ||
| (init_batch_key, init_state_iter) = extract_rows( | ||
| batch, "initState", self.init_key_offsets | ||
| ) | ||
| # Detect which column has data - each batch contains only one type | ||
| input_result = extract_rows(batch, "inputData", self.key_offsets) | ||
|
|
||
| if input_batch_key is None: | ||
| batch_key = init_batch_key | ||
| if input_result is not None: | ||
| for key, input_data_row in input_result: | ||
| yield (key, input_data_row, None) | ||
| else: | ||
| batch_key = input_batch_key | ||
|
|
||
| for init_state_row in init_state_iter: | ||
| yield (batch_key, None, init_state_row) | ||
|
|
||
| for input_data_row in input_data_iter: | ||
| yield (batch_key, input_data_row, None) | ||
| init_result = extract_rows(batch, "initState", self.init_key_offsets) | ||
| if init_result is not None: | ||
| for key, init_state_row in init_result: | ||
| yield (key, None, init_state_row) | ||
|
|
||
| _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) | ||
| data_batches = generate_data_batches(_batches) | ||
|
|
||
| for k, g in groupby(data_batches, key=lambda x: x[0]): | ||
| # g: list(batch_key, input_data_iter, init_state_iter) | ||
|
|
||
| # they are sharing the iterator, hence need to copy | ||
| input_values_iter, init_state_iter = itertools.tee(g, 2) | ||
|
|
||
| chained_input_values = itertools.chain(map(lambda x: x[1], input_values_iter)) | ||
| chained_init_state_values = itertools.chain(map(lambda x: x[2], init_state_iter)) | ||
|
|
||
| chained_input_values_without_none = filter( | ||
| lambda x: x is not None, chained_input_values | ||
| ) | ||
| chained_init_state_values_without_none = filter( | ||
| lambda x: x is not None, chained_init_state_values | ||
| ) | ||
|
|
||
| ret_tuple = (chained_input_values_without_none, chained_init_state_values_without_none) | ||
|
|
||
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) | ||
| input_rows = [] | ||
| init_rows = [] | ||
|
|
||
| for batch_key, input_row, init_row in g: | ||
| if input_row is not None: | ||
| input_rows.append(input_row) | ||
| if init_row is not None: | ||
| init_rows.append(init_row) | ||
|
|
||
| total_len = len(input_rows) + len(init_rows) | ||
| if total_len >= self.arrow_max_records_per_batch: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The SQLConf config param says if set to zero or negative number there is no limit, in this case if it's set to zero or a negative number we will always output a fresh batch per row. Let's change the behaviour and add a test covering this.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, right;
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| ret_tuple = (iter(input_rows), iter(init_rows)) | ||
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) | ||
| input_rows = [] | ||
| init_rows = [] | ||
|
|
||
| if input_rows or init_rows: | ||
| ret_tuple = (iter(input_rows), iter(init_rows)) | ||
| yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) | ||
|
|
||
| yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -797,12 +797,14 @@ def wrapped(stateful_processor_api_client, mode, key, value_series_gen): | |
|
|
||
| def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, runner_conf): | ||
| def wrapped(stateful_processor_api_client, mode, key, value_series_gen): | ||
| import pandas as pd | ||
|
|
||
| # Split the generator into two using itertools.tee | ||
| state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2) | ||
| state_values = (df for x, _ in state_values_gen if not (df := pd.concat(x, axis=1)).empty) | ||
| init_states = (df for _, x in init_states_gen if not (df := pd.concat(x, axis=1)).empty) | ||
|
|
||
| # Extract just the data DataFrames (first element of each tuple) | ||
| state_values = (data_df for data_df, _ in state_values_gen if not data_df.empty) | ||
|
|
||
| # Extract just the init DataFrames (second element of each tuple) | ||
| init_states = (init_df for _, init_df in init_states_gen if not init_df.empty) | ||
| result_iter = f(stateful_processor_api_client, mode, key, state_values, init_states) | ||
|
|
||
| # TODO(SPARK-49100): add verification that elements in result_iter are | ||
|
|
@@ -3067,6 +3069,7 @@ def values_gen(): | |
| ser.init_key_offsets = parsed_offsets[1][0] | ||
| stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) | ||
|
|
||
| import pandas as pd | ||
| def mapper(a): | ||
| mode = a[0] | ||
|
|
||
|
|
@@ -3075,8 +3078,8 @@ def mapper(a): | |
|
|
||
| def values_gen(): | ||
| for x in a[2]: | ||
| retVal = [x[1][o] for o in parsed_offsets[0][1]] | ||
| initVal = [x[2][o] for o in parsed_offsets[1][1]] | ||
| retVal = x[1] | ||
| initVal = x[2] | ||
| yield retVal, initVal | ||
|
|
||
| # This must be generator comprehension - do not materialize. | ||
|
|
@@ -3230,6 +3233,7 @@ def mapper(a): | |
|
|
||
| parsed_offsets = extract_key_value_indexes(arg_offsets) | ||
|
|
||
| import pandas as pd | ||
|
||
| def mapper(a): | ||
| df1_keys = [a[0][o] for o in parsed_offsets[0][0]] | ||
| df1_vals = [a[0][o] for o in parsed_offsets[0][1]] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic seems to be duplicated from elsewhere in the file, maybe we can add it to a base class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done