Skip to content

Commit 0a878e9

Browse files
author
Ubuntu
committed
address comment
1 parent 69f5b19 commit 0a878e9

File tree

8 files changed

+111
-258
lines changed

8 files changed

+111
-258
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -65,38 +65,35 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
6565
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
6666
StateSourceOptions.apply(session, hadoopConf, properties))
6767
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
68-
if (sourceOptions.readAllColumnFamilies) {
69-
// For readAllColumnFamilies mode, we don't need specific encoder because it returns raw data
70-
val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType())
71-
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
72-
None, None, None, None)
73-
} else {
74-
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
75-
sourceOptions)
68+
if (sourceOptions.internalOnlyReadAllColumnFamilies
69+
&& !stateConf.providerClass.contains("RocksDB")) {
70+
throw StateDataSourceErrors.invalidOptionValue(
71+
StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
72+
"internalOnlyReadAllColumnFamilies is only supported with RocksDBStateStoreProvider. " +
73+
s"Current provider: ${stateConf.providerClass}")
74+
}
75+
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
76+
sourceOptions)
7677

77-
// The key state encoder spec should be available for all operators except stream-stream joins
78-
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
79-
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
80-
} else {
81-
val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
82-
NoPrefixKeyStateEncoderSpec(keySchema)
83-
}
84-
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
85-
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
86-
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
87-
stateStoreReaderInfo.stateSchemaProviderOpt,
88-
stateStoreReaderInfo.joinColFamilyOpt)
78+
// The key state encoder spec should be available for all operators except stream-stream joins
79+
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
80+
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
81+
} else {
82+
val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
83+
NoPrefixKeyStateEncoderSpec(keySchema)
8984
}
85+
86+
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
87+
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
88+
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
89+
stateStoreReaderInfo.stateSchemaProviderOpt,
90+
stateStoreReaderInfo.joinColFamilyOpt)
9091
}
9192

9293
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
9394
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
9495
StateSourceOptions.apply(session, hadoopConf, options))
95-
if (sourceOptions.readAllColumnFamilies) {
96-
// For readAllColumnFamilies mode, return the binary schema directly
97-
return SchemaUtil.getSourceSchema(
98-
sourceOptions, new StructType(), new StructType(), None, None)
99-
}
96+
10097
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
10198
sourceOptions)
10299
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)
@@ -381,7 +378,7 @@ case class StateSourceOptions(
381378
stateVarName: Option[String],
382379
readRegisteredTimers: Boolean,
383380
flattenCollectionTypes: Boolean,
384-
readAllColumnFamilies: Boolean,
381+
internalOnlyReadAllColumnFamilies: Boolean,
385382
startOperatorStateUniqueIds: Option[Array[Array[String]]] = None,
386383
endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) {
387384
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)
@@ -391,7 +388,7 @@ case class StateSourceOptions(
391388
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
392389
s"stateVarName=${stateVarName.getOrElse("None")}, +" +
393390
s"flattenCollectionTypes=$flattenCollectionTypes" +
394-
s"readAllColumnFamilies=$readAllColumnFamilies"
391+
s"internalOnlyReadAllColumnFamilies=$internalOnlyReadAllColumnFamilies"
395392
if (fromSnapshotOptions.isDefined) {
396393
desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}"
397394
desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}"
@@ -418,7 +415,7 @@ object StateSourceOptions extends DataSourceOptions {
418415
val STATE_VAR_NAME = newOption("stateVarName")
419416
val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers")
420417
val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes")
421-
val READ_ALL_COLUMN_FAMILIES = newOption("readAllColumnFamilies")
418+
val INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES = newOption("internalOnlyReadAllColumnFamilies")
422419

423420
object JoinSideValues extends Enumeration {
424421
type JoinSideValues = Value
@@ -503,25 +500,28 @@ object StateSourceOptions extends DataSourceOptions {
503500

504501
val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean)
505502

506-
val readAllColumnFamilies = try {
507-
Option(options.get(READ_ALL_COLUMN_FAMILIES))
503+
val internalOnlyReadAllColumnFamilies = try {
504+
Option(options.get(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES))
508505
.map(_.toBoolean).getOrElse(false)
509506
} catch {
510507
case _: IllegalArgumentException =>
511-
throw StateDataSourceErrors.invalidOptionValue(READ_ALL_COLUMN_FAMILIES,
508+
throw StateDataSourceErrors.invalidOptionValue(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
512509
"Boolean value is expected")
513510
}
514511

515-
if (readAllColumnFamilies && stateVarName.isDefined) {
516-
throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME))
512+
if (internalOnlyReadAllColumnFamilies && stateVarName.isDefined) {
513+
throw StateDataSourceErrors.conflictOptions(
514+
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, STATE_VAR_NAME))
517515
}
518516

519-
if (readAllColumnFamilies && joinSide != JoinSideValues.none) {
520-
throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, JOIN_SIDE))
517+
if (internalOnlyReadAllColumnFamilies && joinSide != JoinSideValues.none) {
518+
throw StateDataSourceErrors.conflictOptions(
519+
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, JOIN_SIDE))
521520
}
522521

523-
if (readAllColumnFamilies && readChangeFeed) {
524-
throw StateDataSourceErrors.conflictOptions(Seq(READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED))
522+
if (internalOnlyReadAllColumnFamilies && readChangeFeed) {
523+
throw StateDataSourceErrors.conflictOptions(
524+
Seq(INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, READ_CHANGE_FEED))
525525
}
526526

527527
val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong)
@@ -648,7 +648,7 @@ object StateSourceOptions extends DataSourceOptions {
648648
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
649649
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
650650
stateVarName, readRegisteredTimers, flattenCollectionTypes,
651-
readAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds)
651+
internalOnlyReadAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds)
652652
}
653653

654654
private def resolvedCheckpointLocation(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 12 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import org.apache.spark.internal.Logging
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2222
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
23-
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
2423
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
2524
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
2625
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
@@ -50,9 +49,9 @@ class StatePartitionReaderFactory(
5049

5150
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
5251
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
53-
if (stateStoreInputPartition.sourceOptions.readAllColumnFamilies) {
52+
if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) {
5453
new StatePartitionReaderAllColumnFamilies(storeConf, hadoopConf,
55-
stateStoreInputPartition, schema)
54+
stateStoreInputPartition, schema, keyStateEncoderSpec)
5655
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
5756
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
5857
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
@@ -85,15 +84,15 @@ abstract class StatePartitionReaderBase(
8584
private val schemaForValueRow: StructType =
8685
StructType(Array(StructField("__dummy__", NullType)))
8786

88-
protected lazy val keySchema = {
87+
protected val keySchema = {
8988
if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) {
9089
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
9190
} else {
9291
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
9392
}
9493
}
9594

96-
protected lazy val valueSchema = if (stateVariableInfoOpt.isDefined) {
95+
protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
9796
schemaForValueRow
9897
} else {
9998
SchemaUtil.getSchemaAsDataType(
@@ -249,16 +248,10 @@ class StatePartitionReaderAllColumnFamilies(
249248
storeConf: StateStoreConf,
250249
hadoopConf: SerializableConfiguration,
251250
partition: StateStoreInputPartition,
252-
schema: StructType)
251+
schema: StructType,
252+
keyStateEncoderSpec: KeyStateEncoderSpec)
253253
extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
254-
NoPrefixKeyStateEncoderSpec(new StructType()), None, None, None, None) {
255-
256-
val allStateStoreMetadata = {
257-
new StateMetadataPartitionReader(
258-
partition.sourceOptions.resolvedCpLocation,
259-
new SerializableConfiguration(hadoopConf.value),
260-
partition.sourceOptions.batchId).stateMetadata.toArray
261-
}
254+
keyStateEncoderSpec, None, None, None, None) {
262255

263256
private lazy val store: ReadStateStore = {
264257
assert(getStartStoreUniqueId == getEndStoreUniqueId,
@@ -269,56 +262,14 @@ class StatePartitionReaderAllColumnFamilies(
269262
)
270263
}
271264

272-
val colFamilyNames: Seq[String] = {
273-
// todo: Support operator with multiple column family names in next PR
274-
Seq[String]()
275-
}
276-
277-
override protected lazy val provider: StateStoreProvider = {
278-
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
279-
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
280-
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
281-
282-
// Disable format validation when reading raw bytes.
283-
// We use binary schemas (keyBytes/valueBytes) which don't match the actual schema
284-
// of the stored data. Validation would fail in HDFSBackedStateStoreProvider when
285-
// loading data from disk, so we disable it for raw bytes mode.
286-
val modifiedStoreConf = storeConf.withFormatValidationDisabled()
287-
288-
val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType())
289-
// Pass in empty keySchema, valueSchema and dummy encoder because we don't encode any data
290-
val provider = StateStoreProvider.createAndInit(
291-
stateStoreProviderId, new StructType(), new StructType(), keyStateEncoderSpec,
292-
useColumnFamilies = colFamilyNames.nonEmpty, modifiedStoreConf, hadoopConf.value, false, None)
293-
294-
provider
295-
}
296-
297265
override lazy val iter: Iterator[InternalRow] = {
298266
// Single store with column families (join v3, transformWithState, or simple operators)
299-
require(store.isInstanceOf[SupportsRawBytesRead],
300-
s"State store ${store.getClass.getName} does not support raw bytes reading")
301-
302-
val rawStore = store.asInstanceOf[SupportsRawBytesRead]
303-
if (colFamilyNames.isEmpty) {
304-
rawStore
305-
.rawIterator()
306-
.map { case (keyBytes, valueBytes) =>
307-
SchemaUtil.unifyStateRowPairAsRawBytes(
308-
partition.partition, keyBytes, valueBytes, StateStore.DEFAULT_COL_FAMILY_NAME)
309-
}
310-
} else {
311-
colFamilyNames.iterator.flatMap { colFamilyName =>
312-
rawStore
313-
.rawIterator(colFamilyName)
314-
.map { case (keyBytes, valueBytes) =>
315-
SchemaUtil.unifyStateRowPairAsRawBytes(partition.partition,
316-
keyBytes,
317-
valueBytes,
318-
colFamilyName)
319-
}
267+
store
268+
.iterator()
269+
.map { pair =>
270+
SchemaUtil.unifyStateRowPairAsRawBytes(
271+
(pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME)
320272
}
321-
}
322273
}
323274

324275
override def close(): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,17 @@ object SchemaUtil {
6161
.add("key", keySchema)
6262
.add("value", valueSchema)
6363
.add("partition_id", IntegerType)
64-
} else if (sourceOptions.readAllColumnFamilies) {
64+
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
6565
new StructType()
66-
.add("partition_id", IntegerType)
66+
// todo: change this to some more specific type after we
67+
// can extract partition key from keySchema
68+
.add("partition_key", keySchema)
6769
.add("key_bytes", BinaryType)
6870
.add("value_bytes", BinaryType)
6971
.add("column_family_name", StringType)
72+
// need key and value schema so that state store can encode data
73+
.add("value", valueSchema)
74+
.add("key", keySchema)
7075
} else {
7176
new StructType()
7277
.add("key", keySchema)
@@ -89,15 +94,14 @@ object SchemaUtil {
8994
* instead of a tuple for better readability.
9095
*/
9196
def unifyStateRowPairAsRawBytes(
92-
partition: Int,
93-
keyBytes: Array[Byte],
94-
valueBytes: Array[Byte],
97+
pair: (UnsafeRow, UnsafeRow),
9598
colFamilyName: String): InternalRow = {
96-
val row = new GenericInternalRow(4)
97-
row.update(0, partition)
98-
row.update(1, keyBytes)
99-
row.update(2, valueBytes)
99+
val row = new GenericInternalRow(6)
100+
row.update(0, pair._1)
101+
row.update(1, pair._1.getBytes)
102+
row.update(2, pair._2.getBytes)
100103
row.update(3, UTF8String.fromString(colFamilyName))
104+
// row.update(4, pair._2)
101105
row
102106
}
103107

@@ -257,6 +261,7 @@ object SchemaUtil {
257261
"user_map_value" -> classOf[StructType],
258262
"expiration_timestamp_ms" -> classOf[LongType],
259263
"partition_id" -> classOf[IntegerType],
264+
"partition_key" -> classOf[StructType],
260265
"key_bytes"->classOf[BinaryType],
261266
"value_bytes"->classOf[BinaryType],
262267
"column_family_name"->classOf[StringType])
@@ -300,8 +305,8 @@ object SchemaUtil {
300305
}
301306
} else if (sourceOptions.readChangeFeed) {
302307
Seq("batch_id", "change_type", "key", "value", "partition_id")
303-
} else if (sourceOptions.readAllColumnFamilies) {
304-
Seq("partition_id", "key_bytes", "value_bytes", "column_family_name")
308+
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
309+
Seq("partition_key", "key_bytes", "value_bytes", "column_family_name", "value", "key")
305310
} else {
306311
Seq("key", "value", "partition_id")
307312
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
7575
private val providerName = "HDFSBackedStateStoreProvider"
7676

7777
class HDFSBackedReadStateStore(val version: Long, map: HDFSBackedStateStoreMap)
78-
extends ReadStateStore with SupportsRawBytesRead {
78+
extends ReadStateStore {
7979

8080
override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId
8181

@@ -104,22 +104,14 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
104104
override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
105105
throw StateStoreErrors.unsupportedOperationException("multipleValuesPerKey", "HDFSStateStore")
106106
}
107-
108-
override def rawIterator(colFamilyName: String): Iterator[(Array[Byte], Array[Byte])] = {
109-
// For HDFS, we get UnsafeRows and convert them to bytes
110-
// The bytes will be properly aligned since they come from valid UnsafeRows
111-
map.iterator().map { pair =>
112-
(pair.key.getBytes(), pair.value.getBytes())
113-
}
114-
}
115107
}
116108

117109
/** Implementation of [[StateStore]] API which is backed by an HDFS-compatible file system */
118110
class HDFSBackedStateStore(
119111
val version: Long,
120112
private val mapToUpdate: HDFSBackedStateStoreMap,
121113
shouldForceSnapshot: Boolean = false)
122-
extends StateStore with SupportsRawBytesRead {
114+
extends StateStore {
123115

124116
/** Trait and classes representing the internal state of the store */
125117
trait STATE
@@ -247,15 +239,6 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
247239
new StateStoreIterator(iter)
248240
}
249241

250-
override def rawIterator(colFamilyName: String): Iterator[(Array[Byte], Array[Byte])] = {
251-
assertUseOfDefaultColFamily(colFamilyName)
252-
// For HDFS, we get UnsafeRows and convert them to bytes
253-
// The bytes will be properly aligned since they come from valid UnsafeRows
254-
mapToUpdate.iterator().map { pair =>
255-
(pair.key.getBytes(), pair.value.getBytes())
256-
}
257-
}
258-
259242
override def prefixScan(
260243
prefixKey: UnsafeRow,
261244
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ private[sql] class RocksDBStateStoreProvider
4747
lastVersion: Long,
4848
private[RocksDBStateStoreProvider] val stamp: Long,
4949
private[RocksDBStateStoreProvider] var readOnly: Boolean,
50-
private[RocksDBStateStoreProvider] var forceSnapshotOnCommit: Boolean) extends StateStore
51-
with SupportsRawBytesRead {
50+
private[RocksDBStateStoreProvider] var forceSnapshotOnCommit: Boolean) extends StateStore {
5251

5352
private sealed trait OPERATION
5453
private case object UPDATE extends OPERATION
@@ -420,21 +419,6 @@ private[sql] class RocksDBStateStoreProvider
420419
}
421420
}
422421

423-
override def rawIterator(colFamilyName: String): Iterator[(Array[Byte], Array[Byte])] = {
424-
validateAndTransitionState(UPDATE)
425-
verifyColFamilyOperations("rawIterator", colFamilyName)
426-
427-
if (useColumnFamilies) {
428-
rocksDB.iterator(colFamilyName).map { pair =>
429-
(pair.key, pair.value)
430-
}
431-
} else {
432-
rocksDB.iterator().map { pair =>
433-
(pair.key, pair.value)
434-
}
435-
}
436-
}
437-
438422
override def prefixScan(
439423
prefixKey: UnsafeRow,
440424
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {

0 commit comments

Comments
 (0)