diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 492a33c57461..723efbea715e 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5493,6 +5493,11 @@ "message" : [ "Unsupported offset sequence version . Please make sure the checkpoint is from a supported Spark version (Spark 4.0+)." ] + }, + "UNSUPPORTED_PROVIDER" : { + "message" : [ + " is not supported" + ] } }, "sqlState" : "55019" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 6af418e1ddc2..f49ced7a1c22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -41,7 +41,8 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata -import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors import org.apache.spark.sql.execution.streaming.utils.StreamingUtils import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.streaming.TimeMode @@ -66,6 +67,14 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf, StateSourceOptions.apply(session, hadoopConf, properties)) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) + // We only support RocksDB because the repartition work that this option + // is built for only supports RocksDB + if (sourceOptions.internalOnlyReadAllColumnFamilies + && stateConf.providerClass != classOf[RocksDBStateStoreProvider].getName) { + throw OfflineStateRepartitionErrors.unsupportedStateStoreProviderError( + sourceOptions.resolvedCpLocation, + stateConf.providerClass) + } val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks( sourceOptions) @@ -372,6 +381,7 @@ case class StateSourceOptions( stateVarName: Option[String], readRegisteredTimers: Boolean, flattenCollectionTypes: Boolean, + internalOnlyReadAllColumnFamilies: Boolean = false, startOperatorStateUniqueIds: Option[Array[Array[String]]] = None, endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) @@ -379,8 +389,9 @@ case class StateSourceOptions( override def toString: String = { var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " + - s"stateVarName=${stateVarName.getOrElse("None")}, +" + - s"flattenCollectionTypes=$flattenCollectionTypes" + s"stateVarName=${stateVarName.getOrElse("None")}, " + + s"flattenCollectionTypes=$flattenCollectionTypes, " + + s"internalOnlyReadAllColumnFamilies=$internalOnlyReadAllColumnFamilies" if (fromSnapshotOptions.isDefined) { desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}" desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}" @@ -407,6 +418,7 @@ object StateSourceOptions extends DataSourceOptions { val STATE_VAR_NAME = newOption("stateVarName") val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers") val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes") + val INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = newOption("internalOnlyReadAllColumnFamilies") object JoinSideValues extends Enumeration { type JoinSideValues = Value @@ -478,6 +490,7 @@ object StateSourceOptions extends DataSourceOptions { s"Valid values are ${JoinSideValues.values.mkString(",")}") } + // Use storeName rather than joinSide to identify the specific join store if (joinSide != JoinSideValues.none && storeName != StateStoreId.DEFAULT_STORE_NAME) { throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME)) } @@ -492,6 +505,29 @@ object StateSourceOptions extends DataSourceOptions { val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean) + val internalOnlyReadAllColumnFamilies = try { + Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES)).exists(_.toBoolean) + } catch { + case _: IllegalArgumentException => + throw StateDataSourceErrors.invalidOptionValue(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, + "Boolean value is expected") + } + + if (internalOnlyReadAllColumnFamilies && stateVarName.isDefined) { + throw StateDataSourceErrors.conflictOptions( + Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME)) + } + + if (internalOnlyReadAllColumnFamilies && joinSide != JoinSideValues.none) { + throw StateDataSourceErrors.conflictOptions( + Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, JOIN_SIDE)) + } + + if (internalOnlyReadAllColumnFamilies && readChangeFeed) { + throw StateDataSourceErrors.conflictOptions( + Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED)) + } + val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong) var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong) @@ -615,7 +651,7 @@ object StateSourceOptions extends DataSourceOptions { StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, - stateVarName, readRegisteredTimers, flattenCollectionTypes, + stateVarName, readRegisteredTimers, flattenCollectionTypes, internalOnlyReadAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 619e374c00de..9fc3c081173f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -49,7 +49,10 @@ class StatePartitionReaderFactory( override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] - if (stateStoreInputPartition.sourceOptions.readChangeFeed) { + if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) { + new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf, + stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt) + } else if (stateStoreInputPartition.sourceOptions.readChangeFeed) { new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt) @@ -81,16 +84,22 @@ abstract class StatePartitionReaderBase( private val schemaForValueRow: StructType = StructType(Array(StructField("__dummy__", NullType))) - protected val keySchema = { + protected val keySchema : StructType = { if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) + } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { + require(stateStoreColFamilySchemaOpt.isDefined) + stateStoreColFamilySchemaOpt.map(_.keySchema).get } else { SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] } } - protected val valueSchema = if (stateVariableInfoOpt.isDefined) { + protected val valueSchema : StructType = if (stateVariableInfoOpt.isDefined) { schemaForValueRow + } else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) { + require(stateStoreColFamilySchemaOpt.isDefined) + stateStoreColFamilySchemaOpt.map(_.valueSchema).get } else { SchemaUtil.getSchemaAsDataType( schema, "value").asInstanceOf[StructType] @@ -237,6 +246,48 @@ class StatePartitionReader( } } +/** + * An implementation of [[StatePartitionReaderBase]] for reading all column families + * in binary format. This reader returns raw key and value bytes along with column family names. + * We are returning key/value bytes because each column family can have different schema + * It will also return the partition key + */ +class StatePartitionAllColumnFamiliesReader( + storeConf: StateStoreConf, + hadoopConf: SerializableConfiguration, + partition: StateStoreInputPartition, + schema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) + extends StatePartitionReaderBase( + storeConf, + hadoopConf, partition, schema, + keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) { + + private lazy val store: ReadStateStore = { + assert(getStartStoreUniqueId == getEndStoreUniqueId, + "Start and end store unique IDs must be the same when reading all column families") + provider.getReadStore( + partition.sourceOptions.batchId + 1, + getStartStoreUniqueId + ) + } + + override lazy val iter: Iterator[InternalRow] = { + store + .iterator() + .map { pair => + SchemaUtil.unifyStateRowPairAsRawBytes( + (pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME) + } + } + + override def close(): Unit = { + store.release() + super.close() + } +} + /** * An implementation of [[StatePartitionReaderBase]] for the readChangeFeed mode of State Data * Source. It reads the change of state over batches of a particular partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 52df016791d4..44d83fc99b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceError import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType._ import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStoreColFamilySchema, UnsafeRowPair} -import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType, LongType, MapType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ object SchemaUtil { @@ -60,6 +61,14 @@ object SchemaUtil { .add("key", keySchema) .add("value", valueSchema) .add("partition_id", IntegerType) + } else if (sourceOptions.internalOnlyReadAllColumnFamilies) { + new StructType() + // todo [SPARK-54443]: change keySchema to a more specific type after we + // can extract partition key from keySchema + .add("partition_key", keySchema) + .add("key_bytes", BinaryType) + .add("value_bytes", BinaryType) + .add("column_family_name", StringType) } else { new StructType() .add("key", keySchema) @@ -76,6 +85,26 @@ object SchemaUtil { row } + /** + * Returns an InternalRow representing + * 1. partitionKey + * 2. key in bytes + * 3. value in bytes + * 4. column family name + */ + def unifyStateRowPairAsRawBytes( + pair: (UnsafeRow, UnsafeRow), + colFamilyName: String): InternalRow = { + val row = new GenericInternalRow(4) + // todo [SPARK-54443]: change keySchema to more specific type after we + // can extract partition key from keySchema + row.update(0, pair._1) + row.update(1, pair._1.getBytes) + row.update(2, pair._2.getBytes) + row.update(3, UTF8String.fromString(colFamilyName)) + row + } + def unifyStateRowPairWithMultipleValues( pair: (UnsafeRow, GenericArrayData), partition: Int): InternalRow = { @@ -231,7 +260,11 @@ object SchemaUtil { "user_map_key" -> classOf[StructType], "user_map_value" -> classOf[StructType], "expiration_timestamp_ms" -> classOf[LongType], - "partition_id" -> classOf[IntegerType]) + "partition_id" -> classOf[IntegerType], + "partition_key" -> classOf[StructType], + "key_bytes" -> classOf[BinaryType], + "value_bytes" -> classOf[BinaryType], + "column_family_name" -> classOf[StringType]) val expectedFieldNames = if (transformWithStateVariableInfoOpt.isDefined) { val stateVarInfo = transformWithStateVariableInfoOpt.get @@ -272,6 +305,8 @@ object SchemaUtil { } } else if (sourceOptions.readChangeFeed) { Seq("batch_id", "change_type", "key", "value", "partition_id") + } else if (sourceOptions.internalOnlyReadAllColumnFamilies) { + Seq("partition_key", "key_bytes", "value_bytes", "column_family_name") } else { Seq("key", "value", "partition_id") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala index 95b273826877..0e9b8ad8a63b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala @@ -85,6 +85,12 @@ object OfflineStateRepartitionErrors { version: Int): StateRepartitionInvalidCheckpointError = { new StateRepartitionUnsupportedOffsetSeqVersionError(checkpointLocation, version) } + + def unsupportedStateStoreProviderError( + checkpointLocation: String, + providerClass: String): StateRepartitionInvalidCheckpointError = { + new StateRepartitionUnsupportedProviderError(checkpointLocation, providerClass) + } } /** @@ -201,3 +207,11 @@ class StateRepartitionUnsupportedOffsetSeqVersionError( checkpointLocation, subClass = "UNSUPPORTED_OFFSET_SEQ_VERSION", messageParameters = Map("version" -> version.toString)) + +class StateRepartitionUnsupportedProviderError( + checkpointLocation: String, + provider: String) + extends StateRepartitionInvalidCheckpointError( + checkpointLocation, + subClass = "UNSUPPORTED_PROVIDER", + messageParameters = Map("provider" -> provider)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala new file mode 100644 index 000000000000..c4b59b149b96 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.nio.ByteOrder +import java.util.Arrays + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateRepartitionUnsupportedProviderError, StateStore} +import org.apache.spark.sql.functions.{count, sum} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, NullType, StructField, StructType, TimestampType} + +/** + * Note: This extends StateDataSourceTestBase to access + * helper methods like runDropDuplicatesQuery without inheriting all predefined tests. + */ +class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + classOf[RocksDBStateStoreProvider].getName) + } + + private def getNormalReadDf( + checkpointDir: String, + storeName: Option[String] = Option.empty[String]): DataFrame = { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .option(StateSourceOptions.STORE_NAME, storeName.orNull) + .load() + .selectExpr("partition_id", "key", "value") + } + + private def getBytesReadDf( + checkpointDir: String, + storeName: Option[String] = Option.empty[String]): DataFrame = { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, checkpointDir) + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + .option(StateSourceOptions.STORE_NAME, storeName.orNull) + .load() + } + + /** + * Validates the schema and column families of the bytes read DataFrame. + */ + private def validateBytesReadDfSchema(df: DataFrame): Unit = { + // Verify schema + val schema = df.schema + assert(schema.fieldNames === Array( + "partition_key", "key_bytes", "value_bytes", "column_family_name")) + assert(schema("partition_key").dataType.typeName === "struct") + assert(schema("key_bytes").dataType.typeName === "binary") + assert(schema("value_bytes").dataType.typeName === "binary") + assert(schema("column_family_name").dataType.typeName === "string") + } + + /** + * Compares normal read data with bytes read data for a specific column family. + * Converts normal rows to bytes then compares with bytes read. + */ + private def compareNormalAndBytesData( + normalDf: Array[Row], + bytesDf: Array[Row], + columnFamily: String, + keySchema: StructType, + valueSchema: StructType): Unit = { + + // Filter bytes data for the specified column family and extract raw bytes directly + val filteredBytesData = bytesDf.filter { row => + row.getString(3) == columnFamily + } + + // Verify same number of rows + assert(filteredBytesData.length == normalDf.length, + s"Row count mismatch for column family '$columnFamily': " + + s"normal read has ${normalDf.length} rows, " + + s"bytes read has ${filteredBytesData.length} rows") + + // Create projections to convert Row to UnsafeRow bytes + val keyProjection = UnsafeProjection.create(keySchema) + val valueProjection = UnsafeProjection.create(valueSchema) + + // Create converters to convert external Row types to internal Catalyst types + val keyConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) + val valueConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) + + // Convert normal data to bytes + val normalAsBytes = normalDf.toSeq.map { row => + val key = row.getStruct(1) + val value = if (row.isNullAt(2)) null else row.getStruct(2) + + // Convert key to InternalRow, then to UnsafeRow, then get bytes + val keyInternalRow = keyConverter(key).asInstanceOf[InternalRow] + val keyUnsafeRow = keyProjection(keyInternalRow) + // IMPORTANT: Must clone the bytes array since getBytes() returns a reference + // that may be overwritten by subsequent UnsafeRow operations + val keyBytes = keyUnsafeRow.getBytes.clone() + + // Convert value to bytes + val valueBytes = if (value == null) { + Array.empty[Byte] + } else { + val valueInternalRow = valueConverter(value).asInstanceOf[InternalRow] + val valueUnsafeRow = valueProjection(valueInternalRow) + // IMPORTANT: Must clone the bytes array since getBytes() returns a reference + // that may be overwritten by subsequent UnsafeRow operations + valueUnsafeRow.getBytes.clone() + } + + (keyBytes, valueBytes) + } + + // Extract raw bytes from bytes read data (no deserialization/reserialization) + val bytesAsBytes = filteredBytesData.map { row => + val keyBytes = row.getAs[Array[Byte]](1) + val valueBytes = row.getAs[Array[Byte]](2) + (keyBytes, valueBytes) + } + + // Sort both for comparison (since Set equality doesn't work well with byte arrays) + val normalSorted = normalAsBytes.sortBy(x => (x._1.mkString(","), x._2.mkString(","))) + val bytesSorted = bytesAsBytes.sortBy(x => (x._1.mkString(","), x._2.mkString(","))) + + assert(normalSorted.length == bytesSorted.length, + s"Size mismatch: normal has ${normalSorted.length}, bytes has ${bytesSorted.length}") + + // Compare each pair + normalSorted.zip(bytesSorted).zipWithIndex.foreach { + case (((normalKey, normalValue), (bytesKey, bytesValue)), idx) => + assert(Arrays.equals(normalKey, bytesKey), + s"Key mismatch at index $idx:\n" + + s" Normal: ${normalKey.mkString("[", ",", "]")}\n" + + s" Bytes: ${bytesKey.mkString("[", ",", "]")}") + assert(Arrays.equals(normalValue, bytesValue), + s"Value mismatch at index $idx:\n" + + s" Normal: ${normalValue.mkString("[", ",", "]")}\n" + + s" Bytes: ${bytesValue.mkString("[", ",", "]")}") + } + } + + // Run all tests with both changelog checkpointing enabled and disabled + Seq(true, false).foreach { changelogCheckpointingEnabled => + val testSuffix = if (changelogCheckpointingEnabled) { + "with changelog checkpointing" + } else { + "without changelog checkpointing" + } + + def testWithChangelogConfig(testName: String)(testFun: => Unit): Unit = { + test(s"$testName ($testSuffix)") { + withSQLConf( + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> + changelogCheckpointingEnabled.toString) { + testFun + } + } + } + + testWithChangelogConfig("SPARK-54388: simple aggregation state ver 1") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array(StructField("groupKey", IntegerType, nullable = false))) + // State version 1 includes key columns in the value + val valueSchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("SPARK-54388: simple aggregation state ver 2") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array(StructField("groupKey", IntegerType, nullable = false))) + val valueSchema = StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("SPARK-54388: composite key aggregation state ver 1") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true) + )) + // State version 1 includes key columns in the value + val valueSchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true), + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("SPARK-54388: composite key aggregation state ver 2") { + withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runCompositeKeyStreamingAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("groupKey", IntegerType, nullable = false), + StructField("fruit", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("count", LongType, nullable = false), + StructField("sum", LongType, nullable = false), + StructField("max", IntegerType, nullable = false), + StructField("min", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("SPARK-54388: dropDuplicates validation") { + withTempDir { tempDir => + runDropDuplicatesQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("eventTime", org.apache.spark.sql.types.TimestampType) + )) + val valueSchema = StructType(Array( + StructField("__dummy__", NullType, nullable = true) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("SPARK-54388: dropDuplicates with column specified") { + withTempDir { tempDir => + runDropDuplicatesQueryWithColumnSpecified(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("col1", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("__dummy__", NullType, nullable = true) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("SPARK-54388: dropDuplicatesWithinWatermark") { + withTempDir { tempDir => + runDropDuplicatesWithinWatermarkQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("_1", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("expiresAtMicros", LongType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("SPARK-54388: session window aggregation") { + withTempDir { tempDir => + runSessionWindowAggregationQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), + StructField("sessionStartTime", + org.apache.spark.sql.types.TimestampType, nullable = false) + )) + val valueSchema = StructType(Array( + StructField("session_window", org.apache.spark.sql.types.StructType(Array( + StructField("start", org.apache.spark.sql.types.TimestampType), + StructField("end", org.apache.spark.sql.types.TimestampType) + )), nullable = false), + StructField("sessionId", org.apache.spark.sql.types.StringType, nullable = false), + StructField("count", LongType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + + testWithChangelogConfig("SPARK-54388: flatMapGroupsWithState, state ver 1") { + // Skip this test on big endian platforms + assume(java.nio.ByteOrder.nativeOrder().equals(java.nio.ByteOrder.LITTLE_ENDIAN)) + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "1") { + withTempDir { tempDir => + assume(ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN)) + runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("value", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("numEvents", IntegerType, nullable = false), + StructField("startTimestampMs", LongType, nullable = false), + StructField("endTimestampMs", LongType, nullable = false), + StructField("timeoutTimestamp", IntegerType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData( + normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + testWithChangelogConfig("SPARK-54388: flatMapGroupsWithState, state ver 2") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2") { + withTempDir { tempDir => + runFlatMapGroupsWithStateQuery(tempDir.getAbsolutePath) + + val keySchema = StructType(Array( + StructField("value", org.apache.spark.sql.types.StringType, nullable = true) + )) + val valueSchema = StructType(Array( + StructField("groupState", org.apache.spark.sql.types.StructType(Array( + StructField("numEvents", IntegerType, nullable = false), + StructField("startTimestampMs", LongType, nullable = false), + StructField("endTimestampMs", LongType, nullable = false) + )), nullable = false), + StructField("timeoutTimestamp", LongType, nullable = false) + )) + + val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect() + val bytesDf = getBytesReadDf(tempDir.getAbsolutePath) + + validateBytesReadDfSchema(bytesDf) + compareNormalAndBytesData( + normalData, bytesDf.collect(), "default", keySchema, valueSchema) + } + } + } + + def testStreamStreamJoin(stateVersion: Int): Unit = { + withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> stateVersion.toString) { + withTempDir { tempDir => + runStreamStreamJoinQuery(tempDir.getAbsolutePath) + + Seq("right-keyToNumValues", "left-keyToNumValues").foreach(storeName => { + val stateReaderForRight = getNormalReadDf( + tempDir.getAbsolutePath, Option(storeName)) + val stateBytesDfForRight = getBytesReadDf( + tempDir.getAbsolutePath, Option(storeName)) + + val keyToNumValuesKeySchema = StructType(Array( + StructField("key", IntegerType) + )) + val keyToNumValueValueSchema = StructType(Array( + StructField("value", LongType) + )) + + validateBytesReadDfSchema(stateBytesDfForRight) + compareNormalAndBytesData( + stateReaderForRight.collect(), + stateBytesDfForRight.collect(), + StateStore.DEFAULT_COL_FAMILY_NAME, + keyToNumValuesKeySchema, + keyToNumValueValueSchema) + }) + + Seq("right-keyWithIndexToValue", "left-keyWithIndexToValue").foreach(storeName => { + val stateReaderForRight = getNormalReadDf( + tempDir.getAbsolutePath, Option(storeName)) + val stateBytesDfForRight = getBytesReadDf( + tempDir.getAbsolutePath, Option(storeName)) + + val keyToNumValuesKeySchema = StructType(Array( + StructField("key", IntegerType, nullable = false), + StructField("index", LongType) + )) + val keyToNumValueValueSchema = if (stateVersion == 2) { + StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("time", TimestampType, nullable = false), + StructField("matched", BooleanType) + )) + } else { + StructType(Array( + StructField("value", IntegerType, nullable = false), + StructField("time", TimestampType, nullable = false) + )) + } + + validateBytesReadDfSchema(stateBytesDfForRight) + compareNormalAndBytesData( + stateReaderForRight.collect(), + stateBytesDfForRight.collect(), + StateStore.DEFAULT_COL_FAMILY_NAME, + keyToNumValuesKeySchema, + keyToNumValueValueSchema) + }) + } + } + } + + testWithChangelogConfig("stream-stream join, state ver 1") { + testStreamStreamJoin(1) + } + + testWithChangelogConfig("stream-stream join, state ver 2") { + testStreamStreamJoin(2) + } + } // End of foreach loop for changelog checkpointing dimension + + test("internalOnlyReadAllColumnFamilies should fail with HDFS-backed state store") { + withTempDir { tempDir => + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[HDFSBackedStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF() + .selectExpr("value", "value % 10 AS groupKey") + .groupBy($"groupKey") + .agg( + count("*").as("cnt"), + sum("value").as("sum") + ) + .as[(Int, Long, Long)] + + testStream(aggregated, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 0 until 1: _*), + CheckLastBatch((0, 1, 0)), + StopStream + ) + + checkError( + exception = intercept[StateRepartitionUnsupportedProviderError] { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true") + .load() + .collect() + }, + condition = "STATE_REPARTITION_INVALID_CHECKPOINT.UNSUPPORTED_PROVIDER", + parameters = Map( + "checkpointLocation" -> s".*${tempDir.getAbsolutePath}", + "provider" -> classOf[HDFSBackedStateStoreProvider].getName + ), + matchPVals = true + ) + } + } + } +}