Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5493,6 +5493,11 @@
"message" : [
"Unsupported offset sequence version <version>. Please make sure the checkpoint is from a supported Spark version (Spark 4.0+)."
]
},
"UNSUPPORTED_PROVIDER" : {
"message" : [
"<provider> is not supported"
]
}
},
"sqlState" : "55019"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.transformwith
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata
import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state.{InMemoryStateSchemaProvider, KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaCompatibilityChecker, StateSchemaMetadata, StateSchemaProvider, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state.OfflineStateRepartitionErrors
import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.streaming.TimeMode
Expand All @@ -66,6 +67,14 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
StateSourceOptions.apply(session, hadoopConf, properties))
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
// We only support RocksDB because the repartition work that this option
// is built for only supports RocksDB
if (sourceOptions.internalOnlyReadAllColumnFamilies
&& stateConf.providerClass != classOf[RocksDBStateStoreProvider].getName) {
throw OfflineStateRepartitionErrors.unsupportedStateStoreProviderError(
sourceOptions.resolvedCpLocation,
stateConf.providerClass)
}
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
sourceOptions)

Expand Down Expand Up @@ -372,15 +381,17 @@ case class StateSourceOptions(
stateVarName: Option[String],
readRegisteredTimers: Boolean,
flattenCollectionTypes: Boolean,
internalOnlyReadAllColumnFamilies: Boolean = false,
startOperatorStateUniqueIds: Option[Array[Array[String]]] = None,
endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " +
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
s"stateVarName=${stateVarName.getOrElse("None")}, +" +
s"flattenCollectionTypes=$flattenCollectionTypes"
s"stateVarName=${stateVarName.getOrElse("None")}, " +
s"flattenCollectionTypes=$flattenCollectionTypes, " +
s"internalOnlyReadAllColumnFamilies=$internalOnlyReadAllColumnFamilies"
if (fromSnapshotOptions.isDefined) {
desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}"
desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}"
Expand All @@ -407,6 +418,7 @@ object StateSourceOptions extends DataSourceOptions {
val STATE_VAR_NAME = newOption("stateVarName")
val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers")
val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes")
val INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = newOption("internalOnlyReadAllColumnFamilies")

object JoinSideValues extends Enumeration {
type JoinSideValues = Value
Expand Down Expand Up @@ -478,6 +490,7 @@ object StateSourceOptions extends DataSourceOptions {
s"Valid values are ${JoinSideValues.values.mkString(",")}")
}

// Use storeName rather than joinSide to identify the specific join store
if (joinSide != JoinSideValues.none && storeName != StateStoreId.DEFAULT_STORE_NAME) {
throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME))
}
Expand All @@ -492,6 +505,29 @@ object StateSourceOptions extends DataSourceOptions {

val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean)

val internalOnlyReadAllColumnFamilies = try {
Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES)).exists(_.toBoolean)
} catch {
case _: IllegalArgumentException =>
throw StateDataSourceErrors.invalidOptionValue(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
"Boolean value is expected")
}

if (internalOnlyReadAllColumnFamilies && stateVarName.isDefined) {
throw StateDataSourceErrors.conflictOptions(
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME))
}

if (internalOnlyReadAllColumnFamilies && joinSide != JoinSideValues.none) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

throw StateDataSourceErrors.conflictOptions(
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, JOIN_SIDE))
}

if (internalOnlyReadAllColumnFamilies && readChangeFeed) {
throw StateDataSourceErrors.conflictOptions(
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED))
}

val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong)
var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong)

Expand Down Expand Up @@ -615,7 +651,7 @@ object StateSourceOptions extends DataSourceOptions {
StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
stateVarName, readRegisteredTimers, flattenCollectionTypes,
stateVarName, readRegisteredTimers, flattenCollectionTypes, internalOnlyReadAllColumnFamilies,
startOperatorStateUniqueIds, endOperatorStateUniqueIds)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ class StatePartitionReaderFactory(

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) {
new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf,
stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt)
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt)
Expand Down Expand Up @@ -81,16 +84,22 @@ abstract class StatePartitionReaderBase(
private val schemaForValueRow: StructType =
StructType(Array(StructField("__dummy__", NullType)))

protected val keySchema = {
protected val keySchema : StructType = {
if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) {
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
} else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
require(stateStoreColFamilySchemaOpt.isDefined)
stateStoreColFamilySchemaOpt.map(_.keySchema).get
} else {
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
}
}

protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
protected val valueSchema : StructType = if (stateVariableInfoOpt.isDefined) {
schemaForValueRow
} else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
require(stateStoreColFamilySchemaOpt.isDefined)
stateStoreColFamilySchemaOpt.map(_.valueSchema).get
} else {
SchemaUtil.getSchemaAsDataType(
schema, "value").asInstanceOf[StructType]
Expand Down Expand Up @@ -237,6 +246,48 @@ class StatePartitionReader(
}
}

/**
* An implementation of [[StatePartitionReaderBase]] for reading all column families
* in binary format. This reader returns raw key and value bytes along with column family names.
* We are returning key/value bytes because each column family can have different schema
* It will also return the partition key
*/
class StatePartitionAllColumnFamiliesReader(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
partition: StateStoreInputPartition,
schema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
extends StatePartitionReaderBase(
storeConf,
hadoopConf, partition, schema,
keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {

private lazy val store: ReadStateStore = {
assert(getStartStoreUniqueId == getEndStoreUniqueId,
"Start and end store unique IDs must be the same when reading all column families")
provider.getReadStore(
partition.sourceOptions.batchId + 1,
getStartStoreUniqueId
)
}

override lazy val iter: Iterator[InternalRow] = {
store
.iterator()
.map { pair =>
SchemaUtil.unifyStateRowPairAsRawBytes(
(pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME)
}
}

override def close(): Unit = {
store.release()
super.close()
}
}

/**
* An implementation of [[StatePartitionReaderBase]] for the readChangeFeed mode of State Data
* Source. It reads the change of state over batches of a particular partition.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceError
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType._
import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStoreColFamilySchema, UnsafeRowPair}
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType, LongType, MapType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._

object SchemaUtil {
Expand Down Expand Up @@ -60,6 +61,14 @@ object SchemaUtil {
.add("key", keySchema)
.add("value", valueSchema)
.add("partition_id", IntegerType)
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
new StructType()
// todo [SPARK-54443]: change keySchema to a more specific type after we
// can extract partition key from keySchema
.add("partition_key", keySchema)
.add("key_bytes", BinaryType)
.add("value_bytes", BinaryType)
.add("column_family_name", StringType)
} else {
new StructType()
.add("key", keySchema)
Expand All @@ -76,6 +85,26 @@ object SchemaUtil {
row
}

/**
* Returns an InternalRow representing
* 1. partitionKey
* 2. key in bytes
* 3. value in bytes
* 4. column family name
*/
def unifyStateRowPairAsRawBytes(
pair: (UnsafeRow, UnsafeRow),
colFamilyName: String): InternalRow = {
val row = new GenericInternalRow(4)
// todo [SPARK-54443]: change keySchema to more specific type after we
// can extract partition key from keySchema
row.update(0, pair._1)
row.update(1, pair._1.getBytes)
row.update(2, pair._2.getBytes)
row.update(3, UTF8String.fromString(colFamilyName))
row
}

def unifyStateRowPairWithMultipleValues(
pair: (UnsafeRow, GenericArrayData),
partition: Int): InternalRow = {
Expand Down Expand Up @@ -231,7 +260,11 @@ object SchemaUtil {
"user_map_key" -> classOf[StructType],
"user_map_value" -> classOf[StructType],
"expiration_timestamp_ms" -> classOf[LongType],
"partition_id" -> classOf[IntegerType])
"partition_id" -> classOf[IntegerType],
"partition_key" -> classOf[StructType],
"key_bytes" -> classOf[BinaryType],
"value_bytes" -> classOf[BinaryType],
"column_family_name" -> classOf[StringType])

val expectedFieldNames = if (transformWithStateVariableInfoOpt.isDefined) {
val stateVarInfo = transformWithStateVariableInfoOpt.get
Expand Down Expand Up @@ -272,6 +305,8 @@ object SchemaUtil {
}
} else if (sourceOptions.readChangeFeed) {
Seq("batch_id", "change_type", "key", "value", "partition_id")
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
Seq("partition_key", "key_bytes", "value_bytes", "column_family_name")
} else {
Seq("key", "value", "partition_id")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ object OfflineStateRepartitionErrors {
version: Int): StateRepartitionInvalidCheckpointError = {
new StateRepartitionUnsupportedOffsetSeqVersionError(checkpointLocation, version)
}

def unsupportedStateStoreProviderError(
checkpointLocation: String,
providerClass: String): StateRepartitionInvalidCheckpointError = {
new StateRepartitionUnsupportedProviderError(checkpointLocation, providerClass)
}
}

/**
Expand Down Expand Up @@ -201,3 +207,11 @@ class StateRepartitionUnsupportedOffsetSeqVersionError(
checkpointLocation,
subClass = "UNSUPPORTED_OFFSET_SEQ_VERSION",
messageParameters = Map("version" -> version.toString))

class StateRepartitionUnsupportedProviderError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still haven't fixed this indentation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2025-12-01 at 9 57 46 PM

I asked Claude and still couldn't figure out what's the issue with the indentation 😅 , could you specify? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zifeif2 probably my eyes. You can resolve it.

checkpointLocation: String,
provider: String)
extends StateRepartitionInvalidCheckpointError(
checkpointLocation,
subClass = "UNSUPPORTED_PROVIDER",
messageParameters = Map("provider" -> provider))
Loading