Skip to content

Commit 42540b1

Browse files
author
Ubuntu
committed
address comment
1 parent 69f5b19 commit 42540b1

File tree

7 files changed

+60
-222
lines changed

7 files changed

+60
-222
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,38 +65,27 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
6565
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
6666
StateSourceOptions.apply(session, hadoopConf, properties))
6767
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
68-
if (sourceOptions.readAllColumnFamilies) {
69-
// For readAllColumnFamilies mode, we don't need specific encoder because it returns raw data
70-
val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType())
71-
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
72-
None, None, None, None)
73-
} else {
74-
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
75-
sourceOptions)
68+
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
69+
sourceOptions)
7670

77-
// The key state encoder spec should be available for all operators except stream-stream joins
78-
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
79-
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
80-
} else {
81-
val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
82-
NoPrefixKeyStateEncoderSpec(keySchema)
83-
}
84-
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
85-
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
86-
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
87-
stateStoreReaderInfo.stateSchemaProviderOpt,
88-
stateStoreReaderInfo.joinColFamilyOpt)
71+
// The key state encoder spec should be available for all operators except stream-stream joins
72+
val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
73+
stateStoreReaderInfo.keyStateEncoderSpecOpt.get
74+
} else {
75+
val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
76+
NoPrefixKeyStateEncoderSpec(keySchema)
8977
}
78+
79+
new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
80+
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
81+
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
82+
stateStoreReaderInfo.stateSchemaProviderOpt,
83+
stateStoreReaderInfo.joinColFamilyOpt)
9084
}
9185

9286
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
9387
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
9488
StateSourceOptions.apply(session, hadoopConf, options))
95-
if (sourceOptions.readAllColumnFamilies) {
96-
// For readAllColumnFamilies mode, return the binary schema directly
97-
return SchemaUtil.getSourceSchema(
98-
sourceOptions, new StructType(), new StructType(), None, None)
99-
}
10089
val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(
10190
sourceOptions)
10291
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions, hadoopConf)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 12 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import org.apache.spark.internal.Logging
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
2222
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
23-
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
2423
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
2524
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
2625
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
@@ -52,7 +51,7 @@ class StatePartitionReaderFactory(
5251
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
5352
if (stateStoreInputPartition.sourceOptions.readAllColumnFamilies) {
5453
new StatePartitionReaderAllColumnFamilies(storeConf, hadoopConf,
55-
stateStoreInputPartition, schema)
54+
stateStoreInputPartition, schema, keyStateEncoderSpec)
5655
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
5756
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
5857
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
@@ -88,12 +87,14 @@ abstract class StatePartitionReaderBase(
8887
protected lazy val keySchema = {
8988
if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) {
9089
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
90+
} else if (partition.sourceOptions.readAllColumnFamilies) {
91+
SchemaUtil.getSchemaAsDataType(schema, "partition_key").asInstanceOf[StructType]
9192
} else {
9293
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
9394
}
9495
}
9596

96-
protected lazy val valueSchema = if (stateVariableInfoOpt.isDefined) {
97+
protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
9798
schemaForValueRow
9899
} else {
99100
SchemaUtil.getSchemaAsDataType(
@@ -249,16 +250,10 @@ class StatePartitionReaderAllColumnFamilies(
249250
storeConf: StateStoreConf,
250251
hadoopConf: SerializableConfiguration,
251252
partition: StateStoreInputPartition,
252-
schema: StructType)
253+
schema: StructType,
254+
keyStateEncoderSpec: KeyStateEncoderSpec)
253255
extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
254-
NoPrefixKeyStateEncoderSpec(new StructType()), None, None, None, None) {
255-
256-
val allStateStoreMetadata = {
257-
new StateMetadataPartitionReader(
258-
partition.sourceOptions.resolvedCpLocation,
259-
new SerializableConfiguration(hadoopConf.value),
260-
partition.sourceOptions.batchId).stateMetadata.toArray
261-
}
256+
keyStateEncoderSpec, None, None, None, None) {
262257

263258
private lazy val store: ReadStateStore = {
264259
assert(getStartStoreUniqueId == getEndStoreUniqueId,
@@ -269,61 +264,14 @@ class StatePartitionReaderAllColumnFamilies(
269264
)
270265
}
271266

272-
val colFamilyNames: Seq[String] = {
273-
// todo: Support operator with multiple column family names in next PR
274-
Seq[String]()
275-
}
276-
277-
override protected lazy val provider: StateStoreProvider = {
278-
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
279-
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
280-
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
281-
282-
// Disable format validation when reading raw bytes.
283-
// We use binary schemas (keyBytes/valueBytes) which don't match the actual schema
284-
// of the stored data. Validation would fail in HDFSBackedStateStoreProvider when
285-
// loading data from disk, so we disable it for raw bytes mode.
286-
val modifiedStoreConf = storeConf.withFormatValidationDisabled()
287-
288-
val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(new StructType())
289-
// Pass in empty keySchema, valueSchema and dummy encoder because we don't encode any data
290-
val provider = StateStoreProvider.createAndInit(
291-
stateStoreProviderId, new StructType(), new StructType(), keyStateEncoderSpec,
292-
useColumnFamilies = colFamilyNames.nonEmpty, modifiedStoreConf, hadoopConf.value, false, None)
293-
294-
provider
295-
}
296-
297267
override lazy val iter: Iterator[InternalRow] = {
298268
// Single store with column families (join v3, transformWithState, or simple operators)
299-
require(store.isInstanceOf[SupportsRawBytesRead],
300-
s"State store ${store.getClass.getName} does not support raw bytes reading")
301-
302-
val rawStore = store.asInstanceOf[SupportsRawBytesRead]
303-
if (colFamilyNames.isEmpty) {
304-
rawStore
305-
.rawIterator()
306-
.map { case (keyBytes, valueBytes) =>
307-
SchemaUtil.unifyStateRowPairAsRawBytes(
308-
partition.partition, keyBytes, valueBytes, StateStore.DEFAULT_COL_FAMILY_NAME)
309-
}
310-
} else {
311-
colFamilyNames.iterator.flatMap { colFamilyName =>
312-
rawStore
313-
.rawIterator(colFamilyName)
314-
.map { case (keyBytes, valueBytes) =>
315-
SchemaUtil.unifyStateRowPairAsRawBytes(partition.partition,
316-
keyBytes,
317-
valueBytes,
318-
colFamilyName)
319-
}
269+
store
270+
.iterator()
271+
.map { pair =>
272+
SchemaUtil.unifyStateRowPairAsRawBytes(
273+
(pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME)
320274
}
321-
}
322-
}
323-
324-
override def close(): Unit = {
325-
store.release()
326-
super.close()
327275
}
328276
}
329277

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,11 @@ object SchemaUtil {
6363
.add("partition_id", IntegerType)
6464
} else if (sourceOptions.readAllColumnFamilies) {
6565
new StructType()
66-
.add("partition_id", IntegerType)
66+
.add("partition_key", keySchema)
6767
.add("key_bytes", BinaryType)
6868
.add("value_bytes", BinaryType)
6969
.add("column_family_name", StringType)
70+
.add("value", valueSchema)
7071
} else {
7172
new StructType()
7273
.add("key", keySchema)
@@ -89,15 +90,14 @@ object SchemaUtil {
8990
* instead of a tuple for better readability.
9091
*/
9192
def unifyStateRowPairAsRawBytes(
92-
partition: Int,
93-
keyBytes: Array[Byte],
94-
valueBytes: Array[Byte],
93+
pair: (UnsafeRow, UnsafeRow),
9594
colFamilyName: String): InternalRow = {
96-
val row = new GenericInternalRow(4)
97-
row.update(0, partition)
98-
row.update(1, keyBytes)
99-
row.update(2, valueBytes)
95+
val row = new GenericInternalRow(5)
96+
row.update(0, pair._1)
97+
row.update(1, pair._1.getBytes)
98+
row.update(2, pair._2.getBytes)
10099
row.update(3, UTF8String.fromString(colFamilyName))
100+
row.update(4, pair._2)
101101
row
102102
}
103103

@@ -257,6 +257,7 @@ object SchemaUtil {
257257
"user_map_value" -> classOf[StructType],
258258
"expiration_timestamp_ms" -> classOf[LongType],
259259
"partition_id" -> classOf[IntegerType],
260+
"partition_key" -> classOf[StructType],
260261
"key_bytes"->classOf[BinaryType],
261262
"value_bytes"->classOf[BinaryType],
262263
"column_family_name"->classOf[StringType])
@@ -301,7 +302,7 @@ object SchemaUtil {
301302
} else if (sourceOptions.readChangeFeed) {
302303
Seq("batch_id", "change_type", "key", "value", "partition_id")
303304
} else if (sourceOptions.readAllColumnFamilies) {
304-
Seq("partition_id", "key_bytes", "value_bytes", "column_family_name")
305+
Seq("partition_key", "key_bytes", "value_bytes", "column_family_name", "value")
305306
} else {
306307
Seq("key", "value", "partition_id")
307308
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
7575
private val providerName = "HDFSBackedStateStoreProvider"
7676

7777
class HDFSBackedReadStateStore(val version: Long, map: HDFSBackedStateStoreMap)
78-
extends ReadStateStore with SupportsRawBytesRead {
78+
extends ReadStateStore {
7979

8080
override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId
8181

@@ -104,22 +104,13 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
104104
override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
105105
throw StateStoreErrors.unsupportedOperationException("multipleValuesPerKey", "HDFSStateStore")
106106
}
107-
108-
override def rawIterator(colFamilyName: String): Iterator[(Array[Byte], Array[Byte])] = {
109-
// For HDFS, we get UnsafeRows and convert them to bytes
110-
// The bytes will be properly aligned since they come from valid UnsafeRows
111-
map.iterator().map { pair =>
112-
(pair.key.getBytes(), pair.value.getBytes())
113-
}
114-
}
115107
}
116108

117109
/** Implementation of [[StateStore]] API which is backed by an HDFS-compatible file system */
118110
class HDFSBackedStateStore(
119111
val version: Long,
120112
private val mapToUpdate: HDFSBackedStateStoreMap,
121-
shouldForceSnapshot: Boolean = false)
122-
extends StateStore with SupportsRawBytesRead {
113+
shouldForceSnapshot: Boolean = false) extends StateStore {
123114

124115
/** Trait and classes representing the internal state of the store */
125116
trait STATE
@@ -247,15 +238,6 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
247238
new StateStoreIterator(iter)
248239
}
249240

250-
override def rawIterator(colFamilyName: String): Iterator[(Array[Byte], Array[Byte])] = {
251-
assertUseOfDefaultColFamily(colFamilyName)
252-
// For HDFS, we get UnsafeRows and convert them to bytes
253-
// The bytes will be properly aligned since they come from valid UnsafeRows
254-
mapToUpdate.iterator().map { pair =>
255-
(pair.key.getBytes(), pair.value.getBytes())
256-
}
257-
}
258-
259241
override def prefixScan(
260242
prefixKey: UnsafeRow,
261243
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ private[sql] class RocksDBStateStoreProvider
4747
lastVersion: Long,
4848
private[RocksDBStateStoreProvider] val stamp: Long,
4949
private[RocksDBStateStoreProvider] var readOnly: Boolean,
50-
private[RocksDBStateStoreProvider] var forceSnapshotOnCommit: Boolean) extends StateStore
51-
with SupportsRawBytesRead {
50+
private[RocksDBStateStoreProvider] var forceSnapshotOnCommit: Boolean) extends StateStore {
5251

5352
private sealed trait OPERATION
5453
private case object UPDATE extends OPERATION
@@ -420,21 +419,6 @@ private[sql] class RocksDBStateStoreProvider
420419
}
421420
}
422421

423-
override def rawIterator(colFamilyName: String): Iterator[(Array[Byte], Array[Byte])] = {
424-
validateAndTransitionState(UPDATE)
425-
verifyColFamilyOperations("rawIterator", colFamilyName)
426-
427-
if (useColumnFamilies) {
428-
rocksDB.iterator(colFamilyName).map { pair =>
429-
(pair.key, pair.value)
430-
}
431-
} else {
432-
rocksDB.iterator().map { pair =>
433-
(pair.key, pair.value)
434-
}
435-
}
436-
}
437-
438422
override def prefixScan(
439423
prefixKey: UnsafeRow,
440424
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -891,20 +891,6 @@ object StateStoreProvider extends Logging {
891891
}
892892
}
893893

894-
/**
895-
* Trait for state stores that support reading raw bytes without decoding.
896-
* This is useful for copying state data during repartitioning
897-
*/
898-
trait SupportsRawBytesRead {
899-
/**
900-
* Returns an iterator of raw key-value bytes for a column family.
901-
* @param colFamilyName the name of the column family to iterate over
902-
* @return an iterator of (keyBytes, valueBytes) tuples
903-
*/
904-
def rawIterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME):
905-
Iterator[(Array[Byte], Array[Byte])]
906-
}
907-
908894
/**
909895
* This is an optional trait to be implemented by [[StateStoreProvider]]s that can read the change
910896
* of state store over batches. This is used by State Data Source with additional options like

0 commit comments

Comments
 (0)