Skip to content

Commit 7e75b40

Browse files
author
Ubuntu
committed
scan simple operator state
1 parent 8cab074 commit 7e75b40

File tree

5 files changed

+427
-21
lines changed

5 files changed

+427
-21
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,38 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
6666
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
6767
StateSourceOptions.apply(session, hadoopConf, properties))
6868
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
69-
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
70-
sourceOptions)
71-
72-
// The key state encoder spec should be available for all operators except stream-stream joins
73-
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
74-
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
69+
if (sourceOptions.readAllColumnFamilies) {
70+
// For readAllColumnFamilies mode, we don't need specific metadata
71+
val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType())
72+
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
73+
None, None, None, None)
7574
} else {
76-
val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
77-
NoPrefixKeyStateEncoderSpec(keySchema)
78-
}
75+
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
76+
sourceOptions)
7977

80-
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
81-
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
82-
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
83-
stateStoreReaderInfo.stateSchemaProviderOpt,
84-
stateStoreReaderInfo.joinColFamilyOpt)
78+
// The key state encoder spec should be available for all operators except stream-stream joins
79+
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
80+
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
81+
} else {
82+
val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
83+
NoPrefixKeyStateEncoderSpec(keySchema)
84+
}
85+
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
86+
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
87+
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
88+
stateStoreReaderInfo.stateSchemaProviderOpt,
89+
stateStoreReaderInfo.joinColFamilyOpt)
90+
}
8591
}
8692

8793
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
8894
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
8995
StateSourceOptions.apply(session, hadoopConf, options))
90-
96+
if (sourceOptions.readAllColumnFamilies) {
97+
// For readAllColumnFamilies mode, return the binary schema directly
98+
return SchemaUtil.getSourceSchema(
99+
sourceOptions, new StructType(), new StructType(), None, None)
100+
}
91101
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
92102
sourceOptions)
93103
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
@@ -372,6 +382,7 @@ case class StateSourceOptions(
372382
stateVarName: Option[String],
373383
readRegisteredTimers: Boolean,
374384
flattenCollectionTypes: Boolean,
385+
readAllColumnFamilies: Boolean,
375386
startOperatorStateUniqueIds: Option[Array[Array[String]]] = None,
376387
endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) {
377388
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)
@@ -380,7 +391,8 @@ case class StateSourceOptions(
380391
var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " +
381392
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
382393
s"stateVarName=${stateVarName.getOrElse("None")}, +" +
383-
s"flattenCollectionTypes=$flattenCollectionTypes"
394+
s"flattenCollectionTypes=$flattenCollectionTypes" +
395+
s"readAllColumnFamilies=$readAllColumnFamilies"
384396
if (fromSnapshotOptions.isDefined) {
385397
desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}"
386398
desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}"
@@ -407,6 +419,7 @@ object StateSourceOptions extends DataSourceOptions {
407419
val STATE_VAR_NAME = newOption("stateVarName")
408420
val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers")
409421
val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes")
422+
val READ_ALL_COLUMN_FAMILIES = newOption("readAllColumnFamilies")
410423

411424
object JoinSideValues extends Enumeration {
412425
type JoinSideValues = Value
@@ -492,6 +505,27 @@ object StateSourceOptions extends DataSourceOptions {
492505

493506
val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean)
494507

508+
val readAllColumnFamilies = try {
509+
Option(options.get(READ_ALL_COLUMN_FAMILIES))
510+
.map(_.toBoolean).getOrElse(false)
511+
} catch {
512+
case _: IllegalArgumentException =>
513+
throw StateDataSourceErrors.invalidOptionValue(READ_ALL_COLUMN_FAMILIES,
514+
"Boolean value is expected")
515+
}
516+
517+
if (readAllColumnFamilies && stateVarName.isDefined) {
518+
throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME))
519+
}
520+
521+
if (readAllColumnFamilies && joinSide != JoinSideValues.none) {
522+
throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, JOIN_SIDE))
523+
}
524+
525+
if (readAllColumnFamilies && readChangeFeed) {
526+
throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED))
527+
}
528+
495529
val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong)
496530
var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong)
497531

@@ -616,7 +650,7 @@ object StateSourceOptions extends DataSourceOptions {
616650
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
617651
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
618652
stateVarName, readRegisteredTimers, flattenCollectionTypes,
619-
startOperatorStateUniqueIds, endOperatorStateUniqueIds)
653+
readAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds)
620654
}
621655

622656
private def getLastCommittedBatch(session: SparkSession, checkpointLocation: String): Long = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ import org.apache.spark.internal.Logging
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2222
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
23+
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
2324
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
2425
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
2526
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
2627
import org.apache.spark.sql.execution.streaming.state._
2728
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
28-
import org.apache.spark.sql.types.{NullType, StructField, StructType}
29+
import org.apache.spark.sql.types.{BinaryType, NullType, StructField, StructType}
2930
import org.apache.spark.unsafe.types.UTF8String
3031
import org.apache.spark.util.{NextIterator, SerializableConfiguration}
3132

@@ -49,7 +50,10 @@ class StatePartitionReaderFactory(
4950

5051
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
5152
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
52-
if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
53+
if (stateStoreInputPartition.sourceOptions.readAllColumnFamilies) {
54+
new StatePartitionReaderAllColumnFamilies(storeConf, hadoopConf,
55+
stateStoreInputPartition, schema)
56+
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
5357
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
5458
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
5559
stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt)
@@ -84,13 +88,17 @@ abstract class StatePartitionReaderBase(
8488
protected val keySchema = {
8589
if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) {
8690
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
91+
} else if (partition.sourceOptions.readAllColumnFamilies) {
92+
new StructType().add("keyBytes", BinaryType, nullable = false)
8793
} else {
8894
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
8995
}
9096
}
9197

9298
protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
9399
schemaForValueRow
100+
} else if (partition.sourceOptions.readAllColumnFamilies) {
101+
new StructType().add("valueBytes", BinaryType, nullable = false)
94102
} else {
95103
SchemaUtil.getSchemaAsDataType(
96104
schema, "value").asInstanceOf[StructType]
@@ -237,6 +245,85 @@ class StatePartitionReader(
237245
}
238246
}
239247

248+
/**
249+
* An implementation of [[StatePartitionReaderBase]] for reading all column families
250+
* in binary format. This reader returns raw key and value bytes along with column family names.
251+
*/
252+
class StatePartitionReaderAllColumnFamilies(
253+
storeConf: StateStoreConf,
254+
hadoopConf: SerializableConfiguration,
255+
partition: StateStoreInputPartition,
256+
schema: StructType)
257+
extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
258+
NoPrefixKeyStateEncoderSpec(new StructType()), None, None, None, None) {
259+
260+
val allStateStoreMetadata = {
261+
new StateMetadataPartitionReader(
262+
partition.sourceOptions.resolvedCpLocation,
263+
new SerializableConfiguration(hadoopConf.value),
264+
partition.sourceOptions.batchId).stateMetadata.toArray
265+
}
266+
267+
private lazy val store: ReadStateStore = {
268+
assert(getStartStoreUniqueId == getEndStoreUniqueId,
269+
"Start and end store unique IDs must be the same when reading all column families")
270+
provider.getReadStore(
271+
partition.sourceOptions.batchId + 1,
272+
getStartStoreUniqueId
273+
)
274+
}
275+
276+
val colFamilyNames: Seq[String] = {
277+
// todo: Support operator with multiple column family names in next PR
278+
Seq[String]()
279+
}
280+
281+
override protected lazy val provider: StateStoreProvider = {
282+
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
283+
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
284+
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
285+
286+
val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
287+
val provider = StateStoreProvider.createAndInit(
288+
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
289+
useColumnFamilies = colFamilyNames.nonEmpty, storeConf, hadoopConf.value, false, None)
290+
291+
provider
292+
}
293+
294+
override lazy val iter: Iterator[InternalRow] = {
295+
// Single store with column families (join v3, transformWithState, or simple operators)
296+
require(store.isInstanceOf[SupportsRawBytesRead],
297+
s"State store ${store.getClass.getName} does not support raw bytes reading")
298+
299+
val rawStore = store.asInstanceOf[SupportsRawBytesRead]
300+
if (colFamilyNames.isEmpty) {
301+
rawStore
302+
.rawIterator()
303+
.map { case (keyBytes, valueBytes) =>
304+
SchemaUtil.unifyStateRowPairAsRawBytes(
305+
partition.partition, keyBytes, valueBytes, StateStore.DEFAULT_COL_FAMILY_NAME)
306+
}
307+
} else {
308+
colFamilyNames.iterator.flatMap { colFamilyName =>
309+
rawStore
310+
.rawIterator(colFamilyName)
311+
.map { case (keyBytes, valueBytes) =>
312+
SchemaUtil.unifyStateRowPairAsRawBytes(partition.partition,
313+
keyBytes,
314+
valueBytes,
315+
colFamilyName)
316+
}
317+
}
318+
}
319+
}
320+
321+
override def close(): Unit = {
322+
store.release()
323+
super.close()
324+
}
325+
}
326+
240327
/**
241328
* An implementation of [[StatePartitionReaderBase]] for the readChangeFeed mode of State Data
242329
* Source. It reads the change of state over batches of a particular partition.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceError
2828
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
2929
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType._
3030
import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStoreColFamilySchema, UnsafeRowPair}
31-
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType}
31+
import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType, LongType, MapType, StringType, StructType}
32+
import org.apache.spark.unsafe.types.UTF8String
3233
import org.apache.spark.util.ArrayImplicits._
3334

3435
object SchemaUtil {
@@ -60,6 +61,12 @@ object SchemaUtil {
6061
.add("key", keySchema)
6162
.add("value", valueSchema)
6263
.add("partition_id", IntegerType)
64+
} else if (sourceOptions.readAllColumnFamilies) {
65+
new StructType()
66+
.add("partition_id", IntegerType)
67+
.add("key_bytes", BinaryType)
68+
.add("value_bytes", BinaryType)
69+
.add("column_family_name", StringType)
6370
} else {
6471
new StructType()
6572
.add("key", keySchema)
@@ -76,6 +83,24 @@ object SchemaUtil {
7683
row
7784
}
7885

86+
/**
87+
* Creates a unified row from raw key and value bytes.
88+
* This is an alias for unifyStateRowPairAsBytes that takes individual byte arrays
89+
* instead of a tuple for better readability.
90+
*/
91+
def unifyStateRowPairAsRawBytes(
92+
partition: Int,
93+
keyBytes: Array[Byte],
94+
valueBytes: Array[Byte],
95+
colFamilyName: String): InternalRow = {
96+
val row = new GenericInternalRow(4)
97+
row.update(0, partition)
98+
row.update(1, keyBytes)
99+
row.update(2, valueBytes)
100+
row.update(3, UTF8String.fromString(colFamilyName))
101+
row
102+
}
103+
79104
def unifyStateRowPairWithMultipleValues(
80105
pair: (UnsafeRow, GenericArrayData),
81106
partition: Int): InternalRow = {
@@ -231,7 +256,10 @@ object SchemaUtil {
231256
"user_map_key" -> classOf[StructType],
232257
"user_map_value" -> classOf[StructType],
233258
"expiration_timestamp_ms" -> classOf[LongType],
234-
"partition_id" -> classOf[IntegerType])
259+
"partition_id" -> classOf[IntegerType],
260+
"key_bytes"->classOf[BinaryType],
261+
"value_bytes"->classOf[BinaryType],
262+
"column_family_name"->classOf[StringType])
235263

236264
val expectedFieldNames = if (transformWithStateVariableInfoOpt.isDefined) {
237265
val stateVarInfo = transformWithStateVariableInfoOpt.get
@@ -272,6 +300,8 @@ object SchemaUtil {
272300
}
273301
} else if (sourceOptions.readChangeFeed) {
274302
Seq("batch_id", "change_type", "key", "value", "partition_id")
303+
} else if (sourceOptions.readAllColumnFamilies) {
304+
Seq("partition_id", "key_bytes", "value_bytes", "column_family_name")
275305
} else {
276306
Seq("key", "value", "partition_id")
277307
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,20 @@ object StateStoreProvider extends Logging {
891891
}
892892
}
893893

894+
/**
895+
* Trait for state stores that support reading raw bytes without decoding.
896+
* This is useful for copying state data during repartitioning
897+
*/
898+
trait SupportsRawBytesRead {
899+
/**
900+
* Returns an iterator of raw key-value bytes for a column family.
901+
* @param colFamilyName the name of the column family to iterate over
902+
* @return an iterator of (keyBytes, valueBytes) tuples
903+
*/
904+
def rawIterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME):
905+
Iterator[(Array[Byte], Array[Byte])]
906+
}
907+
894908
/**
895909
* This is an optional trait to be implemented by [[StateStoreProvider]]s that can read the change
896910
* of state store over batches. This is used by State Data Source with additional options like

0 commit comments

Comments
 (0)