From ac4bd31e6cc081d0b385617cbb63e249d95b1024 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 17 Nov 2025 22:58:47 +0000 Subject: [PATCH 01/12] scan simple operator state --- .../v2/state/StateDataSource.scala | 68 +++-- .../v2/state/StatePartitionReader.scala | 91 ++++++- .../v2/state/utils/SchemaUtil.scala | 34 ++- .../streaming/state/StateStore.scala | 14 + ...artitionReaderAllColumnFamiliesSuite.scala | 241 ++++++++++++++++++ 5 files changed, 427 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 6af418e1ddc2..5878e295369a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -66,28 +66,38 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, StateSourceOptions.apply(session, hadoopConf, properties)) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) - val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( - sourceOptions) - - // The key state encoder spec should be available for all operators except stream-stream joins - val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) { - stateStoreReaderInfo.keyStateEncoderSpecOpt.get + if (sourceOptions.readAllColumnFamilies) { + // For readAllColumnFamilies mode, we don't need specific metadata + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType()) + new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, + None, None, None, None) } else { - val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] - NoPrefixKeyStateEncoderSpec(keySchema) - } + val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( + sourceOptions) - new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, - stateStoreReaderInfo.transformWithStateVariableInfoOpt, - stateStoreReaderInfo.stateStoreColFamilySchemaOpt, - stateStoreReaderInfo.stateSchemaProviderOpt, - stateStoreReaderInfo.joinColFamilyOpt) + // The key state encoder spec should be available for all operators except stream-stream joins + val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) { + stateStoreReaderInfo.keyStateEncoderSpecOpt.get + } else { + val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] + NoPrefixKeyStateEncoderSpec(keySchema) + } + new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, + stateStoreReaderInfo.transformWithStateVariableInfoOpt, + stateStoreReaderInfo.stateStoreColFamilySchemaOpt, + stateStoreReaderInfo.stateSchemaProviderOpt, + stateStoreReaderInfo.joinColFamilyOpt) + } } override def inferSchema(options: CaseInsensitiveStringMap): StructType = { val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, StateSourceOptions.apply(session, hadoopConf, options)) - + if (sourceOptions.readAllColumnFamilies) { + // For readAllColumnFamilies mode, return the binary schema directly + return SchemaUtil.getSourceSchema( + sourceOptions, new StructType(), new StructType(), None, None) + } val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( sourceOptions) val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf) @@ -372,6 +382,7 @@ case class StateSourceOptions( stateVarName: Option[String], readRegisteredTimers: Boolean, flattenCollectionTypes: Boolean, + readAllColumnFamilies: Boolean, startOperatorStateUniqueIds: Option[Array[Array[String]]] = None, endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) @@ -380,7 +391,8 @@ case class StateSourceOptions( var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " + s"stateVarName=${stateVarName.getOrElse("None")}, +" + - s"flattenCollectionTypes=$flattenCollectionTypes" + s"flattenCollectionTypes=$flattenCollectionTypes" + + s"readAllColumnFamilies=$readAllColumnFamilies" if (fromSnapshotOptions.isDefined) { desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}" desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}" @@ -407,6 +419,7 @@ object StateSourceOptions extends DataSourceOptions { val STATE_VAR_NAME = newOption("stateVarName") val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers") val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes") + val READ_ALL_COLUMN_FAMILIES = newOption("readAllColumnFamilies") object JoinSideValues extends Enumeration { type JoinSideValues = Value @@ -492,6 +505,27 @@ object StateSourceOptions extends DataSourceOptions { val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean) + val readAllColumnFamilies = try { + Option(options.get(READ_ALL_COLUMN_FAMILIES)) + .map(_.toBoolean).getOrElse(false) + } catch { + case _: IllegalArgumentException => + throw StateDataSourceErrors.invalidOptionValue(READ_ALL_COLUMN_FAMILIES, + "Boolean value is expected") + } + + if (readAllColumnFamilies && stateVarName.isDefined) { + throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME)) + } + + if (readAllColumnFamilies && joinSide != JoinSideValues.none) { + throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, JOIN_SIDE)) + } + + if (readAllColumnFamilies && readChangeFeed) { + throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED)) + } + val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong) var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong) @@ -616,7 +650,7 @@ object StateSourceOptions extends DataSourceOptions { resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName, readRegisteredTimers, flattenCollectionTypes, - startOperatorStateUniqueIds, endOperatorStateUniqueIds) + readAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds) } private def getLastCommittedBatch(session: SparkSession, checkpointLocation: String): Long = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 619e374c00de..ca1a6cf295b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -20,12 +20,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType} -import org.apache.spark.sql.types.{NullType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, NullType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{NextIterator, SerializableConfiguration} @@ -49,7 +50,10 @@ class StatePartitionReaderFactory( override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] - if (stateStoreInputPartition.sourceOptions.readChangeFeed) { + if (stateStoreInputPartition.sourceOptions.readAllColumnFamilies) { + new StatePartitionReaderAllColumnFamilies(storeConf, hadoopConf, + stateStoreInputPartition, schema) + } else if (stateStoreInputPartition.sourceOptions.readChangeFeed) { new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt) @@ -84,6 +88,8 @@ abstract class StatePartitionReaderBase( protected val keySchema = { if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) + } else if (partition.sourceOptions.readAllColumnFamilies) { + new StructType().add("keyBytes", BinaryType, nullable = false) } else { SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] } @@ -91,6 +97,8 @@ abstract class StatePartitionReaderBase( protected val valueSchema = if (stateVariableInfoOpt.isDefined) { schemaForValueRow + } else if (partition.sourceOptions.readAllColumnFamilies) { + new StructType().add("valueBytes", BinaryType, nullable = false) } else { SchemaUtil.getSchemaAsDataType( schema, "value").asInstanceOf[StructType] @@ -237,6 +245,85 @@ class StatePartitionReader( } } +/** + * An implementation of [[StatePartitionReaderBase]] for reading all column families + * in binary format. This reader returns raw key and value bytes along with column family names. + */ +class StatePartitionReaderAllColumnFamilies( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + partition: StateStoreInputPartition, + schema: StructType) + extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, + NoPrefixKeyStateEncoderSpec(new StructType()), None, None, None, None) { + + val allStateStoreMetadata = { + new StateMetadataPartitionReader( + partition.sourceOptions.resolvedCpLocation, + new SerializableConfiguration(hadoopConf.value), + partition.sourceOptions.batchId).stateMetadata.toArray + } + + private lazy val store: ReadStateStore = { + assert(getStartStoreUniqueId == getEndStoreUniqueId, + "Start and end store unique IDs must be the same when reading all column families") + provider.getReadStore( + partition.sourceOptions.batchId + 1, + getStartStoreUniqueId + ) + } + + val colFamilyNames: Seq[String] = { + // todo: Support operator with multiple column family names in next PR + Seq[String]() + } + + override protected lazy val provider: StateStoreProvider = { + val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, + partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) + val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) + + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + val provider = StateStoreProvider.createAndInit( + stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec, + useColumnFamilies = colFamilyNames.nonEmpty, storeConf, hadoopConf.value, false, None) + + provider + } + + override lazy val iter: Iterator[InternalRow] = { + // Single store with column families (join v3, transformWithState, or simple operators) + require(store.isInstanceOf[SupportsRawBytesRead], + s"State store ${store.getClass.getName} does not support raw bytes reading") + + val rawStore = store.asInstanceOf[SupportsRawBytesRead] + if (colFamilyNames.isEmpty) { + rawStore + .rawIterator() + .map { case (keyBytes, valueBytes) => + SchemaUtil.unifyStateRowPairAsRawBytes( + partition.partition, keyBytes, valueBytes, StateStore.DEFAULT_COL_FAMILY_NAME) + } + } else { + colFamilyNames.iterator.flatMap { colFamilyName => + rawStore + .rawIterator(colFamilyName) + .map { case (keyBytes, valueBytes) => + SchemaUtil.unifyStateRowPairAsRawBytes(partition.partition, + keyBytes, + valueBytes, + colFamilyName) + } + } + } + } + + override def close(): Unit = { + store.release() + super.close() + } +} + /** * An implementation of [[StatePartitionReaderBase]] for the readChangeFeed mode of State Data * Source. It reads the change of state over batches of a particular partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 52df016791d4..095bcff6fb07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceError import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType._ import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStoreColFamilySchema, UnsafeRowPair} -import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType, LongType, MapType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ object SchemaUtil { @@ -60,6 +61,12 @@ object SchemaUtil { .add("key", keySchema) .add("value", valueSchema) .add("partition_id", IntegerType) + } else if (sourceOptions.readAllColumnFamilies) { + new StructType() + .add("partition_id", IntegerType) + .add("key_bytes", BinaryType) + .add("value_bytes", BinaryType) + .add("column_family_name", StringType) } else { new StructType() .add("key", keySchema) @@ -76,6 +83,24 @@ object SchemaUtil { row } + /** + * Creates a unified row from raw key and value bytes. + * This is an alias for unifyStateRowPairAsBytes that takes individual byte arrays + * instead of a tuple for better readability. + */ + def unifyStateRowPairAsRawBytes( + partition: Int, + keyBytes: Array[Byte], + valueBytes: Array[Byte], + colFamilyName: String): InternalRow = { + val row = new GenericInternalRow(4) + row.update(0, partition) + row.update(1, keyBytes) + row.update(2, valueBytes) + row.update(3, UTF8String.fromString(colFamilyName)) + row + } + def unifyStateRowPairWithMultipleValues( pair: (UnsafeRow, GenericArrayData), partition: Int): InternalRow = { @@ -231,7 +256,10 @@ object SchemaUtil { "user_map_key" -> classOf[StructType], "user_map_value" -> classOf[StructType], "expiration_timestamp_ms" -> classOf[LongType], - "partition_id" -> classOf[IntegerType]) + "partition_id" -> classOf[IntegerType], + "key_bytes"->classOf[BinaryType], + "value_bytes"->classOf[BinaryType], + "column_family_name"->classOf[StringType]) val expectedFieldNames = if (transformWithStateVariableInfoOpt.isDefined) { val stateVarInfo = transformWithStateVariableInfoOpt.get @@ -272,6 +300,8 @@ object SchemaUtil { } } else if (sourceOptions.readChangeFeed) { Seq("batch_id", "change_type", "key", "value", "partition_id") + } else if (sourceOptions.readAllColumnFamilies) { + Seq("partition_id", "key_bytes", "value_bytes", "column_family_name") } else { Seq("key", "value", "partition_id") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 43b95766882f..d180010e355e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -891,6 +891,20 @@ object StateStoreProvider extends Logging { } } +/** + * Trait for state stores that support reading raw bytes without decoding. + * This is useful for copying state data during repartitioning + */ +trait SupportsRawBytesRead { + /** + * Returns an iterator of raw key-value bytes for a column family. + * @param colFamilyName the name of the column family to iterate over + * @return an iterator of (keyBytes, valueBytes) tuples + */ + def rawIterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): + Iterator[(Array[Byte], Array[Byte])] +} + /** * This is an optional trait to be implemented by [[StateStoreProvider]]s that can read the change * of state store over batches. This is used by State Data Source with additional options like diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala new file mode 100644 index 000000000000..20fe0162a424 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.functions.{count, sum} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} +import org.apache.spark.tags.SlowSQLTest +import org.apache.spark.unsafe.Platform + +/** + * Test suite to verify StatePartitionReaderAllColumnFamilies functionality. + */ +@SlowSQLTest +class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase { + + import testImplicits._ + + /** + * Returns a set of (partitionId, key, value) tuples from a normal state read. + */ + private def getNormalReadData(checkpointDir: String): Set[(Int, Row, Row)] = { + val normalReadDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .load() + .selectExpr("partition_id", "key", "value") + + normalReadDf.collect() + .map { row => + val partitionId = row.getInt(0) + val key = row.getStruct(1) + val value = row.getStruct(2) + (partitionId, key, value) + } + .toSet + } + + /** + * Returns a DataFrame with raw bytes mode (READ_ALL_COLUMN_FAMILIES = true). + */ + private def getBytesReadDf(checkpointDir: String): DataFrame = { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .option(StateSourceOptions.READ_ALL_COLUMN_FAMILIES, "true") + .load() + } + + /** + * Validates the schema and column families of the bytes read DataFrame. + */ + private def validateBytesReadSchema( + df: DataFrame, + expectedRowCount: Int, + expectedColumnFamilies: Seq[String]): Unit = { + // Verify schema + val schema = df.schema + assert(schema.fieldNames === Array( + "partition_id", "key_bytes", "value_bytes", "column_family_name")) + assert(schema("partition_id").dataType.typeName === "integer") + assert(schema("key_bytes").dataType.typeName === "binary") + assert(schema("value_bytes").dataType.typeName === "binary") + assert(schema("column_family_name").dataType.typeName === "string") + + // Verify data + val rows = df.collect() + assert(rows.length == expectedRowCount, + s"Expected $expectedRowCount rows but got: ${rows.length}") + + val columnFamilies = rows.map(r => Option(r.getString(3)).getOrElse("null")).distinct.sorted + assert(columnFamilies.length == expectedColumnFamilies.length, + s"Expected ${expectedColumnFamilies.length} column families, " + + s"but got ${columnFamilies.length}: ${columnFamilies.mkString(", ")}") + + expectedColumnFamilies.foreach { expectedCF => + assert(columnFamilies.contains(expectedCF), + s"Expected column family '$expectedCF', " + + s"but got: ${columnFamilies.mkString(", ")}") + } + + // Verify all rows have non-null values + rows.foreach { row => + assert(row.getInt(0) >= 0) // partition_id non-negative + assert(row.get(1) != null) // key_bytes not null + assert(row.get(2) != null) // value_bytes not null + } + } + + /** + * Parses the bytes read DataFrame into a set of (partitionId, key, value, columnFamily) tuples. + */ + private def parseBytesReadData( + df: DataFrame, numOfKey: Int, numOfValue: Int): Set[(Int, UnsafeRow, UnsafeRow, String)] = { + df.selectExpr("partition_id", "key_bytes", "value_bytes", "column_family_name") + .collect() + .map { row => + val partitionId = row.getInt(0) + val keyBytes = row.getAs[Array[Byte]](1) + val valueBytes = row.getAs[Array[Byte]](2) + val columnFamily = row.getString(3) + + // Deserialize key bytes to UnsafeRow + // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning + val keyRow = new UnsafeRow(numOfKey) + keyRow.pointTo( + keyBytes, + Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES, + keyBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES) + + // Deserialize value bytes to UnsafeRow + // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning + val valueRow = new UnsafeRow(numOfValue) + valueRow.pointTo( + valueBytes, + Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES, + valueBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES) + + (partitionId, keyRow.copy(), valueRow.copy(), columnFamily) + } + .toSet + } + + /** + * Compares normal read data with bytes read data for a specific column family. + */ + private def compareNormalAndBytesData( + normalData: Set[(Int, Row, Row)], + bytesData: Set[(Int, UnsafeRow, UnsafeRow, String)], + columnFamily: String, + keySchema: StructType, + valueSchema: StructType): Unit = { + // Filter bytes data for the specified column family + val filteredBytesData = bytesData.filter(_._4 == columnFamily) + + // Verify same number of rows + assert(filteredBytesData.size == normalData.size, + s"Row count mismatch for column family '$columnFamily': " + + s"normal read has ${filteredBytesData.size} rows, bytes read has ${normalData.size} rows") + // Convert to comparable format (extract field values) + val normalSet = normalData.map { case (partId, key, value) => + val keyFields = (0 until key.length).map(i => key.get(i)) + val valueFields = (0 until value.length).map(i => value.get(i)) + (partId, keyFields, valueFields) + } + + val bytesSet = filteredBytesData.map { case (partId, keyRow, valueRow, _) => + val keyFields = (0 until keySchema.length).map(i => + keyRow.get(i, keySchema(i).dataType)) + val valueFields = (0 until valueSchema.length).map(i => + valueRow.get(i, valueSchema(i).dataType)) + (partId, keyFields, valueFields) + } + + assert(normalSet == bytesSet) + } + + test("read all column families with simple operator") { + withTempDir { tempDir => + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF() + .selectExpr("value", "value % 10 AS groupKey") + .groupBy($"groupKey") + .agg( + count("*").as("cnt"), + sum("value").as("sum") + ) + .as[(Int, Long, Long)] + + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + // batch 0 + AddData(inputData, 0 until 20: _*), + CheckLastBatch( + (0, 2, 10), // 0, 10 + (1, 2, 12), // 1, 11 + (2, 2, 14), // 2, 12 + (3, 2, 16), // 3, 13 + (4, 2, 18), // 4, 14 + (5, 2, 20), // 5, 15 + (6, 2, 22), // 6, 16 + (7, 2, 24), // 7, 17 + (8, 2, 26), // 8, 18 + (9, 2, 28) // 9, 19 + ), + StopStream + ) + + // Read state data once with READ_ALL_COLUMN_FAMILIES = true + val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath) + + // Verify schema and column families + validateBytesReadSchema(bytesReadDf, + expectedRowCount = 10, + expectedColumnFamilies = Seq("default")) + + // Get normal read data for comparison + val normalData = getNormalReadData(tempDir.getAbsolutePath) + + // Compare normal and bytes data for default column family + val keySchema: StructType = StructType(Array( + StructField("key", IntegerType, nullable = false) + )) + + // Value schema for the aggregation: count and sum columns + val valueSchema: StructType = StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false) + )) + // Parse bytes read data + val bytesData = parseBytesReadData(bytesReadDf, keySchema.length, valueSchema.length) + + compareNormalAndBytesData(normalData, bytesData, "default", keySchema, valueSchema) + } + } + } +} From f14e02496482bc40fe1e3c8a7a13a97d54996e7a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 18 Nov 2025 19:58:41 +0000 Subject: [PATCH 02/12] add test and support for HDFS --- .../v2/state/StatePartitionReader.scala | 8 +- .../streaming/state/StateStoreConf.scala | 20 ++ ...artitionReaderAllColumnFamiliesSuite.scala | 212 ++++++++++-------- 3 files changed, 145 insertions(+), 95 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index ca1a6cf295b8..c80046e60d3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -283,10 +283,16 @@ class StatePartitionReaderAllColumnFamilies( partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) + // Disable format validation when reading raw bytes. + // We use binary schemas (keyBytes/valueBytes) which don't match the actual schema + // of the stored data. Validation would fail in HDFSBackedStateStoreProvider when + // loading data from disk, so we disable it for raw bytes mode. + val modifiedStoreConf = storeConf.withFormatValidationDisabled() + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) val provider = StateStoreProvider.createAndInit( stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec, - useColumnFamilies = colFamilyNames.nonEmpty, storeConf, hadoopConf.value, false, None) + useColumnFamilies = colFamilyNames.nonEmpty, modifiedStoreConf, hadoopConf.value, false, None) provider } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 3991f8d93f2c..35f8d2118643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -163,6 +163,26 @@ class StateStoreConf( */ val sqlConfs: Map[String, String] = sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) + + /** + * Creates a copy of this StateStoreConf with format validation disabled. + * This is useful when reading raw bytes where the schema used (binary) doesn't match + * the actual stored data schema. + */ + def withFormatValidationDisabled(): StateStoreConf = { + val reconstructedSqlConf = { + // Reconstruct a SQLConf with the all settings preserved because sqlConf is transient + val conf = new SQLConf() + // Restore all state store related settings + sqlConfs.foreach { case (key, value) => + conf.setConfString(key, value) + } + conf + } + new StateStoreConf(reconstructedSqlConf, extraOptions) { + override val formatValidationEnabled: Boolean = false + } + } } object StateStoreConf { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala index 20fe0162a424..1b3da0498623 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.sql.execution.datasources.v2.state -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.runtime.MemoryStream -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} import org.apache.spark.sql.functions.{count, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -38,21 +38,12 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase /** * Returns a set of (partitionId, key, value) tuples from a normal state read. */ - private def getNormalReadData(checkpointDir: String): Set[(Int, Row, Row)] = { - val normalReadDf = spark.read + private def getNormalReadData(checkpointDir: String): DataFrame = { + spark.read .format("statestore") .option(StateSourceOptions.PATH, checkpointDir) .load() .selectExpr("partition_id", "key", "value") - - normalReadDf.collect() - .map { row => - val partitionId = row.getInt(0) - val key = row.getStruct(1) - val value = row.getStruct(2) - (partitionId, key, value) - } - .toSet } /** @@ -108,9 +99,14 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase /** * Parses the bytes read DataFrame into a set of (partitionId, key, value, columnFamily) tuples. + * For RocksDB provider, skipVersionBytes should be true. + * For HDFS provider, skipVersionBytes should be false. */ private def parseBytesReadData( - df: DataFrame, numOfKey: Int, numOfValue: Int): Set[(Int, UnsafeRow, UnsafeRow, String)] = { + df: DataFrame, + numOfKey: Int, + numOfValue: Int, + skipVersionBytes: Boolean = true): Set[(Int, UnsafeRow, UnsafeRow, String)] = { df.selectExpr("partition_id", "key_bytes", "value_bytes", "column_family_name") .collect() .map { row => @@ -120,20 +116,38 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase val columnFamily = row.getString(3) // Deserialize key bytes to UnsafeRow - // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning val keyRow = new UnsafeRow(numOfKey) - keyRow.pointTo( - keyBytes, - Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES, - keyBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES) + if (skipVersionBytes) { + // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning + // This is for RocksDB provider + keyRow.pointTo( + keyBytes, + Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES, + keyBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES) + } else { + // HDFS provider doesn't add version bytes, use bytes directly + keyRow.pointTo( + keyBytes, + Platform.BYTE_ARRAY_OFFSET, + keyBytes.length) + } // Deserialize value bytes to UnsafeRow - // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning val valueRow = new UnsafeRow(numOfValue) - valueRow.pointTo( - valueBytes, - Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES, - valueBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES) + if (skipVersionBytes) { + // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning + // This is for RocksDB provider + valueRow.pointTo( + valueBytes, + Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES, + valueBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES) + } else { + // HDFS provider doesn't add version bytes, use bytes directly + valueRow.pointTo( + valueBytes, + Platform.BYTE_ARRAY_OFFSET, + valueBytes.length) + } (partitionId, keyRow.copy(), valueRow.copy(), columnFamily) } @@ -144,24 +158,31 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase * Compares normal read data with bytes read data for a specific column family. */ private def compareNormalAndBytesData( - normalData: Set[(Int, Row, Row)], - bytesData: Set[(Int, UnsafeRow, UnsafeRow, String)], + normalReadDf: DataFrame, + bytesReadDf: DataFrame, columnFamily: String, keySchema: StructType, - valueSchema: StructType): Unit = { + valueSchema: StructType, + skipVersionBytes: Boolean): Unit = { + // Filter bytes data for the specified column family + val bytesData = parseBytesReadData(bytesReadDf, keySchema.length, valueSchema.length, + skipVersionBytes) val filteredBytesData = bytesData.filter(_._4 == columnFamily) - // Verify same number of rows - assert(filteredBytesData.size == normalData.size, - s"Row count mismatch for column family '$columnFamily': " + - s"normal read has ${filteredBytesData.size} rows, bytes read has ${normalData.size} rows") // Convert to comparable format (extract field values) - val normalSet = normalData.map { case (partId, key, value) => + val normalSet = normalReadDf.collect().map { row => + val partitionId = row.getInt(0) + val key = row.getStruct(1) + val value = row.getStruct(2) val keyFields = (0 until key.length).map(i => key.get(i)) val valueFields = (0 until value.length).map(i => value.get(i)) - (partId, keyFields, valueFields) - } + (partitionId, keyFields, valueFields) + }.toSet + // Verify same number of rows + assert(filteredBytesData.size == normalSet.size, + s"Row count mismatch for column family '$columnFamily': " + + s"normal read has ${filteredBytesData.size} rows, bytes read has ${normalSet.size} rows") val bytesSet = filteredBytesData.map { case (partId, keyRow, valueRow, _) => val keyFields = (0 until keySchema.length).map(i => @@ -174,67 +195,70 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase assert(normalSet == bytesSet) } - test("read all column families with simple operator") { - withTempDir { tempDir => - withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> "2") { - - val inputData = MemoryStream[Int] - val aggregated = inputData.toDF() - .selectExpr("value", "value % 10 AS groupKey") - .groupBy($"groupKey") - .agg( - count("*").as("cnt"), - sum("value").as("sum") + Seq( + ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider], true), + ("HDFSBackedStateStoreProvider", classOf[HDFSBackedStateStoreProvider], false) + ).foreach { case (providerName, providerClass, skipVersionBytes) => + test(s"read all column families with simple operator - $providerName") { + withTempDir { tempDir => + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClass.getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF() + .selectExpr("value", "value % 10 AS groupKey") + .groupBy($"groupKey") + .agg( + count("*").as("cnt"), + sum("value").as("sum") + ) + .as[(Int, Long, Long)] + + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + // batch 0 + AddData(inputData, 0 until 20: _*), + CheckLastBatch( + (0, 2, 10), // 0, 10 + (1, 2, 12), // 1, 11 + (2, 2, 14), // 2, 12 + (3, 2, 16), // 3, 13 + (4, 2, 18), // 4, 14 + (5, 2, 20), // 5, 15 + (6, 2, 22), // 6, 16 + (7, 2, 24), // 7, 17 + (8, 2, 26), // 8, 18 + (9, 2, 28) // 9, 19 + ), + StopStream ) - .as[(Int, Long, Long)] - - testStream(aggregated, OutputMode.Update)( - StartStream(checkpointLocation = tempDir.getAbsolutePath), - // batch 0 - AddData(inputData, 0 until 20: _*), - CheckLastBatch( - (0, 2, 10), // 0, 10 - (1, 2, 12), // 1, 11 - (2, 2, 14), // 2, 12 - (3, 2, 16), // 3, 13 - (4, 2, 18), // 4, 14 - (5, 2, 20), // 5, 15 - (6, 2, 22), // 6, 16 - (7, 2, 24), // 7, 17 - (8, 2, 26), // 8, 18 - (9, 2, 28) // 9, 19 - ), - StopStream - ) - - // Read state data once with READ_ALL_COLUMN_FAMILIES = true - val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath) - - // Verify schema and column families - validateBytesReadSchema(bytesReadDf, - expectedRowCount = 10, - expectedColumnFamilies = Seq("default")) - - // Get normal read data for comparison - val normalData = getNormalReadData(tempDir.getAbsolutePath) - - // Compare normal and bytes data for default column family - val keySchema: StructType = StructType(Array( - StructField("key", IntegerType, nullable = false) - )) - - // Value schema for the aggregation: count and sum columns - val valueSchema: StructType = StructType(Array( - StructField("count", LongType, nullable = false), - StructField("sum", LongType, nullable = false) - )) - // Parse bytes read data - val bytesData = parseBytesReadData(bytesReadDf, keySchema.length, valueSchema.length) - - compareNormalAndBytesData(normalData, bytesData, "default", keySchema, valueSchema) + + // Read state data once with READ_ALL_COLUMN_FAMILIES = true + val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath) + + // Verify schema and column families + validateBytesReadSchema(bytesReadDf, + expectedRowCount = 10, + expectedColumnFamilies = Seq("default")) + + // Compare normal and bytes data for default column family + val keySchema: StructType = StructType(Array( + StructField("key", IntegerType, nullable = false) + )) + + // Value schema for the aggregation: count and sum columns + val valueSchema: StructType = StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false) + )) + // Parse bytes read data + + // Get normal read data for comparison + val normalData = getNormalReadData(tempDir.getAbsolutePath) + compareNormalAndBytesData( + normalData, bytesReadDf, "default", keySchema, valueSchema, skipVersionBytes) + } } } } From 0cd133049ae216f54480964d8b852e67f3683747 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 18 Nov 2025 22:04:38 +0000 Subject: [PATCH 03/12] remove unused code --- .../datasources/v2/state/StateDataSource.scala | 2 +- .../v2/state/StatePartitionReader.scala | 15 ++++++--------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 5878e295369a..760b5d7eec66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -67,7 +67,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging StateSourceOptions.apply(session, hadoopConf, properties)) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) if (sourceOptions.readAllColumnFamilies) { - // For readAllColumnFamilies mode, we don't need specific metadata + // For readAllColumnFamilies mode, we don't need specific encoder because it returns raw data val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType()) new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, None, None, None, None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index c80046e60d3e..3740201986da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.join.Symmetri import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType} -import org.apache.spark.sql.types.{BinaryType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{NullType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{NextIterator, SerializableConfiguration} @@ -85,20 +85,16 @@ abstract class StatePartitionReaderBase( private val schemaForValueRow: StructType = StructType(Array(StructField("__dummy__", NullType))) - protected val keySchema = { + protected lazy val keySchema = { if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) - } else if (partition.sourceOptions.readAllColumnFamilies) { - new StructType().add("keyBytes", BinaryType, nullable = false) } else { SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] } } - protected val valueSchema = if (stateVariableInfoOpt.isDefined) { + protected lazy val valueSchema = if (stateVariableInfoOpt.isDefined) { schemaForValueRow - } else if (partition.sourceOptions.readAllColumnFamilies) { - new StructType().add("valueBytes", BinaryType, nullable = false) } else { SchemaUtil.getSchemaAsDataType( schema, "value").asInstanceOf[StructType] @@ -289,9 +285,10 @@ class StatePartitionReaderAllColumnFamilies( // loading data from disk, so we disable it for raw bytes mode. val modifiedStoreConf = storeConf.withFormatValidationDisabled() - val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema) + val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType()) + // Pass in empty keySchema, valueSchema and dummy encoder because we don't encode any data val provider = StateStoreProvider.createAndInit( - stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec, + stateStoreProviderId, new StructType(), new StructType(), keyStateEncoderSpec, useColumnFamilies = colFamilyNames.nonEmpty, modifiedStoreConf, hadoopConf.value, false, None) provider From f6e15ed6d6fbf76edc1800714e8ec5bcdd9f602c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 19 Nov 2025 01:35:02 +0000 Subject: [PATCH 04/12] address comment --- .../v2/state/StateDataSource.scala | 76 +++++------ .../v2/state/StatePartitionReader.scala | 73 ++--------- .../v2/state/utils/SchemaUtil.scala | 27 ++-- .../streaming/state/StateStore.scala | 14 -- .../streaming/state/StateStoreConf.scala | 20 --- ...artitionReaderAllColumnFamiliesSuite.scala | 120 ++++++------------ 6 files changed, 108 insertions(+), 222 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 760b5d7eec66..5bafb3d64b04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -66,38 +66,35 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, StateSourceOptions.apply(session, hadoopConf, properties)) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) - if (sourceOptions.readAllColumnFamilies) { - // For readAllColumnFamilies mode, we don't need specific encoder because it returns raw data - val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType()) - new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, - None, None, None, None) - } else { - val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( - sourceOptions) + if (sourceOptions.internalOnlyReadAllColumnFamilies + && !stateConf.providerClass.contains("RocksDB")) { + throw StateDataSourceErrors.invalidOptionValue( + StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, + "internalOnlyReadAllColumnFamilies is only supported with RocksDBStateStoreProvider. " + + s"Current provider: ${stateConf.providerClass}") + } + val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( + sourceOptions) - // The key state encoder spec should be available for all operators except stream-stream joins - val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) { - stateStoreReaderInfo.keyStateEncoderSpecOpt.get - } else { - val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] - NoPrefixKeyStateEncoderSpec(keySchema) - } - new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, - stateStoreReaderInfo.transformWithStateVariableInfoOpt, - stateStoreReaderInfo.stateStoreColFamilySchemaOpt, - stateStoreReaderInfo.stateSchemaProviderOpt, - stateStoreReaderInfo.joinColFamilyOpt) + // The key state encoder spec should be available for all operators except stream-stream joins + val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) { + stateStoreReaderInfo.keyStateEncoderSpecOpt.get + } else { + val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] + NoPrefixKeyStateEncoderSpec(keySchema) } + + new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, + stateStoreReaderInfo.transformWithStateVariableInfoOpt, + stateStoreReaderInfo.stateStoreColFamilySchemaOpt, + stateStoreReaderInfo.stateSchemaProviderOpt, + stateStoreReaderInfo.joinColFamilyOpt) } override def inferSchema(options: CaseInsensitiveStringMap): StructType = { val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, StateSourceOptions.apply(session, hadoopConf, options)) - if (sourceOptions.readAllColumnFamilies) { - // For readAllColumnFamilies mode, return the binary schema directly - return SchemaUtil.getSourceSchema( - sourceOptions, new StructType(), new StructType(), None, None) - } + val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( sourceOptions) val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf) @@ -382,7 +379,7 @@ case class StateSourceOptions( stateVarName: Option[String], readRegisteredTimers: Boolean, flattenCollectionTypes: Boolean, - readAllColumnFamilies: Boolean, + internalOnlyReadAllColumnFamilies: Boolean, startOperatorStateUniqueIds: Option[Array[Array[String]]] = None, endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) @@ -392,7 +389,7 @@ case class StateSourceOptions( s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " + s"stateVarName=${stateVarName.getOrElse("None")}, +" + s"flattenCollectionTypes=$flattenCollectionTypes" + - s"readAllColumnFamilies=$readAllColumnFamilies" + s"internalOnlyReadAllColumnFamilies=$internalOnlyReadAllColumnFamilies" if (fromSnapshotOptions.isDefined) { desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}" desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}" @@ -419,7 +416,7 @@ object StateSourceOptions extends DataSourceOptions { val STATE_VAR_NAME = newOption("stateVarName") val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers") val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes") - val READ_ALL_COLUMN_FAMILIES = newOption("readAllColumnFamilies") + val INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = newOption("internalOnlyReadAllColumnFamilies") object JoinSideValues extends Enumeration { type JoinSideValues = Value @@ -505,25 +502,28 @@ object StateSourceOptions extends DataSourceOptions { val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean) - val readAllColumnFamilies = try { - Option(options.get(READ_ALL_COLUMN_FAMILIES)) + val internalOnlyReadAllColumnFamilies = try { + Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES)) .map(_.toBoolean).getOrElse(false) } catch { case _: IllegalArgumentException => - throw StateDataSourceErrors.invalidOptionValue(READ_ALL_COLUMN_FAMILIES, + throw StateDataSourceErrors.invalidOptionValue(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "Boolean value is expected") } - if (readAllColumnFamilies && stateVarName.isDefined) { - throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME)) + if (internalOnlyReadAllColumnFamilies && stateVarName.isDefined) { + throw StateDataSourceErrors.conflictOptions( + Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME)) } - if (readAllColumnFamilies && joinSide != JoinSideValues.none) { - throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, JOIN_SIDE)) + if (internalOnlyReadAllColumnFamilies && joinSide != JoinSideValues.none) { + throw StateDataSourceErrors.conflictOptions( + Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, JOIN_SIDE)) } - if (readAllColumnFamilies && readChangeFeed) { - throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED)) + if (internalOnlyReadAllColumnFamilies && readChangeFeed) { + throw StateDataSourceErrors.conflictOptions( + Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED)) } val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong) @@ -650,7 +650,7 @@ object StateSourceOptions extends DataSourceOptions { resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName, readRegisteredTimers, flattenCollectionTypes, - readAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds) + internalOnlyReadAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds) } private def getLastCommittedBatch(session: SparkSession, checkpointLocation: String): Long = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 3740201986da..475249e3fd27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -20,7 +20,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} -import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo} @@ -50,9 +49,9 @@ class StatePartitionReaderFactory( override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] - if (stateStoreInputPartition.sourceOptions.readAllColumnFamilies) { + if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) { new StatePartitionReaderAllColumnFamilies(storeConf, hadoopConf, - stateStoreInputPartition, schema) + stateStoreInputPartition, schema, keyStateEncoderSpec) } else if (stateStoreInputPartition.sourceOptions.readChangeFeed) { new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt, @@ -85,7 +84,7 @@ abstract class StatePartitionReaderBase( private val schemaForValueRow: StructType = StructType(Array(StructField("__dummy__", NullType))) - protected lazy val keySchema = { + protected val keySchema = { if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) } else { @@ -93,7 +92,7 @@ abstract class StatePartitionReaderBase( } } - protected lazy val valueSchema = if (stateVariableInfoOpt.isDefined) { + protected val valueSchema = if (stateVariableInfoOpt.isDefined) { schemaForValueRow } else { SchemaUtil.getSchemaAsDataType( @@ -249,16 +248,10 @@ class StatePartitionReaderAllColumnFamilies( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, - schema: StructType) + schema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec) extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, - NoPrefixKeyStateEncoderSpec(new StructType()), None, None, None, None) { - - val allStateStoreMetadata = { - new StateMetadataPartitionReader( - partition.sourceOptions.resolvedCpLocation, - new SerializableConfiguration(hadoopConf.value), - partition.sourceOptions.batchId).stateMetadata.toArray - } + keyStateEncoderSpec, None, None, None, None) { private lazy val store: ReadStateStore = { assert(getStartStoreUniqueId == getEndStoreUniqueId, @@ -269,56 +262,14 @@ class StatePartitionReaderAllColumnFamilies( ) } - val colFamilyNames: Seq[String] = { - // todo: Support operator with multiple column family names in next PR - Seq[String]() - } - - override protected lazy val provider: StateStoreProvider = { - val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, - partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) - val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) - - // Disable format validation when reading raw bytes. - // We use binary schemas (keyBytes/valueBytes) which don't match the actual schema - // of the stored data. Validation would fail in HDFSBackedStateStoreProvider when - // loading data from disk, so we disable it for raw bytes mode. - val modifiedStoreConf = storeConf.withFormatValidationDisabled() - - val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType()) - // Pass in empty keySchema, valueSchema and dummy encoder because we don't encode any data - val provider = StateStoreProvider.createAndInit( - stateStoreProviderId, new StructType(), new StructType(), keyStateEncoderSpec, - useColumnFamilies = colFamilyNames.nonEmpty, modifiedStoreConf, hadoopConf.value, false, None) - - provider - } - override lazy val iter: Iterator[InternalRow] = { // Single store with column families (join v3, transformWithState, or simple operators) - require(store.isInstanceOf[SupportsRawBytesRead], - s"State store ${store.getClass.getName} does not support raw bytes reading") - - val rawStore = store.asInstanceOf[SupportsRawBytesRead] - if (colFamilyNames.isEmpty) { - rawStore - .rawIterator() - .map { case (keyBytes, valueBytes) => - SchemaUtil.unifyStateRowPairAsRawBytes( - partition.partition, keyBytes, valueBytes, StateStore.DEFAULT_COL_FAMILY_NAME) - } - } else { - colFamilyNames.iterator.flatMap { colFamilyName => - rawStore - .rawIterator(colFamilyName) - .map { case (keyBytes, valueBytes) => - SchemaUtil.unifyStateRowPairAsRawBytes(partition.partition, - keyBytes, - valueBytes, - colFamilyName) - } + store + .iterator() + .map { pair => + SchemaUtil.unifyStateRowPairAsRawBytes( + (pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME) } - } } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 095bcff6fb07..16a8a2e5d3fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -61,12 +61,17 @@ object SchemaUtil { .add("key", keySchema) .add("value", valueSchema) .add("partition_id", IntegerType) - } else if (sourceOptions.readAllColumnFamilies) { + } else if (sourceOptions.internalOnlyReadAllColumnFamilies) { new StructType() - .add("partition_id", IntegerType) + // todo: change this to some more specific type after we + // can extract partition key from keySchema + .add("partition_key", keySchema) .add("key_bytes", BinaryType) .add("value_bytes", BinaryType) .add("column_family_name", StringType) + // need key and value schema so that state store can encode data + .add("value", valueSchema) + .add("key", keySchema) } else { new StructType() .add("key", keySchema) @@ -89,15 +94,14 @@ object SchemaUtil { * instead of a tuple for better readability. */ def unifyStateRowPairAsRawBytes( - partition: Int, - keyBytes: Array[Byte], - valueBytes: Array[Byte], + pair: (UnsafeRow, UnsafeRow), colFamilyName: String): InternalRow = { - val row = new GenericInternalRow(4) - row.update(0, partition) - row.update(1, keyBytes) - row.update(2, valueBytes) + val row = new GenericInternalRow(6) + row.update(0, pair._1) + row.update(1, pair._1.getBytes) + row.update(2, pair._2.getBytes) row.update(3, UTF8String.fromString(colFamilyName)) +// row.update(4, pair._2) row } @@ -257,6 +261,7 @@ object SchemaUtil { "user_map_value" -> classOf[StructType], "expiration_timestamp_ms" -> classOf[LongType], "partition_id" -> classOf[IntegerType], + "partition_key" -> classOf[StructType], "key_bytes"->classOf[BinaryType], "value_bytes"->classOf[BinaryType], "column_family_name"->classOf[StringType]) @@ -300,8 +305,8 @@ object SchemaUtil { } } else if (sourceOptions.readChangeFeed) { Seq("batch_id", "change_type", "key", "value", "partition_id") - } else if (sourceOptions.readAllColumnFamilies) { - Seq("partition_id", "key_bytes", "value_bytes", "column_family_name") + } else if (sourceOptions.internalOnlyReadAllColumnFamilies) { + Seq("partition_key", "key_bytes", "value_bytes", "column_family_name", "value", "key") } else { Seq("key", "value", "partition_id") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d180010e355e..43b95766882f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -891,20 +891,6 @@ object StateStoreProvider extends Logging { } } -/** - * Trait for state stores that support reading raw bytes without decoding. - * This is useful for copying state data during repartitioning - */ -trait SupportsRawBytesRead { - /** - * Returns an iterator of raw key-value bytes for a column family. - * @param colFamilyName the name of the column family to iterate over - * @return an iterator of (keyBytes, valueBytes) tuples - */ - def rawIterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): - Iterator[(Array[Byte], Array[Byte])] -} - /** * This is an optional trait to be implemented by [[StateStoreProvider]]s that can read the change * of state store over batches. This is used by State Data Source with additional options like diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 35f8d2118643..3991f8d93f2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -163,26 +163,6 @@ class StateStoreConf( */ val sqlConfs: Map[String, String] = sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore.")) - - /** - * Creates a copy of this StateStoreConf with format validation disabled. - * This is useful when reading raw bytes where the schema used (binary) doesn't match - * the actual stored data schema. - */ - def withFormatValidationDisabled(): StateStoreConf = { - val reconstructedSqlConf = { - // Reconstruct a SQLConf with the all settings preserved because sqlConf is transient - val conf = new SQLConf() - // Restore all state store related settings - sqlConfs.foreach { case (key, value) => - conf.setConfString(key, value) - } - conf - } - new StateStoreConf(reconstructedSqlConf, extraOptions) { - override val formatValidationEnabled: Boolean = false - } - } } object StateStoreConf { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala index 1b3da0498623..49d5d46842de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.v2.state import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, UnsafeRow} import org.apache.spark.sql.execution.streaming.runtime.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider import org.apache.spark.sql.functions.{count, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -47,13 +47,13 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase } /** - * Returns a DataFrame with raw bytes mode (READ_ALL_COLUMN_FAMILIES = true). + * Returns a DataFrame with raw bytes mode (INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = true). */ private def getBytesReadDf(checkpointDir: String): DataFrame = { spark.read .format("statestore") .option(StateSourceOptions.PATH, checkpointDir) - .option(StateSourceOptions.READ_ALL_COLUMN_FAMILIES, "true") + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") .load() } @@ -67,14 +67,16 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase // Verify schema val schema = df.schema assert(schema.fieldNames === Array( - "partition_id", "key_bytes", "value_bytes", "column_family_name")) - assert(schema("partition_id").dataType.typeName === "integer") + "partition_key", "key_bytes", "value_bytes", "column_family_name", "value", "key")) + assert(schema("partition_key").dataType.typeName === "struct") assert(schema("key_bytes").dataType.typeName === "binary") assert(schema("value_bytes").dataType.typeName === "binary") assert(schema("column_family_name").dataType.typeName === "string") // Verify data - val rows = df.collect() + val rows = df + .selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name") + .collect() assert(rows.length == expectedRowCount, s"Expected $expectedRowCount rows but got: ${rows.length}") @@ -88,68 +90,37 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase s"Expected column family '$expectedCF', " + s"but got: ${columnFamilies.mkString(", ")}") } - - // Verify all rows have non-null values - rows.foreach { row => - assert(row.getInt(0) >= 0) // partition_id non-negative - assert(row.get(1) != null) // key_bytes not null - assert(row.get(2) != null) // value_bytes not null - } } - /** - * Parses the bytes read DataFrame into a set of (partitionId, key, value, columnFamily) tuples. - * For RocksDB provider, skipVersionBytes should be true. - * For HDFS provider, skipVersionBytes should be false. - */ private def parseBytesReadData( - df: DataFrame, - numOfKey: Int, - numOfValue: Int, - skipVersionBytes: Boolean = true): Set[(Int, UnsafeRow, UnsafeRow, String)] = { - df.selectExpr("partition_id", "key_bytes", "value_bytes", "column_family_name") + df: DataFrame) + : Set[(GenericRowWithSchema, UnsafeRow, UnsafeRow, String)] = { + df.selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name") .collect() .map { row => - val partitionId = row.getInt(0) + val partitionKey = row.getAs[GenericRowWithSchema](0) val keyBytes = row.getAs[Array[Byte]](1) val valueBytes = row.getAs[Array[Byte]](2) val columnFamily = row.getString(3) // Deserialize key bytes to UnsafeRow - val keyRow = new UnsafeRow(numOfKey) - if (skipVersionBytes) { - // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning - // This is for RocksDB provider - keyRow.pointTo( - keyBytes, - Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES, - keyBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES) - } else { - // HDFS provider doesn't add version bytes, use bytes directly - keyRow.pointTo( - keyBytes, - Platform.BYTE_ARRAY_OFFSET, - keyBytes.length) - } + val keyRow = new UnsafeRow(1) + // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning + // This is for RocksDB provider + keyRow.pointTo( + keyBytes, + Platform.BYTE_ARRAY_OFFSET, + keyBytes.length) // Deserialize value bytes to UnsafeRow - val valueRow = new UnsafeRow(numOfValue) - if (skipVersionBytes) { - // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning - // This is for RocksDB provider - valueRow.pointTo( - valueBytes, - Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES, - valueBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES) - } else { - // HDFS provider doesn't add version bytes, use bytes directly - valueRow.pointTo( - valueBytes, - Platform.BYTE_ARRAY_OFFSET, - valueBytes.length) - } - - (partitionId, keyRow.copy(), valueRow.copy(), columnFamily) + val valueRow = new UnsafeRow(2) + // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning + // This is for RocksDB provider + valueRow.pointTo( + valueBytes, + Platform.BYTE_ARRAY_OFFSET, + valueBytes.length) + (partitionKey, keyRow.copy(), valueRow.copy(), columnFamily) } .toSet } @@ -162,47 +133,41 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase bytesReadDf: DataFrame, columnFamily: String, keySchema: StructType, - valueSchema: StructType, - skipVersionBytes: Boolean): Unit = { + valueSchema: StructType): Unit = { // Filter bytes data for the specified column family - val bytesData = parseBytesReadData(bytesReadDf, keySchema.length, valueSchema.length, - skipVersionBytes) + val bytesData = parseBytesReadData(bytesReadDf) val filteredBytesData = bytesData.filter(_._4 == columnFamily) + // Apply the projection // Convert to comparable format (extract field values) val normalSet = normalReadDf.collect().map { row => - val partitionId = row.getInt(0) val key = row.getStruct(1) val value = row.getStruct(2) val keyFields = (0 until key.length).map(i => key.get(i)) val valueFields = (0 until value.length).map(i => value.get(i)) - (partitionId, keyFields, valueFields) + (keyFields, valueFields) }.toSet - // Verify same number of rows - assert(filteredBytesData.size == normalSet.size, - s"Row count mismatch for column family '$columnFamily': " + - s"normal read has ${filteredBytesData.size} rows, bytes read has ${normalSet.size} rows") - val bytesSet = filteredBytesData.map { case (partId, keyRow, valueRow, _) => + val bytesSet = filteredBytesData.map { case (_, keyRow, valueRow, _) => val keyFields = (0 until keySchema.length).map(i => keyRow.get(i, keySchema(i).dataType)) val valueFields = (0 until valueSchema.length).map(i => valueRow.get(i, valueSchema(i).dataType)) - (partId, keyFields, valueFields) + (keyFields, valueFields) } + // Verify same number of rows + assert(filteredBytesData.size == normalSet.size, + s"Row count mismatch for column family '$columnFamily': " + + s"normal read has ${filteredBytesData.size} rows, bytes read has ${normalSet.size} rows") assert(normalSet == bytesSet) } - Seq( - ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider], true), - ("HDFSBackedStateStoreProvider", classOf[HDFSBackedStateStoreProvider], false) - ).foreach { case (providerName, providerClass, skipVersionBytes) => - test(s"read all column families with simple operator - $providerName") { + test(s"read all column families with simple operator") { withTempDir { tempDir => withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClass.getName, + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> "2") { val inputData = MemoryStream[Int] @@ -234,7 +199,7 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase StopStream ) - // Read state data once with READ_ALL_COLUMN_FAMILIES = true + // Read state data once with INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = true val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath) // Verify schema and column families @@ -257,8 +222,7 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase // Get normal read data for comparison val normalData = getNormalReadData(tempDir.getAbsolutePath) compareNormalAndBytesData( - normalData, bytesReadDf, "default", keySchema, valueSchema, skipVersionBytes) - } + normalData, bytesReadDf, "default", keySchema, valueSchema) } } } From 158c8465855c65495b9ee97b0425f0732ef86a56 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 19 Nov 2025 03:18:11 +0000 Subject: [PATCH 05/12] refactor test --- ...artitionReaderAllColumnFamiliesSuite.scala | 54 +++++-------------- 1 file changed, 14 insertions(+), 40 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala index 49d5d46842de..e4ca6610f0c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution.datasources.v2.state -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, UnsafeRow} import org.apache.spark.sql.execution.streaming.runtime.MemoryStream import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider @@ -60,10 +60,7 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase /** * Validates the schema and column families of the bytes read DataFrame. */ - private def validateBytesReadSchema( - df: DataFrame, - expectedRowCount: Int, - expectedColumnFamilies: Seq[String]): Unit = { + private def validateBytesReadSchema(df: DataFrame): Unit = { // Verify schema val schema = df.schema assert(schema.fieldNames === Array( @@ -72,50 +69,25 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase assert(schema("key_bytes").dataType.typeName === "binary") assert(schema("value_bytes").dataType.typeName === "binary") assert(schema("column_family_name").dataType.typeName === "string") - - // Verify data - val rows = df - .selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name") - .collect() - assert(rows.length == expectedRowCount, - s"Expected $expectedRowCount rows but got: ${rows.length}") - - val columnFamilies = rows.map(r => Option(r.getString(3)).getOrElse("null")).distinct.sorted - assert(columnFamilies.length == expectedColumnFamilies.length, - s"Expected ${expectedColumnFamilies.length} column families, " + - s"but got ${columnFamilies.length}: ${columnFamilies.mkString(", ")}") - - expectedColumnFamilies.foreach { expectedCF => - assert(columnFamilies.contains(expectedCF), - s"Expected column family '$expectedCF', " + - s"but got: ${columnFamilies.mkString(", ")}") - } } - private def parseBytesReadData( - df: DataFrame) + private def parseBytesReadData(df: Array[Row], keyLength: Int, valueLength: Int) : Set[(GenericRowWithSchema, UnsafeRow, UnsafeRow, String)] = { - df.selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name") - .collect() - .map { row => + df.map { row => val partitionKey = row.getAs[GenericRowWithSchema](0) val keyBytes = row.getAs[Array[Byte]](1) val valueBytes = row.getAs[Array[Byte]](2) val columnFamily = row.getString(3) // Deserialize key bytes to UnsafeRow - val keyRow = new UnsafeRow(1) - // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning - // This is for RocksDB provider + val keyRow = new UnsafeRow(keyLength) keyRow.pointTo( keyBytes, Platform.BYTE_ARRAY_OFFSET, keyBytes.length) // Deserialize value bytes to UnsafeRow - val valueRow = new UnsafeRow(2) - // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning - // This is for RocksDB provider + val valueRow = new UnsafeRow(valueLength) valueRow.pointTo( valueBytes, Platform.BYTE_ARRAY_OFFSET, @@ -134,9 +106,15 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase columnFamily: String, keySchema: StructType, valueSchema: StructType): Unit = { + // Verify data + val bytesDf = bytesReadDf + .selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name") + .collect() + assert(bytesDf.length == 10, + s"Expected 10 rows but got: ${bytesDf.length}") // Filter bytes data for the specified column family - val bytesData = parseBytesReadData(bytesReadDf) + val bytesData = parseBytesReadData(bytesDf, keySchema.length, valueSchema.length) val filteredBytesData = bytesData.filter(_._4 == columnFamily) // Apply the projection @@ -203,10 +181,7 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath) // Verify schema and column families - validateBytesReadSchema(bytesReadDf, - expectedRowCount = 10, - expectedColumnFamilies = Seq("default")) - + validateBytesReadSchema(bytesReadDf) // Compare normal and bytes data for default column family val keySchema: StructType = StructType(Array( StructField("key", IntegerType, nullable = false) @@ -217,7 +192,6 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase StructField("count", LongType, nullable = false), StructField("sum", LongType, nullable = false) )) - // Parse bytes read data // Get normal read data for comparison val normalData = getNormalReadData(tempDir.getAbsolutePath) From 2129dcbfc2e5c1803efb3100d1b199d8d80ebd15 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 19 Nov 2025 03:38:55 +0000 Subject: [PATCH 06/12] add more test --- .../v2/state/StateDataSource.scala | 4 +- .../v2/state/utils/SchemaUtil.scala | 1 - ...artitionReaderAllColumnFamiliesSuite.scala | 43 ++++++++++++++++++- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 5bafb3d64b04..584b544ed5c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -70,8 +70,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging && !stateConf.providerClass.contains("RocksDB")) { throw StateDataSourceErrors.invalidOptionValue( StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, - "internalOnlyReadAllColumnFamilies is only supported with RocksDBStateStoreProvider. " + - s"Current provider: ${stateConf.providerClass}") + "internalOnlyReadAllColumnFamilies=true is only supported with " + + s"RocksDBStateStoreProvider. Current provider: ${stateConf.providerClass}") } val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( sourceOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 16a8a2e5d3fd..d28c64f1b113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -101,7 +101,6 @@ object SchemaUtil { row.update(1, pair._1.getBytes) row.update(2, pair._2.getBytes) row.update(3, UTF8String.fromString(colFamilyName)) -// row.update(4, pair._2) row } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala index e4ca6610f0c6..46e0f225c49a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.state import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, UnsafeRow} import org.apache.spark.sql.execution.streaming.runtime.MemoryStream -import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} import org.apache.spark.sql.functions.{count, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -200,4 +200,45 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase } } } + + test("internalOnlyReadAllColumnFamilies should fail with HDFS-backed state store") { + withTempDir { tempDir => + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[HDFSBackedStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF() + .selectExpr("value", "value % 10 AS groupKey") + .groupBy($"groupKey") + .agg( + count("*").as("cnt"), + sum("value").as("sum") + ) + .as[(Int, Long, Long)] + + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + // batch 0 + AddData(inputData, 0 until 1: _*), + CheckLastBatch( + (0, 1, 0) + ), + StopStream + ) + + // Attempt to read with internalOnlyReadAllColumnFamilies=true should fail + val e = intercept[StateDataSourceException] { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + .load() + .collect() + } + assert(e.getMessage.contains("internalOnlyReadAllColumnFamilies=true is only " + + s"supported with RocksDBStateStoreProvider")) + } + } + } } From 251f3060c457120cd7d0eccd4a0741164a158f14 Mon Sep 17 00:00:00 2001 From: zifeif2 Date: Fri, 21 Nov 2025 19:08:29 +0000 Subject: [PATCH 07/12] address comment --- .../resources/error/error-conditions.json | 13 + .../v2/state/StateDataSource.scala | 13 +- .../v2/state/StatePartitionReader.scala | 18 +- .../v2/state/utils/SchemaUtil.scala | 29 +- .../state/OfflineStateRepartitionErrors.scala | 24 +- .../streaming/state/StateStoreConf.scala | 17 +- ...artitionAllColumnFamiliesReaderSuite.scala | 453 ++++++++++++++++++ ...artitionReaderAllColumnFamiliesSuite.scala | 244 ---------- 8 files changed, 540 insertions(+), 271 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 492a33c57461..077f474484fa 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5520,6 +5520,19 @@ }, "sqlState" : "42616" }, + "STATE_REPARTITION_INVALID_STATE_STORE_CONFIG": { + "message" : [ + "StateStoreConfig is invalid:" + ], + "subClass" : { + "UNSUPPORTED_PROVIDER" : { + "message" : [ + " is not supported" + ] + } + }, + "sqlState" : "42617" + }, "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { "message" : [ "Failed to create column family with unsupported starting character and name=." diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 584b544ed5c9..713088393249 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -41,7 +41,8 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata -import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors import org.apache.spark.sql.execution.streaming.utils.StreamingUtils import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.streaming.TimeMode @@ -67,11 +68,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging StateSourceOptions.apply(session, hadoopConf, properties)) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) if (sourceOptions.internalOnlyReadAllColumnFamilies - && !stateConf.providerClass.contains("RocksDB")) { - throw StateDataSourceErrors.invalidOptionValue( - StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, - "internalOnlyReadAllColumnFamilies=true is only supported with " + - s"RocksDBStateStoreProvider. Current provider: ${stateConf.providerClass}") + && stateConf.providerClass != classOf[RocksDBStateStoreProvider].getName) { + throw OfflineStateRepartitionErrors.unsupportedStateStoreProviderError( + stateConf.providerClass) } val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( sourceOptions) @@ -379,7 +378,7 @@ case class StateSourceOptions( stateVarName: Option[String], readRegisteredTimers: Boolean, flattenCollectionTypes: Boolean, - internalOnlyReadAllColumnFamilies: Boolean, + internalOnlyReadAllColumnFamilies: Boolean = false, startOperatorStateUniqueIds: Option[Array[Array[String]]] = None, endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 475249e3fd27..e9d70c848a0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -50,7 +50,11 @@ class StatePartitionReaderFactory( override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) { - new StatePartitionReaderAllColumnFamilies(storeConf, hadoopConf, + val modifiedStoreConf = storeConf.withExtraOptions(Map( + StateStoreConf.FORMAT_VALIDATION_ENABLED_CONFIG -> "false", + StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> "false" + )) + new StatePartitionAllColumnFamiliesReader(modifiedStoreConf, hadoopConf, stateStoreInputPartition, schema, keyStateEncoderSpec) } else if (stateStoreInputPartition.sourceOptions.readChangeFeed) { new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, @@ -87,6 +91,8 @@ abstract class StatePartitionReaderBase( protected val keySchema = { if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) + } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { + schemaForValueRow } else { SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] } @@ -94,6 +100,8 @@ abstract class StatePartitionReaderBase( protected val valueSchema = if (stateVariableInfoOpt.isDefined) { schemaForValueRow + } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { + schemaForValueRow } else { SchemaUtil.getSchemaAsDataType( schema, "value").asInstanceOf[StructType] @@ -243,14 +251,17 @@ class StatePartitionReader( /** * An implementation of [[StatePartitionReaderBase]] for reading all column families * in binary format. This reader returns raw key and value bytes along with column family names. + * We are returning key/value bytes because each column family can have different schema */ -class StatePartitionReaderAllColumnFamilies( +class StatePartitionAllColumnFamiliesReader( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, schema: StructType, keyStateEncoderSpec: KeyStateEncoderSpec) - extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, + extends StatePartitionReaderBase( + storeConf, + hadoopConf, partition, schema, keyStateEncoderSpec, None, None, None, None) { private lazy val store: ReadStateStore = { @@ -263,7 +274,6 @@ class StatePartitionReaderAllColumnFamilies( } override lazy val iter: Iterator[InternalRow] = { - // Single store with column families (join v3, transformWithState, or simple operators) store .iterator() .map { pair => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index d28c64f1b113..9019a2e61449 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -63,15 +63,12 @@ object SchemaUtil { .add("partition_id", IntegerType) } else if (sourceOptions.internalOnlyReadAllColumnFamilies) { new StructType() - // todo: change this to some more specific type after we + // todo [SPARK-54443]: change keySchema to more specific type after we // can extract partition key from keySchema .add("partition_key", keySchema) .add("key_bytes", BinaryType) .add("value_bytes", BinaryType) .add("column_family_name", StringType) - // need key and value schema so that state store can encode data - .add("value", valueSchema) - .add("key", keySchema) } else { new StructType() .add("key", keySchema) @@ -89,14 +86,18 @@ object SchemaUtil { } /** - * Creates a unified row from raw key and value bytes. - * This is an alias for unifyStateRowPairAsBytes that takes individual byte arrays - * instead of a tuple for better readability. + * Returns an InternalRow representing + * 1. partitionKey + * 2. key in bytes + * 3. value in bytes + * 4. column family name */ def unifyStateRowPairAsRawBytes( - pair: (UnsafeRow, UnsafeRow), - colFamilyName: String): InternalRow = { - val row = new GenericInternalRow(6) + pair: (UnsafeRow, UnsafeRow), + colFamilyName: String): InternalRow = { + val row = new GenericInternalRow(4) + // todo [SPARK-54443]: change keySchema to more specific type after we + // can extract partition key from keySchema row.update(0, pair._1) row.update(1, pair._1.getBytes) row.update(2, pair._2.getBytes) @@ -261,9 +262,9 @@ object SchemaUtil { "expiration_timestamp_ms" -> classOf[LongType], "partition_id" -> classOf[IntegerType], "partition_key" -> classOf[StructType], - "key_bytes"->classOf[BinaryType], - "value_bytes"->classOf[BinaryType], - "column_family_name"->classOf[StringType]) + "key_bytes" -> classOf[BinaryType], + "value_bytes" -> classOf[BinaryType], + "column_family_name" -> classOf[StringType]) val expectedFieldNames = if (transformWithStateVariableInfoOpt.isDefined) { val stateVarInfo = transformWithStateVariableInfoOpt.get @@ -305,7 +306,7 @@ object SchemaUtil { } else if (sourceOptions.readChangeFeed) { Seq("batch_id", "change_type", "key", "value", "partition_id") } else if (sourceOptions.internalOnlyReadAllColumnFamilies) { - Seq("partition_key", "key_bytes", "value_bytes", "column_family_name", "value", "key") + Seq("partition_key", "key_bytes", "value_bytes", "column_family_name") } else { Seq("key", "value", "partition_id") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala index 95b273826877..030c81ee3ea4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException} +import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException, SparkRuntimeException} /** * Errors thrown by Offline state repartitioning. @@ -85,6 +85,12 @@ object OfflineStateRepartitionErrors { version: Int): StateRepartitionInvalidCheckpointError = { new StateRepartitionUnsupportedOffsetSeqVersionError(checkpointLocation, version) } + + def unsupportedStateStoreProviderError( + providerClass: String + ): StateRepartitionInvalidStateStoreConfigUnsupportedProviderError = { + new StateRepartitionInvalidStateStoreConfigUnsupportedProviderError(providerClass) + } } /** @@ -201,3 +207,19 @@ class StateRepartitionUnsupportedOffsetSeqVersionError( checkpointLocation, subClass = "UNSUPPORTED_OFFSET_SEQ_VERSION", messageParameters = Map("version" -> version.toString)) + +abstract class StateRepartitionInvalidStateStoreConfigError( + configName: String, + subClass: String, + messageParameters: Map[String, String] = Map.empty, + cause: Throwable = null) + extends SparkRuntimeException( + errorClass = s"STATE_REPARTITION_INVALID_STATE_STORE_CONFIG.$subClass", + messageParameters = Map("configName" -> configName) ++ messageParameters, + cause = cause) + +class StateRepartitionInvalidStateStoreConfigUnsupportedProviderError( + provider: String) extends StateRepartitionInvalidStateStoreConfigError( + "SQLConf.STATE_STORE_PROVIDER_CLASS.key", + subClass = "UNSUPPORTED_PROVIDER", + messageParameters = Map("provider" -> provider)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 3991f8d93f2c..790deb099496 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -30,6 +30,18 @@ class StateStoreConf( def this() = this(new SQLConf) + def withExtraOptions(additionalOptions: Map[String, String]): StateStoreConf = { + val reconstructedSqlConf = { + // Reconstruct a SQLConf with the all settings preserved because sqlConf is transient + val conf = new SQLConf() + // Restore all state store related settings + sqlConfs.foreach { case (key, value) => + conf.setConfString(key, value) + } + conf + } + new StateStoreConf(reconstructedSqlConf, extraOptions ++ additionalOptions) + } /** * Size of MaintenanceThreadPool to perform maintenance tasks for StateStore */ @@ -83,7 +95,9 @@ class StateStoreConf( val providerClass: String = sqlConf.stateStoreProviderClass /** Whether validate the underlying format or not. */ - val formatValidationEnabled: Boolean = sqlConf.stateStoreFormatValidationEnabled + val formatValidationEnabled: Boolean = extraOptions.getOrElse( + StateStoreConf.FORMAT_VALIDATION_ENABLED_CONFIG, + sqlConf.stateStoreFormatValidationEnabled) == "true" /** * Whether to validate StateStore commits for ForeachBatch sinks to ensure all partitions @@ -166,6 +180,7 @@ class StateStoreConf( } object StateStoreConf { + val FORMAT_VALIDATION_ENABLED_CONFIG = "formatValidationEnabled" val FORMAT_VALIDATION_CHECK_VALUE_CONFIG = "formatValidationCheckValue" val empty = new StateStoreConf() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala new file mode 100644 index 000000000000..8e3743914cb6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.nio.ByteOrder +import java.util.Arrays + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} +import org.apache.spark.sql.functions.{count, sum} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{IntegerType, LongType, NullType, StructField, StructType} + +/** + * Note: This extends StateDataSourceTestBase to access + * helper methods like runDropDuplicatesQuery without inheriting all predefined tests. + */ +class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + classOf[RocksDBStateStoreProvider].getName) + } + + private def getNormalReadDf(checkpointDir: String): DataFrame = { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .load() + .selectExpr("partition_id", "key", "value") + } + + private def getBytesReadDf(checkpointDir: String): DataFrame = { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + .load() + .selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name") + } + + /** + * Validates the schema and column families of the bytes read DataFrame. + */ + private def validateBytesReadDfSchema(df: DataFrame): Unit = { + // Verify schema + val schema = df.schema + assert(schema.fieldNames === Array( + "partition_key", "key_bytes", "value_bytes", "column_family_name")) + assert(schema("partition_key").dataType.typeName === "struct") + assert(schema("key_bytes").dataType.typeName === "binary") + assert(schema("value_bytes").dataType.typeName === "binary") + assert(schema("column_family_name").dataType.typeName === "string") + } + + /** + * Compares normal read data with bytes read data for a specific column family. + * Converts normal rows to bytes then compares with bytes read. + */ + private def compareNormalAndBytesData( + normalDf: Array[Row], + bytesDf: Array[Row], + columnFamily: String, + keySchema: StructType, + valueSchema: StructType): Unit = { + + // Filter bytes data for the specified column family and extract raw bytes directly + val filteredBytesData = bytesDf.filter { row => + row.getString(3) == columnFamily + } + + // Verify same number of rows + assert(filteredBytesData.length == normalDf.length, + s"Row count mismatch for column family '$columnFamily': " + + s"normal read has ${normalDf.length} rows, " + + s"bytes read has ${filteredBytesData.length} rows") + + // Create projections to convert Row to UnsafeRow bytes + val keyProjection = UnsafeProjection.create(keySchema) + val valueProjection = if (valueSchema.isEmpty) null else UnsafeProjection.create(valueSchema) + + // Create converters to convert external Row types to internal Catalyst types + val keyConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) + val valueConverter = if (valueSchema.isEmpty) { + null + } else { + CatalystTypeConverters.createToCatalystConverter(valueSchema) + } + + // Convert normal data to bytes + val normalAsBytes = normalDf.map { row => + val key = row.getStruct(1) + val value = if (row.isNullAt(2) || valueSchema.isEmpty) null else row.getStruct(2) + + // Convert key to InternalRow, then to UnsafeRow, then get bytes + val keyInternalRow = keyConverter(key).asInstanceOf[InternalRow] + val keyUnsafeRow = keyProjection(keyInternalRow) + // IMPORTANT: Must clone the bytes array since getBytes() returns a reference + // that may be overwritten by subsequent UnsafeRow operations + val keyBytes = keyUnsafeRow.getBytes.clone() + + // Convert value to bytes + val valueBytes = if (value == null || valueSchema.isEmpty) { + Array.empty[Byte] + } else { + val valueInternalRow = valueConverter(value).asInstanceOf[InternalRow] + val valueUnsafeRow = valueProjection(valueInternalRow) + // IMPORTANT: Must clone the bytes array since getBytes() returns a reference + // that may be overwritten by subsequent UnsafeRow operations + valueUnsafeRow.getBytes.clone() + } + + (keyBytes, valueBytes) + } + + // Extract raw bytes from bytes read data (no deserialization/reserialization) + val bytesAsBytes = filteredBytesData.map { row => + val keyBytes = row.getAs[Array[Byte]](1) + val valueBytes = row.getAs[Array[Byte]](2) + (keyBytes, valueBytes) + } + + // Sort both for comparison (since Set equality doesn't work well with byte arrays) + val normalSorted = normalAsBytes.sortBy(x => (x._1.mkString(","), x._2.mkString(","))) + val bytesSorted = bytesAsBytes.sortBy(x => (x._1.mkString(","), x._2.mkString(","))) + + assert(normalSorted.length == bytesSorted.length, + s"Size mismatch: normal has ${normalSorted.length}, bytes has ${bytesSorted.length}") + + // Compare each pair + normalSorted.zip(bytesSorted).zipWithIndex.foreach { + case (((normalKey, normalValue), (bytesKey, bytesValue)), idx) => + assert(Arrays.equals(normalKey, bytesKey), + s"Key mismatch at index $idx:\n" + + s" Normal: ${normalKey.mkString("[", ",", "]")}\n" + + s" Bytes: ${bytesKey.mkString("[", ",", "]")}") + assert(Arrays.equals(normalValue, bytesValue), + s"Value mismatch at index $idx:\n" + + s" Normal: ${normalValue.mkString("[", ",", "]")}\n" + + s" Bytes: ${bytesValue.mkString("[", ",", "]")}") + } + } + + // Run all tests with both changelog checkpointing enabled and disabled + Seq(true, false).foreach { changelogCheckpointingEnabled => + val testSuffix = if (changelogCheckpointingEnabled) { + "with changelog checkpointing" + } else { + "without changelog checkpointing" + } + + def testWithChangelogConfig(testName: String)(testFun: => Unit): Unit = { + test(s"$testName ($testSuffix)") { + withSQLConf( + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> + changelogCheckpointingEnabled.toString) { + testFun + } + } + } + + testWithChangelogConfig("all-column-families: simple aggregation state ver 1") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array(StructField("groupKey", IntegerType, nullable = false))) + // State version 1 includes key columns in the value + val valueSchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("all-column-families: simple aggregation state ver 2") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array(StructField("groupKey", IntegerType, nullable = false))) + val valueSchema = StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("all-column-families: composite key aggregation state ver 1") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true) + )) + // State version 1 includes key columns in the value + val valueSchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true), + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("all-column-families: composite key aggregation state ver 2") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("all-column-families: dropDuplicates validation") { + withTempDir { tempDir => + runDropDuplicatesQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("eventTime", org.apache.spark.sql.types.TimestampType) + )) + val valueSchema = StructType(Array( + StructField("__dummy__", NullType, nullable = true) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("all-column-families: dropDuplicates with column specified") { + withTempDir { tempDir => + runDropDuplicatesQueryWithColumnSpecified(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("col1", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("__dummy__", NullType, nullable = true) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("all-column-families: dropDuplicatesWithinWatermark") { + withTempDir { tempDir => + runDropDuplicatesWithinWatermarkQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("_1", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("expiresAtMicros", LongType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("all-column-families: session window aggregation") { + withTempDir { tempDir => + runSessionWindowAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), + StructField("sessionStartTime", org.apache.spark.sql.types.TimestampType, nullable = false) + )) + val valueSchema = StructType(Array( + StructField("session_window", org.apache.spark.sql.types.StructType(Array( + StructField("start", org.apache.spark.sql.types.TimestampType, nullable = true), + StructField("end", org.apache.spark.sql.types.TimestampType, nullable = true) + )), nullable = false), + StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), + StructField("count", LongType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("all-column-families: flatMapGroupsWithState, state ver 1") { + // Skip this test on big endian platforms + assume(java.nio.ByteOrder.nativeOrder().equals(java.nio.ByteOrder.LITTLE_ENDIAN)) + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + assume(ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN)) + runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("value", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("numEvents", IntegerType, nullable = false), + StructField("startTimestampMs", LongType, nullable = false), + StructField("endTimestampMs", LongType, nullable = false), + StructField("timeoutTimestamp", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("all-column-families: flatMapGroupsWithState, state ver 2") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("value", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("groupState", org.apache.spark.sql.types.StructType(Array( + StructField("numEvents", IntegerType, nullable = false), + StructField("startTimestampMs", LongType, nullable = false), + StructField("endTimestampMs", LongType, nullable = false) + )), nullable = false), + StructField("timeoutTimestamp", LongType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + } + } + } + } // End of foreach loop for changelog checkpointing dimension + + test("internalOnlyReadAllColumnFamilies should fail with HDFS-backed state store") { + withTempDir { tempDir => + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[HDFSBackedStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF() + .selectExpr("value", "value % 10 AS groupKey") + .groupBy($"groupKey") + .agg( + count("*").as("cnt"), + sum("value").as("sum") + ) + .as[(Int, Long, Long)] + + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 0 until 1: _*), + CheckLastBatch((0, 1, 0)), + StopStream + ) + + checkError( + exception = intercept[SparkRuntimeException] { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + .load() + .collect() + }, + condition = "STATE_REPARTITION_INVALID_STATE_STORE_CONFIG.UNSUPPORTED_PROVIDER", + parameters = Map( + "configName" -> "SQLConf.STATE_STORE_PROVIDER_CLASS.key", + "provider" -> classOf[HDFSBackedStateStoreProvider].getName + ) + ) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala deleted file mode 100644 index 46e0f225c49a..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala +++ /dev/null @@ -1,244 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.v2.state - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, UnsafeRow} -import org.apache.spark.sql.execution.streaming.runtime.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} -import org.apache.spark.sql.functions.{count, sum} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} -import org.apache.spark.tags.SlowSQLTest -import org.apache.spark.unsafe.Platform - -/** - * Test suite to verify StatePartitionReaderAllColumnFamilies functionality. - */ -@SlowSQLTest -class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase { - - import testImplicits._ - - /** - * Returns a set of (partitionId, key, value) tuples from a normal state read. - */ - private def getNormalReadData(checkpointDir: String): DataFrame = { - spark.read - .format("statestore") - .option(StateSourceOptions.PATH, checkpointDir) - .load() - .selectExpr("partition_id", "key", "value") - } - - /** - * Returns a DataFrame with raw bytes mode (INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = true). - */ - private def getBytesReadDf(checkpointDir: String): DataFrame = { - spark.read - .format("statestore") - .option(StateSourceOptions.PATH, checkpointDir) - .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") - .load() - } - - /** - * Validates the schema and column families of the bytes read DataFrame. - */ - private def validateBytesReadSchema(df: DataFrame): Unit = { - // Verify schema - val schema = df.schema - assert(schema.fieldNames === Array( - "partition_key", "key_bytes", "value_bytes", "column_family_name", "value", "key")) - assert(schema("partition_key").dataType.typeName === "struct") - assert(schema("key_bytes").dataType.typeName === "binary") - assert(schema("value_bytes").dataType.typeName === "binary") - assert(schema("column_family_name").dataType.typeName === "string") - } - - private def parseBytesReadData(df: Array[Row], keyLength: Int, valueLength: Int) - : Set[(GenericRowWithSchema, UnsafeRow, UnsafeRow, String)] = { - df.map { row => - val partitionKey = row.getAs[GenericRowWithSchema](0) - val keyBytes = row.getAs[Array[Byte]](1) - val valueBytes = row.getAs[Array[Byte]](2) - val columnFamily = row.getString(3) - - // Deserialize key bytes to UnsafeRow - val keyRow = new UnsafeRow(keyLength) - keyRow.pointTo( - keyBytes, - Platform.BYTE_ARRAY_OFFSET, - keyBytes.length) - - // Deserialize value bytes to UnsafeRow - val valueRow = new UnsafeRow(valueLength) - valueRow.pointTo( - valueBytes, - Platform.BYTE_ARRAY_OFFSET, - valueBytes.length) - (partitionKey, keyRow.copy(), valueRow.copy(), columnFamily) - } - .toSet - } - - /** - * Compares normal read data with bytes read data for a specific column family. - */ - private def compareNormalAndBytesData( - normalReadDf: DataFrame, - bytesReadDf: DataFrame, - columnFamily: String, - keySchema: StructType, - valueSchema: StructType): Unit = { - // Verify data - val bytesDf = bytesReadDf - .selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name") - .collect() - assert(bytesDf.length == 10, - s"Expected 10 rows but got: ${bytesDf.length}") - - // Filter bytes data for the specified column family - val bytesData = parseBytesReadData(bytesDf, keySchema.length, valueSchema.length) - val filteredBytesData = bytesData.filter(_._4 == columnFamily) - - // Apply the projection - // Convert to comparable format (extract field values) - val normalSet = normalReadDf.collect().map { row => - val key = row.getStruct(1) - val value = row.getStruct(2) - val keyFields = (0 until key.length).map(i => key.get(i)) - val valueFields = (0 until value.length).map(i => value.get(i)) - (keyFields, valueFields) - }.toSet - - val bytesSet = filteredBytesData.map { case (_, keyRow, valueRow, _) => - val keyFields = (0 until keySchema.length).map(i => - keyRow.get(i, keySchema(i).dataType)) - val valueFields = (0 until valueSchema.length).map(i => - valueRow.get(i, valueSchema(i).dataType)) - (keyFields, valueFields) - } - // Verify same number of rows - assert(filteredBytesData.size == normalSet.size, - s"Row count mismatch for column family '$columnFamily': " + - s"normal read has ${filteredBytesData.size} rows, bytes read has ${normalSet.size} rows") - - assert(normalSet == bytesSet) - } - - test(s"read all column families with simple operator") { - withTempDir { tempDir => - withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> "2") { - - val inputData = MemoryStream[Int] - val aggregated = inputData.toDF() - .selectExpr("value", "value % 10 AS groupKey") - .groupBy($"groupKey") - .agg( - count("*").as("cnt"), - sum("value").as("sum") - ) - .as[(Int, Long, Long)] - - testStream(aggregated, OutputMode.Update)( - StartStream(checkpointLocation = tempDir.getAbsolutePath), - // batch 0 - AddData(inputData, 0 until 20: _*), - CheckLastBatch( - (0, 2, 10), // 0, 10 - (1, 2, 12), // 1, 11 - (2, 2, 14), // 2, 12 - (3, 2, 16), // 3, 13 - (4, 2, 18), // 4, 14 - (5, 2, 20), // 5, 15 - (6, 2, 22), // 6, 16 - (7, 2, 24), // 7, 17 - (8, 2, 26), // 8, 18 - (9, 2, 28) // 9, 19 - ), - StopStream - ) - - // Read state data once with INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = true - val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath) - - // Verify schema and column families - validateBytesReadSchema(bytesReadDf) - // Compare normal and bytes data for default column family - val keySchema: StructType = StructType(Array( - StructField("key", IntegerType, nullable = false) - )) - - // Value schema for the aggregation: count and sum columns - val valueSchema: StructType = StructType(Array( - StructField("count", LongType, nullable = false), - StructField("sum", LongType, nullable = false) - )) - - // Get normal read data for comparison - val normalData = getNormalReadData(tempDir.getAbsolutePath) - compareNormalAndBytesData( - normalData, bytesReadDf, "default", keySchema, valueSchema) - } - } - } - - test("internalOnlyReadAllColumnFamilies should fail with HDFS-backed state store") { - withTempDir { tempDir => - withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[HDFSBackedStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> "2") { - - val inputData = MemoryStream[Int] - val aggregated = inputData.toDF() - .selectExpr("value", "value % 10 AS groupKey") - .groupBy($"groupKey") - .agg( - count("*").as("cnt"), - sum("value").as("sum") - ) - .as[(Int, Long, Long)] - - testStream(aggregated, OutputMode.Update)( - StartStream(checkpointLocation = tempDir.getAbsolutePath), - // batch 0 - AddData(inputData, 0 until 1: _*), - CheckLastBatch( - (0, 1, 0) - ), - StopStream - ) - - // Attempt to read with internalOnlyReadAllColumnFamilies=true should fail - val e = intercept[StateDataSourceException] { - spark.read - .format("statestore") - .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) - .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") - .load() - .collect() - } - assert(e.getMessage.contains("internalOnlyReadAllColumnFamilies=true is only " + - s"supported with RocksDBStateStoreProvider")) - } - } - } -} From fa776e1c42eb0418505dc254af697f385896a275 Mon Sep 17 00:00:00 2001 From: zifeif2 Date: Sat, 22 Nov 2025 04:57:46 +0000 Subject: [PATCH 08/12] small changes --- .../v2/state/StateDataSource.scala | 13 +++++++------ .../v2/state/StatePartitionReader.scala | 19 +++++++++++++------ .../v2/state/utils/SchemaUtil.scala | 2 +- .../state/OfflineStateRepartitionErrors.scala | 13 ++++++------- .../streaming/state/StateStoreConf.scala | 7 ++++--- ...artitionAllColumnFamiliesReaderSuite.scala | 18 +++++++----------- 6 files changed, 38 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 713088393249..a838c83895e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -67,6 +67,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, StateSourceOptions.apply(session, hadoopConf, properties)) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) + // We only support RocksDB because the repartition work that this reader + // is built for only supports RocksDB if (sourceOptions.internalOnlyReadAllColumnFamilies && stateConf.providerClass != classOf[RocksDBStateStoreProvider].getName) { throw OfflineStateRepartitionErrors.unsupportedStateStoreProviderError( @@ -386,8 +388,8 @@ case class StateSourceOptions( override def toString: String = { var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " + - s"stateVarName=${stateVarName.getOrElse("None")}, +" + - s"flattenCollectionTypes=$flattenCollectionTypes" + + s"stateVarName=${stateVarName.getOrElse("None")}, " + + s"flattenCollectionTypes=$flattenCollectionTypes, " + s"internalOnlyReadAllColumnFamilies=$internalOnlyReadAllColumnFamilies" if (fromSnapshotOptions.isDefined) { desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}" @@ -502,8 +504,7 @@ object StateSourceOptions extends DataSourceOptions { val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean) val internalOnlyReadAllColumnFamilies = try { - Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES)) - .map(_.toBoolean).getOrElse(false) + Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES)).exists(_.toBoolean) } catch { case _: IllegalArgumentException => throw StateDataSourceErrors.invalidOptionValue(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, @@ -648,8 +649,8 @@ object StateSourceOptions extends DataSourceOptions { StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, - stateVarName, readRegisteredTimers, flattenCollectionTypes, - internalOnlyReadAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds) + stateVarName, readRegisteredTimers, flattenCollectionTypes, internalOnlyReadAllColumnFamilies, + startOperatorStateUniqueIds, endOperatorStateUniqueIds) } private def getLastCommittedBatch(session: SparkSession, checkpointLocation: String): Long = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index e9d70c848a0c..26c9186f2615 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -50,6 +50,11 @@ class StatePartitionReaderFactory( override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) { + // Disable format validation because the schema returned by + // StatePartitionAllColumnFamiliesReader does not contain the corresponding + // keySchema or valueSchema. + // It's safe to do so we also don't expect the caller of StatePartitionAllColumnFamiliesReader + // to extract specific fields out of the returning row. val modifiedStoreConf = storeConf.withExtraOptions(Map( StateStoreConf.FORMAT_VALIDATION_ENABLED_CONFIG -> "false", StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> "false" @@ -85,23 +90,24 @@ abstract class StatePartitionReaderBase( extends PartitionReader[InternalRow] with Logging { // Used primarily as a placeholder for the value schema in the context of // state variables used within the transformWithState operator. - private val schemaForValueRow: StructType = + // Also used as a placeholder for both key and value schema for + // StatePartitionAllColumnFamiliesReader + private val placeholderSchema: StructType = StructType(Array(StructField("__dummy__", NullType))) protected val keySchema = { if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { - schemaForValueRow + placeholderSchema } else { SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] } } - protected val valueSchema = if (stateVariableInfoOpt.isDefined) { - schemaForValueRow - } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { - schemaForValueRow + protected val valueSchema = if (stateVariableInfoOpt.isDefined || + partition.sourceOptions.internalOnlyReadAllColumnFamilies) { + placeholderSchema } else { SchemaUtil.getSchemaAsDataType( schema, "value").asInstanceOf[StructType] @@ -252,6 +258,7 @@ class StatePartitionReader( * An implementation of [[StatePartitionReaderBase]] for reading all column families * in binary format. This reader returns raw key and value bytes along with column family names. * We are returning key/value bytes because each column family can have different schema + * It will also return the partition key */ class StatePartitionAllColumnFamiliesReader( storeConf: StateStoreConf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 9019a2e61449..44d83fc99b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -63,7 +63,7 @@ object SchemaUtil { .add("partition_id", IntegerType) } else if (sourceOptions.internalOnlyReadAllColumnFamilies) { new StructType() - // todo [SPARK-54443]: change keySchema to more specific type after we + // todo [SPARK-54443]: change keySchema to a more specific type after we // can extract partition key from keySchema .add("partition_key", keySchema) .add("key_bytes", BinaryType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala index 030c81ee3ea4..9c4afdd48d23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException, SparkRuntimeException} +import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException, SparkUnsupportedOperationException} +import org.apache.spark.sql.internal.SQLConf /** * Errors thrown by Offline state repartitioning. @@ -211,15 +212,13 @@ class StateRepartitionUnsupportedOffsetSeqVersionError( abstract class StateRepartitionInvalidStateStoreConfigError( configName: String, subClass: String, - messageParameters: Map[String, String] = Map.empty, - cause: Throwable = null) - extends SparkRuntimeException( + messageParameters: Map[String, String] = Map.empty) + extends SparkUnsupportedOperationException( errorClass = s"STATE_REPARTITION_INVALID_STATE_STORE_CONFIG.$subClass", - messageParameters = Map("configName" -> configName) ++ messageParameters, - cause = cause) + messageParameters = Map("configName" -> configName) ++ messageParameters) class StateRepartitionInvalidStateStoreConfigUnsupportedProviderError( provider: String) extends StateRepartitionInvalidStateStoreConfigError( - "SQLConf.STATE_STORE_PROVIDER_CLASS.key", + SQLConf.STATE_STORE_PROVIDER_CLASS.key, subClass = "UNSUPPORTED_PROVIDER", messageParameters = Map("provider" -> provider)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 790deb099496..22474f55c5d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -95,9 +95,10 @@ class StateStoreConf( val providerClass: String = sqlConf.stateStoreProviderClass /** Whether validate the underlying format or not. */ - val formatValidationEnabled: Boolean = extraOptions.getOrElse( - StateStoreConf.FORMAT_VALIDATION_ENABLED_CONFIG, - sqlConf.stateStoreFormatValidationEnabled) == "true" + val formatValidationEnabled: Boolean = extraOptions.get( + StateStoreConf.FORMAT_VALIDATION_ENABLED_CONFIG) + .map(_ == "true") + .getOrElse(sqlConf.stateStoreFormatValidationEnabled) /** * Whether to validate StateStore commits for ForeachBatch sinks to ensure all partitions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala index 8e3743914cb6..2054a29a2815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.state import java.nio.ByteOrder import java.util.Arrays -import org.apache.spark.SparkRuntimeException +import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.InternalRow @@ -100,20 +100,16 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase // Create projections to convert Row to UnsafeRow bytes val keyProjection = UnsafeProjection.create(keySchema) - val valueProjection = if (valueSchema.isEmpty) null else UnsafeProjection.create(valueSchema) + val valueProjection = UnsafeProjection.create(valueSchema) // Create converters to convert external Row types to internal Catalyst types val keyConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) - val valueConverter = if (valueSchema.isEmpty) { - null - } else { - CatalystTypeConverters.createToCatalystConverter(valueSchema) - } + val valueConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) // Convert normal data to bytes val normalAsBytes = normalDf.map { row => val key = row.getStruct(1) - val value = if (row.isNullAt(2) || valueSchema.isEmpty) null else row.getStruct(2) + val value = if (row.isNullAt(2)) null else row.getStruct(2) // Convert key to InternalRow, then to UnsafeRow, then get bytes val keyInternalRow = keyConverter(key).asInstanceOf[InternalRow] @@ -123,7 +119,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase val keyBytes = keyUnsafeRow.getBytes.clone() // Convert value to bytes - val valueBytes = if (value == null || valueSchema.isEmpty) { + val valueBytes = if (value == null) { Array.empty[Byte] } else { val valueInternalRow = valueConverter(value).asInstanceOf[InternalRow] @@ -433,7 +429,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase ) checkError( - exception = intercept[SparkRuntimeException] { + exception = intercept[SparkUnsupportedOperationException] { spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) @@ -443,7 +439,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase }, condition = "STATE_REPARTITION_INVALID_STATE_STORE_CONFIG.UNSUPPORTED_PROVIDER", parameters = Map( - "configName" -> "SQLConf.STATE_STORE_PROVIDER_CLASS.key", + "configName" -> SQLConf.STATE_STORE_PROVIDER_CLASS.key, "provider" -> classOf[HDFSBackedStateStoreProvider].getName ) ) From 63e0753260f9025ef4dca29cb908fce18a52e3c5 Mon Sep 17 00:00:00 2001 From: zifeif2 Date: Mon, 24 Nov 2025 20:03:00 +0000 Subject: [PATCH 09/12] fix small issue --- .../resources/error/error-conditions.json | 2 +- ...artitionAllColumnFamiliesReaderSuite.scala | 100 +++++++++--------- 2 files changed, 51 insertions(+), 51 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 077f474484fa..8e5cdd36b2a4 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5520,7 +5520,7 @@ }, "sqlState" : "42616" }, - "STATE_REPARTITION_INVALID_STATE_STORE_CONFIG": { + "STATE_REPARTITION_INVALID_STATE_STORE_CONFIG" : { "message" : [ "StateStoreConfig is invalid:" ], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala index 2054a29a2815..836470c727f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala @@ -276,80 +276,80 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase testWithChangelogConfig("all-column-families: dropDuplicates validation") { withTempDir { tempDir => - runDropDuplicatesQuery(tempDir.getAbsolutePath) + runDropDuplicatesQuery(tempDir.getAbsolutePath) - val keySchema = StructType(Array( - StructField("value", IntegerType, nullable = false), - StructField("eventTime", org.apache.spark.sql.types.TimestampType) - )) - val valueSchema = StructType(Array( - StructField("__dummy__", NullType, nullable = true) - )) + val keySchema = StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("eventTime", org.apache.spark.sql.types.TimestampType) + )) + val valueSchema = StructType(Array( + StructField("__dummy__", NullType, nullable = true) + )) - val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) } } testWithChangelogConfig("all-column-families: dropDuplicates with column specified") { withTempDir { tempDir => - runDropDuplicatesQueryWithColumnSpecified(tempDir.getAbsolutePath) + runDropDuplicatesQueryWithColumnSpecified(tempDir.getAbsolutePath) - val keySchema = StructType(Array( - StructField("col1", org.apache.spark.sql.types.StringType, nullable = true) - )) - val valueSchema = StructType(Array( - StructField("__dummy__", NullType, nullable = true) - )) + val keySchema = StructType(Array( + StructField("col1", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("__dummy__", NullType, nullable = true) + )) - val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) } } testWithChangelogConfig("all-column-families: dropDuplicatesWithinWatermark") { withTempDir { tempDir => - runDropDuplicatesWithinWatermarkQuery(tempDir.getAbsolutePath) + runDropDuplicatesWithinWatermarkQuery(tempDir.getAbsolutePath) - val keySchema = StructType(Array( - StructField("_1", org.apache.spark.sql.types.StringType, nullable = true) - )) - val valueSchema = StructType(Array( - StructField("expiresAtMicros", LongType, nullable = false) - )) + val keySchema = StructType(Array( + StructField("_1", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("expiresAtMicros", LongType, nullable = false) + )) - val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) } } testWithChangelogConfig("all-column-families: session window aggregation") { withTempDir { tempDir => - runSessionWindowAggregationQuery(tempDir.getAbsolutePath) - - val keySchema = StructType(Array( - StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), - StructField("sessionStartTime", org.apache.spark.sql.types.TimestampType, nullable = false) - )) - val valueSchema = StructType(Array( - StructField("session_window", org.apache.spark.sql.types.StructType(Array( - StructField("start", org.apache.spark.sql.types.TimestampType, nullable = true), - StructField("end", org.apache.spark.sql.types.TimestampType, nullable = true) - )), nullable = false), - StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), - StructField("count", LongType, nullable = false) - )) - - val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() - - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + runSessionWindowAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), + StructField("sessionStartTime", org.apache.spark.sql.types.TimestampType, nullable = false) + )) + val valueSchema = StructType(Array( + StructField("session_window", org.apache.spark.sql.types.StructType(Array( + StructField("start", org.apache.spark.sql.types.TimestampType, nullable = true), + StructField("end", org.apache.spark.sql.types.TimestampType, nullable = true) + )), nullable = false), + StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), + StructField("count", LongType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + + compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) } } From aee5732814a795a2ba87712d9036f94af48787ec Mon Sep 17 00:00:00 2001 From: zifeif2 Date: Wed, 26 Nov 2025 08:07:27 +0000 Subject: [PATCH 10/12] address commenet --- .../resources/error/error-conditions.json | 18 +-- .../v2/state/StateDataSource.scala | 3 +- .../v2/state/StatePartitionReader.scala | 34 +++-- .../state/OfflineStateRepartitionErrors.scala | 23 ++-- .../streaming/state/StateStoreConf.scala | 18 +-- ...artitionAllColumnFamiliesReaderSuite.scala | 122 ++++++++++++++---- 6 files changed, 135 insertions(+), 83 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8e5cdd36b2a4..723efbea715e 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5493,6 +5493,11 @@ "message" : [ "Unsupported offset sequence version . Please make sure the checkpoint is from a supported Spark version (Spark 4.0+)." ] + }, + "UNSUPPORTED_PROVIDER" : { + "message" : [ + " is not supported" + ] } }, "sqlState" : "55019" @@ -5520,19 +5525,6 @@ }, "sqlState" : "42616" }, - "STATE_REPARTITION_INVALID_STATE_STORE_CONFIG" : { - "message" : [ - "StateStoreConfig is invalid:" - ], - "subClass" : { - "UNSUPPORTED_PROVIDER" : { - "message" : [ - " is not supported" - ] - } - }, - "sqlState" : "42617" - }, "STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : { "message" : [ "Failed to create column family with unsupported starting character and name=." diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index a838c83895e7..c07c67e657f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -67,11 +67,12 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, StateSourceOptions.apply(session, hadoopConf, properties)) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) - // We only support RocksDB because the repartition work that this reader + // We only support RocksDB because the repartition work that this option // is built for only supports RocksDB if (sourceOptions.internalOnlyReadAllColumnFamilies && stateConf.providerClass != classOf[RocksDBStateStoreProvider].getName) { throw OfflineStateRepartitionErrors.unsupportedStateStoreProviderError( + sourceOptions.resolvedCpLocation, stateConf.providerClass) } val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 26c9186f2615..a71a3aa2d4f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.datasources.v2.state +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} @@ -50,16 +52,7 @@ class StatePartitionReaderFactory( override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) { - // Disable format validation because the schema returned by - // StatePartitionAllColumnFamiliesReader does not contain the corresponding - // keySchema or valueSchema. - // It's safe to do so we also don't expect the caller of StatePartitionAllColumnFamiliesReader - // to extract specific fields out of the returning row. - val modifiedStoreConf = storeConf.withExtraOptions(Map( - StateStoreConf.FORMAT_VALIDATION_ENABLED_CONFIG -> "false", - StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG -> "false" - )) - new StatePartitionAllColumnFamiliesReader(modifiedStoreConf, hadoopConf, + new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf, stateStoreInputPartition, schema, keyStateEncoderSpec) } else if (stateStoreInputPartition.sourceOptions.readChangeFeed) { new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, @@ -95,19 +88,34 @@ abstract class StatePartitionReaderBase( private val placeholderSchema: StructType = StructType(Array(StructField("__dummy__", NullType))) + private val colFamilyToSchema : mutable.HashMap[String, StateStoreColFamilySchema] = { + val stateStoreId = StateStoreId( + partition.sourceOptions.stateCheckpointLocation.toString, + partition.sourceOptions.operatorId, + StateStore.PARTITION_ID_TO_CHECK_SCHEMA, + partition.sourceOptions.storeName) + val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) + val manager = new StateSchemaCompatibilityChecker(stateStoreProviderId, hadoopConf.value) + val schemaFile = manager.readSchemaFile() + val schemaMap = mutable.HashMap[String, StateStoreColFamilySchema]() + schemaFile.foreach { schema => schemaMap.put(schema.colFamilyName, schema)} + schemaMap + } + protected val keySchema = { if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { - placeholderSchema + colFamilyToSchema(StateStore.DEFAULT_COL_FAMILY_NAME).keySchema } else { SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] } } - protected val valueSchema = if (stateVariableInfoOpt.isDefined || - partition.sourceOptions.internalOnlyReadAllColumnFamilies) { + protected val valueSchema = if (stateVariableInfoOpt.isDefined) { placeholderSchema + } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { + colFamilyToSchema(StateStore.DEFAULT_COL_FAMILY_NAME).valueSchema } else { SchemaUtil.getSchemaAsDataType( schema, "value").asInstanceOf[StructType] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala index 9c4afdd48d23..16e4837fc342 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException, SparkUnsupportedOperationException} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.{SparkIllegalArgumentException, SparkIllegalStateException} /** * Errors thrown by Offline state repartitioning. @@ -88,9 +87,10 @@ object OfflineStateRepartitionErrors { } def unsupportedStateStoreProviderError( + checkpointLocation: String, providerClass: String - ): StateRepartitionInvalidStateStoreConfigUnsupportedProviderError = { - new StateRepartitionInvalidStateStoreConfigUnsupportedProviderError(providerClass) + ): StateRepartitionUnsupportedProviderError = { + new StateRepartitionUnsupportedProviderError(checkpointLocation, providerClass) } } @@ -209,16 +209,9 @@ class StateRepartitionUnsupportedOffsetSeqVersionError( subClass = "UNSUPPORTED_OFFSET_SEQ_VERSION", messageParameters = Map("version" -> version.toString)) -abstract class StateRepartitionInvalidStateStoreConfigError( - configName: String, - subClass: String, - messageParameters: Map[String, String] = Map.empty) - extends SparkUnsupportedOperationException( - errorClass = s"STATE_REPARTITION_INVALID_STATE_STORE_CONFIG.$subClass", - messageParameters = Map("configName" -> configName) ++ messageParameters) - -class StateRepartitionInvalidStateStoreConfigUnsupportedProviderError( - provider: String) extends StateRepartitionInvalidStateStoreConfigError( - SQLConf.STATE_STORE_PROVIDER_CLASS.key, +class StateRepartitionUnsupportedProviderError( + checkpointLocation: String, + provider: String) extends StateRepartitionInvalidCheckpointError( + checkpointLocation, subClass = "UNSUPPORTED_PROVIDER", messageParameters = Map("provider" -> provider)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 22474f55c5d2..3991f8d93f2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -30,18 +30,6 @@ class StateStoreConf( def this() = this(new SQLConf) - def withExtraOptions(additionalOptions: Map[String, String]): StateStoreConf = { - val reconstructedSqlConf = { - // Reconstruct a SQLConf with the all settings preserved because sqlConf is transient - val conf = new SQLConf() - // Restore all state store related settings - sqlConfs.foreach { case (key, value) => - conf.setConfString(key, value) - } - conf - } - new StateStoreConf(reconstructedSqlConf, extraOptions ++ additionalOptions) - } /** * Size of MaintenanceThreadPool to perform maintenance tasks for StateStore */ @@ -95,10 +83,7 @@ class StateStoreConf( val providerClass: String = sqlConf.stateStoreProviderClass /** Whether validate the underlying format or not. */ - val formatValidationEnabled: Boolean = extraOptions.get( - StateStoreConf.FORMAT_VALIDATION_ENABLED_CONFIG) - .map(_ == "true") - .getOrElse(sqlConf.stateStoreFormatValidationEnabled) + val formatValidationEnabled: Boolean = sqlConf.stateStoreFormatValidationEnabled /** * Whether to validate StateStore commits for ForeachBatch sinks to ensure all partitions @@ -181,7 +166,6 @@ class StateStoreConf( } object StateStoreConf { - val FORMAT_VALIDATION_ENABLED_CONFIG = "formatValidationEnabled" val FORMAT_VALIDATION_CHECK_VALUE_CONFIG = "formatValidationCheckValue" val empty = new StateStoreConf() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala index 836470c727f4..1dce7467a7af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala @@ -19,17 +19,16 @@ package org.apache.spark.sql.execution.datasources.v2.state import java.nio.ByteOrder import java.util.Arrays -import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.streaming.runtime.MemoryStream -import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateRepartitionUnsupportedProviderError, StateStore} import org.apache.spark.sql.functions.{count, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{IntegerType, LongType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, NullType, StructField, StructType, TimestampType} /** * Note: This extends StateDataSourceTestBase to access @@ -45,19 +44,25 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase classOf[RocksDBStateStoreProvider].getName) } - private def getNormalReadDf(checkpointDir: String): DataFrame = { + private def getNormalReadDf( + checkpointDir: String, + storeName: Option[String] = Option.empty[String]): DataFrame = { spark.read .format("statestore") .option(StateSourceOptions.PATH, checkpointDir) + .option(StateSourceOptions.STORE_NAME, storeName.orNull) .load() .selectExpr("partition_id", "key", "value") } - private def getBytesReadDf(checkpointDir: String): DataFrame = { + private def getBytesReadDf( + checkpointDir: String, + storeName: Option[String] = Option.empty[String]): DataFrame = { spark.read .format("statestore") .option(StateSourceOptions.PATH, checkpointDir) .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + .option(StateSourceOptions.STORE_NAME, storeName.orNull) .load() .selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name") } @@ -107,7 +112,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase val valueConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) // Convert normal data to bytes - val normalAsBytes = normalDf.map { row => + val normalAsBytes = normalDf.toSeq.map { row => val key = row.getStruct(1) val value = if (row.isNullAt(2)) null else row.getStruct(2) @@ -178,7 +183,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase } } - testWithChangelogConfig("all-column-families: simple aggregation state ver 1") { + testWithChangelogConfig("SPARK-54388: simple aggregation state ver 1") { withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { withTempDir { tempDir => runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) @@ -202,7 +207,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase } } - testWithChangelogConfig("all-column-families: simple aggregation state ver 2") { + testWithChangelogConfig("SPARK-54388: simple aggregation state ver 2") { withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") { withTempDir { tempDir => runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) @@ -223,7 +228,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase } } - testWithChangelogConfig("all-column-families: composite key aggregation state ver 1") { + testWithChangelogConfig("SPARK-54388: composite key aggregation state ver 1") { withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { withTempDir { tempDir => runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) @@ -250,7 +255,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase } } - testWithChangelogConfig("all-column-families: composite key aggregation state ver 2") { + testWithChangelogConfig("SPARK-54388: composite key aggregation state ver 2") { withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") { withTempDir { tempDir => runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) @@ -274,7 +279,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase } } - testWithChangelogConfig("all-column-families: dropDuplicates validation") { + testWithChangelogConfig("SPARK-54388: dropDuplicates validation") { withTempDir { tempDir => runDropDuplicatesQuery(tempDir.getAbsolutePath) @@ -293,7 +298,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase } } - testWithChangelogConfig("all-column-families: dropDuplicates with column specified") { + testWithChangelogConfig("SPARK-54388: dropDuplicates with column specified") { withTempDir { tempDir => runDropDuplicatesQueryWithColumnSpecified(tempDir.getAbsolutePath) @@ -311,7 +316,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase } } - testWithChangelogConfig("all-column-families: dropDuplicatesWithinWatermark") { + testWithChangelogConfig("SPARK-54388: dropDuplicatesWithinWatermark") { withTempDir { tempDir => runDropDuplicatesWithinWatermarkQuery(tempDir.getAbsolutePath) @@ -329,18 +334,19 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase } } - testWithChangelogConfig("all-column-families: session window aggregation") { + testWithChangelogConfig("SPARK-54388: session window aggregation") { withTempDir { tempDir => runSessionWindowAggregationQuery(tempDir.getAbsolutePath) val keySchema = StructType(Array( StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), - StructField("sessionStartTime", org.apache.spark.sql.types.TimestampType, nullable = false) + StructField("sessionStartTime", + org.apache.spark.sql.types.TimestampType, nullable = false) )) val valueSchema = StructType(Array( StructField("session_window", org.apache.spark.sql.types.StructType(Array( - StructField("start", org.apache.spark.sql.types.TimestampType, nullable = true), - StructField("end", org.apache.spark.sql.types.TimestampType, nullable = true) + StructField("start", org.apache.spark.sql.types.TimestampType), + StructField("end", org.apache.spark.sql.types.TimestampType) )), nullable = false), StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), StructField("count", LongType, nullable = false) @@ -353,7 +359,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase } } - testWithChangelogConfig("all-column-families: flatMapGroupsWithState, state ver 1") { + testWithChangelogConfig("SPARK-54388: flatMapGroupsWithState, state ver 1") { // Skip this test on big endian platforms assume(java.nio.ByteOrder.nativeOrder().equals(java.nio.ByteOrder.LITTLE_ENDIAN)) withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "1") { @@ -379,7 +385,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase } } - testWithChangelogConfig("all-column-families: flatMapGroupsWithState, state ver 2") { + testWithChangelogConfig("SPARK-54388: flatMapGroupsWithState, state ver 2") { withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2") { withTempDir { tempDir => runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath) @@ -398,11 +404,78 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) } } } + + def testStreamStreamJoin(stateVersion: Int): Unit = { + withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> stateVersion.toString) { + withTempDir { tempDir => + runStreamStreamJoinQuery(tempDir.getAbsolutePath) + + Seq("right-keyToNumValues", "left-keyToNumValues").foreach(storeName => { + val stateReaderForRight = getNormalReadDf( + tempDir.getAbsolutePath, Option(storeName)) + val stateBytesDfForRight = getBytesReadDf( + tempDir.getAbsolutePath, Option(storeName)) + + val keyToNumValuesKeySchema = StructType(Array( + StructField("key", IntegerType) + )) + val keyToNumValueValueSchema = StructType(Array( + StructField("value", LongType) + )) + + compareNormalAndBytesData( + stateReaderForRight.collect(), + stateBytesDfForRight.collect(), + StateStore.DEFAULT_COL_FAMILY_NAME, + keyToNumValuesKeySchema, + keyToNumValueValueSchema) + }) + + Seq("right-keyWithIndexToValue", "left-keyWithIndexToValue").foreach(storeName => { + val stateReaderForRight = getNormalReadDf( + tempDir.getAbsolutePath, Option(storeName)) + val stateBytesDfForRight = getBytesReadDf( + tempDir.getAbsolutePath, Option(storeName)) + + val keyToNumValuesKeySchema = StructType(Array( + StructField("key", IntegerType, nullable = false), + StructField("index", LongType) + )) + val keyToNumValueValueSchema = if (stateVersion == 2) { + StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("time", TimestampType, nullable = false), + StructField("matched", BooleanType) + )) + } else { + StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("time", TimestampType, nullable = false) + )) + } + + compareNormalAndBytesData( + stateReaderForRight.collect(), + stateBytesDfForRight.collect(), + StateStore.DEFAULT_COL_FAMILY_NAME, + keyToNumValuesKeySchema, + keyToNumValueValueSchema) + }) + } + } + } + + testWithChangelogConfig("stream-stream join, state ver 1") { + testStreamStreamJoin(1) + } + + testWithChangelogConfig("stream-stream join, state ver 2") { + testStreamStreamJoin(2) + } } // End of foreach loop for changelog checkpointing dimension test("internalOnlyReadAllColumnFamilies should fail with HDFS-backed state store") { @@ -429,7 +502,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase ) checkError( - exception = intercept[SparkUnsupportedOperationException] { + exception = intercept[StateRepartitionUnsupportedProviderError] { spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) @@ -437,11 +510,12 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase .load() .collect() }, - condition = "STATE_REPARTITION_INVALID_STATE_STORE_CONFIG.UNSUPPORTED_PROVIDER", + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.UNSUPPORTED_PROVIDER", parameters = Map( - "configName" -> SQLConf.STATE_STORE_PROVIDER_CLASS.key, + "checkpointLocation" -> s".*${tempDir.getAbsolutePath}", "provider" -> classOf[HDFSBackedStateStoreProvider].getName - ) + ), + matchPVals = true ) } } From 48521c33627b811bd128ef3b127d9477f7968434 Mon Sep 17 00:00:00 2001 From: zifeif2 Date: Mon, 1 Dec 2025 16:59:44 +0000 Subject: [PATCH 11/12] get keySchema from stateStoreColFamilySchemaOpt --- .../v2/state/StatePartitionReader.scala | 39 +++++--------- .../state/OfflineStateRepartitionErrors.scala | 6 +-- ...artitionAllColumnFamiliesReaderSuite.scala | 51 ++++++++++++------- 3 files changed, 47 insertions(+), 49 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index a71a3aa2d4f3..9fc3c081173f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.sql.execution.datasources.v2.state -import scala.collection.mutable - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} @@ -53,7 +51,7 @@ class StatePartitionReaderFactory( val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) { new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf, - stateStoreInputPartition, schema, keyStateEncoderSpec) + stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt) } else if (stateStoreInputPartition.sourceOptions.readChangeFeed) { new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt, @@ -83,39 +81,25 @@ abstract class StatePartitionReaderBase( extends PartitionReader[InternalRow] with Logging { // Used primarily as a placeholder for the value schema in the context of // state variables used within the transformWithState operator. - // Also used as a placeholder for both key and value schema for - // StatePartitionAllColumnFamiliesReader - private val placeholderSchema: StructType = + private val schemaForValueRow: StructType = StructType(Array(StructField("__dummy__", NullType))) - private val colFamilyToSchema : mutable.HashMap[String, StateStoreColFamilySchema] = { - val stateStoreId = StateStoreId( - partition.sourceOptions.stateCheckpointLocation.toString, - partition.sourceOptions.operatorId, - StateStore.PARTITION_ID_TO_CHECK_SCHEMA, - partition.sourceOptions.storeName) - val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) - val manager = new StateSchemaCompatibilityChecker(stateStoreProviderId, hadoopConf.value) - val schemaFile = manager.readSchemaFile() - val schemaMap = mutable.HashMap[String, StateStoreColFamilySchema]() - schemaFile.foreach { schema => schemaMap.put(schema.colFamilyName, schema)} - schemaMap - } - - protected val keySchema = { + protected val keySchema : StructType = { if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { - colFamilyToSchema(StateStore.DEFAULT_COL_FAMILY_NAME).keySchema + require(stateStoreColFamilySchemaOpt.isDefined) + stateStoreColFamilySchemaOpt.map(_.keySchema).get } else { SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] } } - protected val valueSchema = if (stateVariableInfoOpt.isDefined) { - placeholderSchema + protected val valueSchema : StructType = if (stateVariableInfoOpt.isDefined) { + schemaForValueRow } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { - colFamilyToSchema(StateStore.DEFAULT_COL_FAMILY_NAME).valueSchema + require(stateStoreColFamilySchemaOpt.isDefined) + stateStoreColFamilySchemaOpt.map(_.valueSchema).get } else { SchemaUtil.getSchemaAsDataType( schema, "value").asInstanceOf[StructType] @@ -273,11 +257,12 @@ class StatePartitionAllColumnFamiliesReader( hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, schema: StructType, - keyStateEncoderSpec: KeyStateEncoderSpec) + keyStateEncoderSpec: KeyStateEncoderSpec, + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) extends StatePartitionReaderBase( storeConf, hadoopConf, partition, schema, - keyStateEncoderSpec, None, None, None, None) { + keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) { private lazy val store: ReadStateStore = { assert(getStartStoreUniqueId == getEndStoreUniqueId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala index 16e4837fc342..b8c39b4a160b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala @@ -88,8 +88,7 @@ object OfflineStateRepartitionErrors { def unsupportedStateStoreProviderError( checkpointLocation: String, - providerClass: String - ): StateRepartitionUnsupportedProviderError = { + providerClass: String): StateRepartitionUnsupportedProviderError = { new StateRepartitionUnsupportedProviderError(checkpointLocation, providerClass) } } @@ -211,7 +210,8 @@ class StateRepartitionUnsupportedOffsetSeqVersionError( class StateRepartitionUnsupportedProviderError( checkpointLocation: String, - provider: String) extends StateRepartitionInvalidCheckpointError( + provider: String) + extends StateRepartitionInvalidCheckpointError( checkpointLocation, subClass = "UNSUPPORTED_PROVIDER", messageParameters = Map("provider" -> provider)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala index 1dce7467a7af..c4b59b149b96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala @@ -64,7 +64,6 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") .option(StateSourceOptions.STORE_NAME, storeName.orNull) .load() - .selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name") } /** @@ -221,9 +220,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase )) val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) } } } @@ -248,9 +248,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase )) val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) } } } @@ -272,9 +273,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase )) val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) } } } @@ -292,9 +294,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase )) val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) } } @@ -310,9 +313,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase )) val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) } } @@ -328,9 +332,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase )) val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) } } @@ -353,9 +358,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase )) val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) } } @@ -378,9 +384,11 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase )) val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData( + normalData, bytesDf.collect(), "default", keySchema, valueSchema) } } } @@ -403,8 +411,11 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase )) val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() - val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect() - compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema) + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData( + normalData, bytesDf.collect(), "default", keySchema, valueSchema) } } } @@ -427,6 +438,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase StructField("value", LongType) )) + validateBytesReadDfSchema(stateBytesDfForRight) compareNormalAndBytesData( stateReaderForRight.collect(), stateBytesDfForRight.collect(), @@ -458,6 +470,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase )) } + validateBytesReadDfSchema(stateBytesDfForRight) compareNormalAndBytesData( stateReaderForRight.collect(), stateBytesDfForRight.collect(), From 6003f545cf40f60aeebdb6151db1a0a4403abaec Mon Sep 17 00:00:00 2001 From: zifeif2 Date: Tue, 2 Dec 2025 05:02:41 +0000 Subject: [PATCH 12/12] address comment --- .../sql/execution/datasources/v2/state/StateDataSource.scala | 1 + .../streaming/state/OfflineStateRepartitionErrors.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index c07c67e657f5..f49ced7a1c22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -490,6 +490,7 @@ object StateSourceOptions extends DataSourceOptions { s"Valid values are ${JoinSideValues.values.mkString(",")}") } + // Use storeName rather than joinSide to identify the specific join store if (joinSide != JoinSideValues.none && storeName != StateStoreId.DEFAULT_STORE_NAME) { throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala index b8c39b4a160b..0e9b8ad8a63b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala @@ -88,7 +88,7 @@ object OfflineStateRepartitionErrors { def unsupportedStateStoreProviderError( checkpointLocation: String, - providerClass: String): StateRepartitionUnsupportedProviderError = { + providerClass: String): StateRepartitionInvalidCheckpointError = { new StateRepartitionUnsupportedProviderError(checkpointLocation, providerClass) } }