Skip to content

Commit ea461ee

Browse files
author
Ubuntu
committed
add test and support for HDFS
1 parent 44dcc3f commit ea461ee

File tree

3 files changed

+145
-95
lines changed

3 files changed

+145
-95
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,16 @@ class StatePartitionReaderAllColumnFamilies(
283283
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
284284
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
285285

286+
// Disable format validation when reading raw bytes.
287+
// We use binary schemas (keyBytes/valueBytes) which don't match the actual schema
288+
// of the stored data. Validation would fail in HDFSBackedStateStoreProvider when
289+
// loading data from disk, so we disable it for raw bytes mode.
290+
val modifiedStoreConf = storeConf.withFormatValidationDisabled()
291+
286292
val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
287293
val provider = StateStoreProvider.createAndInit(
288294
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
289-
useColumnFamilies = colFamilyNames.nonEmpty, storeConf, hadoopConf.value, false, None)
295+
useColumnFamilies = colFamilyNames.nonEmpty, modifiedStoreConf, hadoopConf.value, false, None)
290296

291297
provider
292298
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,26 @@ class StateStoreConf(
163163
*/
164164
val sqlConfs: Map[String, String] =
165165
sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore."))
166+
167+
/**
168+
* Creates a copy of this StateStoreConf with format validation disabled.
169+
* This is useful when reading raw bytes where the schema used (binary) doesn't match
170+
* the actual stored data schema.
171+
*/
172+
def withFormatValidationDisabled(): StateStoreConf = {
173+
val reconstructedSqlConf = {
174+
// Reconstruct a SQLConf with the all settings preserved because sqlConf is transient
175+
val conf = new SQLConf()
176+
// Restore all state store related settings
177+
sqlConfs.foreach { case (key, value) =>
178+
conf.setConfString(key, value)
179+
}
180+
conf
181+
}
182+
new StateStoreConf(reconstructedSqlConf, extraOptions) {
183+
override val formatValidationEnabled: Boolean = false
184+
}
185+
}
166186
}
167187

168188
object StateStoreConf {

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala

Lines changed: 118 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.v2.state
1818

19-
import org.apache.spark.sql.{DataFrame, Row}
19+
import org.apache.spark.sql.DataFrame
2020
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
2121
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
22-
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
22+
import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider}
2323
import org.apache.spark.sql.functions.{count, sum}
2424
import org.apache.spark.sql.internal.SQLConf
2525
import org.apache.spark.sql.streaming.OutputMode
@@ -38,21 +38,12 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
3838
/**
3939
* Returns a set of (partitionId, key, value) tuples from a normal state read.
4040
*/
41-
private def getNormalReadData(checkpointDir: String): Set[(Int, Row, Row)] = {
42-
val normalReadDf = spark.read
41+
private def getNormalReadData(checkpointDir: String): DataFrame = {
42+
spark.read
4343
.format("statestore")
4444
.option(StateSourceOptions.PATH, checkpointDir)
4545
.load()
4646
.selectExpr("partition_id", "key", "value")
47-
48-
normalReadDf.collect()
49-
.map { row =>
50-
val partitionId = row.getInt(0)
51-
val key = row.getStruct(1)
52-
val value = row.getStruct(2)
53-
(partitionId, key, value)
54-
}
55-
.toSet
5647
}
5748

5849
/**
@@ -108,9 +99,14 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
10899

109100
/**
110101
* Parses the bytes read DataFrame into a set of (partitionId, key, value, columnFamily) tuples.
102+
* For RocksDB provider, skipVersionBytes should be true.
103+
* For HDFS provider, skipVersionBytes should be false.
111104
*/
112105
private def parseBytesReadData(
113-
df: DataFrame, numOfKey: Int, numOfValue: Int): Set[(Int, UnsafeRow, UnsafeRow, String)] = {
106+
df: DataFrame,
107+
numOfKey: Int,
108+
numOfValue: Int,
109+
skipVersionBytes: Boolean = true): Set[(Int, UnsafeRow, UnsafeRow, String)] = {
114110
df.selectExpr("partition_id", "key_bytes", "value_bytes", "column_family_name")
115111
.collect()
116112
.map { row =>
@@ -120,20 +116,38 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
120116
val columnFamily = row.getString(3)
121117

122118
// Deserialize key bytes to UnsafeRow
123-
// Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
124119
val keyRow = new UnsafeRow(numOfKey)
125-
keyRow.pointTo(
126-
keyBytes,
127-
Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES,
128-
keyBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES)
120+
if (skipVersionBytes) {
121+
// Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
122+
// This is for RocksDB provider
123+
keyRow.pointTo(
124+
keyBytes,
125+
Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES,
126+
keyBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES)
127+
} else {
128+
// HDFS provider doesn't add version bytes, use bytes directly
129+
keyRow.pointTo(
130+
keyBytes,
131+
Platform.BYTE_ARRAY_OFFSET,
132+
keyBytes.length)
133+
}
129134

130135
// Deserialize value bytes to UnsafeRow
131-
// Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
132136
val valueRow = new UnsafeRow(numOfValue)
133-
valueRow.pointTo(
134-
valueBytes,
135-
Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES,
136-
valueBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES)
137+
if (skipVersionBytes) {
138+
// Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
139+
// This is for RocksDB provider
140+
valueRow.pointTo(
141+
valueBytes,
142+
Platform.BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES,
143+
valueBytes.length - RocksDBStateStoreProvider.STATE_ENCODING_NUM_VERSION_BYTES)
144+
} else {
145+
// HDFS provider doesn't add version bytes, use bytes directly
146+
valueRow.pointTo(
147+
valueBytes,
148+
Platform.BYTE_ARRAY_OFFSET,
149+
valueBytes.length)
150+
}
137151

138152
(partitionId, keyRow.copy(), valueRow.copy(), columnFamily)
139153
}
@@ -144,24 +158,31 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
144158
* Compares normal read data with bytes read data for a specific column family.
145159
*/
146160
private def compareNormalAndBytesData(
147-
normalData: Set[(Int, Row, Row)],
148-
bytesData: Set[(Int, UnsafeRow, UnsafeRow, String)],
161+
normalReadDf: DataFrame,
162+
bytesReadDf: DataFrame,
149163
columnFamily: String,
150164
keySchema: StructType,
151-
valueSchema: StructType): Unit = {
165+
valueSchema: StructType,
166+
skipVersionBytes: Boolean): Unit = {
167+
152168
// Filter bytes data for the specified column family
169+
val bytesData = parseBytesReadData(bytesReadDf, keySchema.length, valueSchema.length,
170+
skipVersionBytes)
153171
val filteredBytesData = bytesData.filter(_._4 == columnFamily)
154172

155-
// Verify same number of rows
156-
assert(filteredBytesData.size == normalData.size,
157-
s"Row count mismatch for column family '$columnFamily': " +
158-
s"normal read has ${filteredBytesData.size} rows, bytes read has ${normalData.size} rows")
159173
// Convert to comparable format (extract field values)
160-
val normalSet = normalData.map { case (partId, key, value) =>
174+
val normalSet = normalReadDf.collect().map { row =>
175+
val partitionId = row.getInt(0)
176+
val key = row.getStruct(1)
177+
val value = row.getStruct(2)
161178
val keyFields = (0 until key.length).map(i => key.get(i))
162179
val valueFields = (0 until value.length).map(i => value.get(i))
163-
(partId, keyFields, valueFields)
164-
}
180+
(partitionId, keyFields, valueFields)
181+
}.toSet
182+
// Verify same number of rows
183+
assert(filteredBytesData.size == normalSet.size,
184+
s"Row count mismatch for column family '$columnFamily': " +
185+
s"normal read has ${filteredBytesData.size} rows, bytes read has ${normalSet.size} rows")
165186

166187
val bytesSet = filteredBytesData.map { case (partId, keyRow, valueRow, _) =>
167188
val keyFields = (0 until keySchema.length).map(i =>
@@ -174,67 +195,70 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
174195
assert(normalSet == bytesSet)
175196
}
176197

177-
test("read all column families with simple operator") {
178-
withTempDir { tempDir =>
179-
withSQLConf(
180-
SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
181-
classOf[RocksDBStateStoreProvider].getName,
182-
SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
183-
184-
val inputData = MemoryStream[Int]
185-
val aggregated = inputData.toDF()
186-
.selectExpr("value", "value % 10 AS groupKey")
187-
.groupBy($"groupKey")
188-
.agg(
189-
count("*").as("cnt"),
190-
sum("value").as("sum")
198+
Seq(
199+
("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider], true),
200+
("HDFSBackedStateStoreProvider", classOf[HDFSBackedStateStoreProvider], false)
201+
).foreach { case (providerName, providerClass, skipVersionBytes) =>
202+
test(s"read all column families with simple operator - $providerName") {
203+
withTempDir { tempDir =>
204+
withSQLConf(
205+
SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClass.getName,
206+
SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
207+
208+
val inputData = MemoryStream[Int]
209+
val aggregated = inputData.toDF()
210+
.selectExpr("value", "value % 10 AS groupKey")
211+
.groupBy($"groupKey")
212+
.agg(
213+
count("*").as("cnt"),
214+
sum("value").as("sum")
215+
)
216+
.as[(Int, Long, Long)]
217+
218+
testStream(aggregated, OutputMode.Update)(
219+
StartStream(checkpointLocation = tempDir.getAbsolutePath),
220+
// batch 0
221+
AddData(inputData, 0 until 20: _*),
222+
CheckLastBatch(
223+
(0, 2, 10), // 0, 10
224+
(1, 2, 12), // 1, 11
225+
(2, 2, 14), // 2, 12
226+
(3, 2, 16), // 3, 13
227+
(4, 2, 18), // 4, 14
228+
(5, 2, 20), // 5, 15
229+
(6, 2, 22), // 6, 16
230+
(7, 2, 24), // 7, 17
231+
(8, 2, 26), // 8, 18
232+
(9, 2, 28) // 9, 19
233+
),
234+
StopStream
191235
)
192-
.as[(Int, Long, Long)]
193-
194-
testStream(aggregated, OutputMode.Update)(
195-
StartStream(checkpointLocation = tempDir.getAbsolutePath),
196-
// batch 0
197-
AddData(inputData, 0 until 20: _*),
198-
CheckLastBatch(
199-
(0, 2, 10), // 0, 10
200-
(1, 2, 12), // 1, 11
201-
(2, 2, 14), // 2, 12
202-
(3, 2, 16), // 3, 13
203-
(4, 2, 18), // 4, 14
204-
(5, 2, 20), // 5, 15
205-
(6, 2, 22), // 6, 16
206-
(7, 2, 24), // 7, 17
207-
(8, 2, 26), // 8, 18
208-
(9, 2, 28) // 9, 19
209-
),
210-
StopStream
211-
)
212-
213-
// Read state data once with READ_ALL_COLUMN_FAMILIES = true
214-
val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath)
215-
216-
// Verify schema and column families
217-
validateBytesReadSchema(bytesReadDf,
218-
expectedRowCount = 10,
219-
expectedColumnFamilies = Seq("default"))
220-
221-
// Get normal read data for comparison
222-
val normalData = getNormalReadData(tempDir.getAbsolutePath)
223-
224-
// Compare normal and bytes data for default column family
225-
val keySchema: StructType = StructType(Array(
226-
StructField("key", IntegerType, nullable = false)
227-
))
228-
229-
// Value schema for the aggregation: count and sum columns
230-
val valueSchema: StructType = StructType(Array(
231-
StructField("count", LongType, nullable = false),
232-
StructField("sum", LongType, nullable = false)
233-
))
234-
// Parse bytes read data
235-
val bytesData = parseBytesReadData(bytesReadDf, keySchema.length, valueSchema.length)
236-
237-
compareNormalAndBytesData(normalData, bytesData, "default", keySchema, valueSchema)
236+
237+
// Read state data once with READ_ALL_COLUMN_FAMILIES = true
238+
val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath)
239+
240+
// Verify schema and column families
241+
validateBytesReadSchema(bytesReadDf,
242+
expectedRowCount = 10,
243+
expectedColumnFamilies = Seq("default"))
244+
245+
// Compare normal and bytes data for default column family
246+
val keySchema: StructType = StructType(Array(
247+
StructField("key", IntegerType, nullable = false)
248+
))
249+
250+
// Value schema for the aggregation: count and sum columns
251+
val valueSchema: StructType = StructType(Array(
252+
StructField("count", LongType, nullable = false),
253+
StructField("sum", LongType, nullable = false)
254+
))
255+
// Parse bytes read data
256+
257+
// Get normal read data for comparison
258+
val normalData = getNormalReadData(tempDir.getAbsolutePath)
259+
compareNormalAndBytesData(
260+
normalData, bytesReadDf, "default", keySchema, valueSchema, skipVersionBytes)
261+
}
238262
}
239263
}
240264
}

0 commit comments

Comments
 (0)