Skip to content

Commit 36db0b6

Browse files
author
Ubuntu
committed
address comment
1 parent 2b8cef3 commit 36db0b6

File tree

6 files changed

+108
-222
lines changed

6 files changed

+108
-222
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -66,38 +66,35 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
6666
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
6767
StateSourceOptions.apply(session, hadoopConf, properties))
6868
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
69-
if (sourceOptions.readAllColumnFamilies) {
70-
// For readAllColumnFamilies mode, we don't need specific encoder because it returns raw data
71-
val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType())
72-
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
73-
None, None, None, None)
74-
} else {
75-
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
76-
sourceOptions)
69+
if (sourceOptions.internalOnlyReadAllColumnFamilies
70+
&& !stateConf.providerClass.contains("RocksDB")) {
71+
throw StateDataSourceErrors.invalidOptionValue(
72+
StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
73+
"internalOnlyReadAllColumnFamilies is only supported with RocksDBStateStoreProvider. " +
74+
s"Current provider: ${stateConf.providerClass}")
75+
}
76+
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
77+
sourceOptions)
7778

78-
// The key state encoder spec should be available for all operators except stream-stream joins
79-
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
80-
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
81-
} else {
82-
val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
83-
NoPrefixKeyStateEncoderSpec(keySchema)
84-
}
85-
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
86-
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
87-
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
88-
stateStoreReaderInfo.stateSchemaProviderOpt,
89-
stateStoreReaderInfo.joinColFamilyOpt)
79+
// The key state encoder spec should be available for all operators except stream-stream joins
80+
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
81+
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
82+
} else {
83+
val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
84+
NoPrefixKeyStateEncoderSpec(keySchema)
9085
}
86+
87+
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
88+
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
89+
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
90+
stateStoreReaderInfo.stateSchemaProviderOpt,
91+
stateStoreReaderInfo.joinColFamilyOpt)
9192
}
9293

9394
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
9495
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
9596
StateSourceOptions.apply(session, hadoopConf, options))
96-
if (sourceOptions.readAllColumnFamilies) {
97-
// For readAllColumnFamilies mode, return the binary schema directly
98-
return SchemaUtil.getSourceSchema(
99-
sourceOptions, new StructType(), new StructType(), None, None)
100-
}
97+
10198
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
10299
sourceOptions)
103100
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
@@ -382,7 +379,7 @@ case class StateSourceOptions(
382379
stateVarName: Option[String],
383380
readRegisteredTimers: Boolean,
384381
flattenCollectionTypes: Boolean,
385-
readAllColumnFamilies: Boolean,
382+
internalOnlyReadAllColumnFamilies: Boolean,
386383
startOperatorStateUniqueIds: Option[Array[Array[String]]] = None,
387384
endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) {
388385
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)
@@ -392,7 +389,7 @@ case class StateSourceOptions(
392389
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
393390
s"stateVarName=${stateVarName.getOrElse("None")}, +" +
394391
s"flattenCollectionTypes=$flattenCollectionTypes" +
395-
s"readAllColumnFamilies=$readAllColumnFamilies"
392+
s"internalOnlyReadAllColumnFamilies=$internalOnlyReadAllColumnFamilies"
396393
if (fromSnapshotOptions.isDefined) {
397394
desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}"
398395
desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}"
@@ -419,7 +416,7 @@ object StateSourceOptions extends DataSourceOptions {
419416
val STATE_VAR_NAME = newOption("stateVarName")
420417
val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers")
421418
val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes")
422-
val READ_ALL_COLUMN_FAMILIES = newOption("readAllColumnFamilies")
419+
val INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = newOption("internalOnlyReadAllColumnFamilies")
423420

424421
object JoinSideValues extends Enumeration {
425422
type JoinSideValues = Value
@@ -505,25 +502,28 @@ object StateSourceOptions extends DataSourceOptions {
505502

506503
val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean)
507504

508-
val readAllColumnFamilies = try {
509-
Option(options.get(READ_ALL_COLUMN_FAMILIES))
505+
val internalOnlyReadAllColumnFamilies = try {
506+
Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES))
510507
.map(_.toBoolean).getOrElse(false)
511508
} catch {
512509
case _: IllegalArgumentException =>
513-
throw StateDataSourceErrors.invalidOptionValue(READ_ALL_COLUMN_FAMILIES,
510+
throw StateDataSourceErrors.invalidOptionValue(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
514511
"Boolean value is expected")
515512
}
516513

517-
if (readAllColumnFamilies && stateVarName.isDefined) {
518-
throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME))
514+
if (internalOnlyReadAllColumnFamilies && stateVarName.isDefined) {
515+
throw StateDataSourceErrors.conflictOptions(
516+
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME))
519517
}
520518

521-
if (readAllColumnFamilies && joinSide != JoinSideValues.none) {
522-
throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, JOIN_SIDE))
519+
if (internalOnlyReadAllColumnFamilies && joinSide != JoinSideValues.none) {
520+
throw StateDataSourceErrors.conflictOptions(
521+
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, JOIN_SIDE))
523522
}
524523

525-
if (readAllColumnFamilies && readChangeFeed) {
526-
throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED))
524+
if (internalOnlyReadAllColumnFamilies && readChangeFeed) {
525+
throw StateDataSourceErrors.conflictOptions(
526+
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED))
527527
}
528528

529529
val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong)
@@ -650,7 +650,7 @@ object StateSourceOptions extends DataSourceOptions {
650650
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
651651
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
652652
stateVarName, readRegisteredTimers, flattenCollectionTypes,
653-
readAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds)
653+
internalOnlyReadAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds)
654654
}
655655

656656
private def getLastCommittedBatch(session: SparkSession, checkpointLocation: String): Long = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 12 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import org.apache.spark.internal.Logging
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2222
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
23-
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
2423
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
2524
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
2625
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
@@ -50,9 +49,9 @@ class StatePartitionReaderFactory(
5049

5150
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
5251
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
53-
if (stateStoreInputPartition.sourceOptions.readAllColumnFamilies) {
52+
if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) {
5453
new StatePartitionReaderAllColumnFamilies(storeConf, hadoopConf,
55-
stateStoreInputPartition, schema)
54+
stateStoreInputPartition, schema, keyStateEncoderSpec)
5655
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
5756
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
5857
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
@@ -85,15 +84,15 @@ abstract class StatePartitionReaderBase(
8584
private val schemaForValueRow: StructType =
8685
StructType(Array(StructField("__dummy__", NullType)))
8786

88-
protected lazy val keySchema = {
87+
protected val keySchema = {
8988
if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) {
9089
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
9190
} else {
9291
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
9392
}
9493
}
9594

96-
protected lazy val valueSchema = if (stateVariableInfoOpt.isDefined) {
95+
protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
9796
schemaForValueRow
9897
} else {
9998
SchemaUtil.getSchemaAsDataType(
@@ -249,16 +248,10 @@ class StatePartitionReaderAllColumnFamilies(
249248
storeConf: StateStoreConf,
250249
hadoopConf: SerializableConfiguration,
251250
partition: StateStoreInputPartition,
252-
schema: StructType)
251+
schema: StructType,
252+
keyStateEncoderSpec: KeyStateEncoderSpec)
253253
extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
254-
NoPrefixKeyStateEncoderSpec(new StructType()), None, None, None, None) {
255-
256-
val allStateStoreMetadata = {
257-
new StateMetadataPartitionReader(
258-
partition.sourceOptions.resolvedCpLocation,
259-
new SerializableConfiguration(hadoopConf.value),
260-
partition.sourceOptions.batchId).stateMetadata.toArray
261-
}
254+
keyStateEncoderSpec, None, None, None, None) {
262255

263256
private lazy val store: ReadStateStore = {
264257
assert(getStartStoreUniqueId == getEndStoreUniqueId,
@@ -269,56 +262,14 @@ class StatePartitionReaderAllColumnFamilies(
269262
)
270263
}
271264

272-
val colFamilyNames: Seq[String] = {
273-
// todo: Support operator with multiple column family names in next PR
274-
Seq[String]()
275-
}
276-
277-
override protected lazy val provider: StateStoreProvider = {
278-
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
279-
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
280-
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
281-
282-
// Disable format validation when reading raw bytes.
283-
// We use binary schemas (keyBytes/valueBytes) which don't match the actual schema
284-
// of the stored data. Validation would fail in HDFSBackedStateStoreProvider when
285-
// loading data from disk, so we disable it for raw bytes mode.
286-
val modifiedStoreConf = storeConf.withFormatValidationDisabled()
287-
288-
val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType())
289-
// Pass in empty keySchema, valueSchema and dummy encoder because we don't encode any data
290-
val provider = StateStoreProvider.createAndInit(
291-
stateStoreProviderId, new StructType(), new StructType(), keyStateEncoderSpec,
292-
useColumnFamilies = colFamilyNames.nonEmpty, modifiedStoreConf, hadoopConf.value, false, None)
293-
294-
provider
295-
}
296-
297265
override lazy val iter: Iterator[InternalRow] = {
298266
// Single store with column families (join v3, transformWithState, or simple operators)
299-
require(store.isInstanceOf[SupportsRawBytesRead],
300-
s"State store ${store.getClass.getName} does not support raw bytes reading")
301-
302-
val rawStore = store.asInstanceOf[SupportsRawBytesRead]
303-
if (colFamilyNames.isEmpty) {
304-
rawStore
305-
.rawIterator()
306-
.map { case (keyBytes, valueBytes) =>
307-
SchemaUtil.unifyStateRowPairAsRawBytes(
308-
partition.partition, keyBytes, valueBytes, StateStore.DEFAULT_COL_FAMILY_NAME)
309-
}
310-
} else {
311-
colFamilyNames.iterator.flatMap { colFamilyName =>
312-
rawStore
313-
.rawIterator(colFamilyName)
314-
.map { case (keyBytes, valueBytes) =>
315-
SchemaUtil.unifyStateRowPairAsRawBytes(partition.partition,
316-
keyBytes,
317-
valueBytes,
318-
colFamilyName)
319-
}
267+
store
268+
.iterator()
269+
.map { pair =>
270+
SchemaUtil.unifyStateRowPairAsRawBytes(
271+
(pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME)
320272
}
321-
}
322273
}
323274

324275
override def close(): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,17 @@ object SchemaUtil {
6161
.add("key", keySchema)
6262
.add("value", valueSchema)
6363
.add("partition_id", IntegerType)
64-
} else if (sourceOptions.readAllColumnFamilies) {
64+
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
6565
new StructType()
66-
.add("partition_id", IntegerType)
66+
// todo: change this to some more specific type after we
67+
// can extract partition key from keySchema
68+
.add("partition_key", keySchema)
6769
.add("key_bytes", BinaryType)
6870
.add("value_bytes", BinaryType)
6971
.add("column_family_name", StringType)
72+
// need key and value schema so that state store can encode data
73+
.add("value", valueSchema)
74+
.add("key", keySchema)
7075
} else {
7176
new StructType()
7277
.add("key", keySchema)
@@ -89,15 +94,14 @@ object SchemaUtil {
8994
* instead of a tuple for better readability.
9095
*/
9196
def unifyStateRowPairAsRawBytes(
92-
partition: Int,
93-
keyBytes: Array[Byte],
94-
valueBytes: Array[Byte],
97+
pair: (UnsafeRow, UnsafeRow),
9598
colFamilyName: String): InternalRow = {
96-
val row = new GenericInternalRow(4)
97-
row.update(0, partition)
98-
row.update(1, keyBytes)
99-
row.update(2, valueBytes)
99+
val row = new GenericInternalRow(6)
100+
row.update(0, pair._1)
101+
row.update(1, pair._1.getBytes)
102+
row.update(2, pair._2.getBytes)
100103
row.update(3, UTF8String.fromString(colFamilyName))
104+
// row.update(4, pair._2)
101105
row
102106
}
103107

@@ -257,6 +261,7 @@ object SchemaUtil {
257261
"user_map_value" -> classOf[StructType],
258262
"expiration_timestamp_ms" -> classOf[LongType],
259263
"partition_id" -> classOf[IntegerType],
264+
"partition_key" -> classOf[StructType],
260265
"key_bytes"->classOf[BinaryType],
261266
"value_bytes"->classOf[BinaryType],
262267
"column_family_name"->classOf[StringType])
@@ -300,8 +305,8 @@ object SchemaUtil {
300305
}
301306
} else if (sourceOptions.readChangeFeed) {
302307
Seq("batch_id", "change_type", "key", "value", "partition_id")
303-
} else if (sourceOptions.readAllColumnFamilies) {
304-
Seq("partition_id", "key_bytes", "value_bytes", "column_family_name")
308+
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
309+
Seq("partition_key", "key_bytes", "value_bytes", "column_family_name", "value", "key")
305310
} else {
306311
Seq("key", "value", "partition_id")
307312
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -891,20 +891,6 @@ object StateStoreProvider extends Logging {
891891
}
892892
}
893893

894-
/**
895-
* Trait for state stores that support reading raw bytes without decoding.
896-
* This is useful for copying state data during repartitioning
897-
*/
898-
trait SupportsRawBytesRead {
899-
/**
900-
* Returns an iterator of raw key-value bytes for a column family.
901-
* @param colFamilyName the name of the column family to iterate over
902-
* @return an iterator of (keyBytes, valueBytes) tuples
903-
*/
904-
def rawIterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME):
905-
Iterator[(Array[Byte], Array[Byte])]
906-
}
907-
908894
/**
909895
* This is an optional trait to be implemented by [[StateStoreProvider]]s that can read the change
910896
* of state store over batches. This is used by State Data Source with additional options like

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -163,26 +163,6 @@ class StateStoreConf(
163163
*/
164164
val sqlConfs: Map[String, String] =
165165
sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore."))
166-
167-
/**
168-
* Creates a copy of this StateStoreConf with format validation disabled.
169-
* This is useful when reading raw bytes where the schema used (binary) doesn't match
170-
* the actual stored data schema.
171-
*/
172-
def withFormatValidationDisabled(): StateStoreConf = {
173-
val reconstructedSqlConf = {
174-
// Reconstruct a SQLConf with the all settings preserved because sqlConf is transient
175-
val conf = new SQLConf()
176-
// Restore all state store related settings
177-
sqlConfs.foreach { case (key, value) =>
178-
conf.setConfString(key, value)
179-
}
180-
conf
181-
}
182-
new StateStoreConf(reconstructedSqlConf, extraOptions) {
183-
override val formatValidationEnabled: Boolean = false
184-
}
185-
}
186166
}
187167

188168
object StateStoreConf {

0 commit comments

Comments
 (0)