1616 */
1717package org .apache .spark .sql .execution .datasources .v2 .state
1818
19- import org .apache .spark .sql .{ DataFrame , Row }
19+ import org .apache .spark .sql .DataFrame
2020import org .apache .spark .sql .catalyst .expressions .UnsafeRow
2121import org .apache .spark .sql .execution .streaming .runtime .MemoryStream
22- import org .apache .spark .sql .execution .streaming .state .RocksDBStateStoreProvider
22+ import org .apache .spark .sql .execution .streaming .state .{ HDFSBackedStateStoreProvider , RocksDBStateStoreProvider }
2323import org .apache .spark .sql .functions .{count , sum }
2424import org .apache .spark .sql .internal .SQLConf
2525import org .apache .spark .sql .streaming .OutputMode
@@ -38,21 +38,12 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
3838 /**
3939 * Returns a set of (partitionId, key, value) tuples from a normal state read.
4040 */
41- private def getNormalReadData (checkpointDir : String ): Set [( Int , Row , Row )] = {
42- val normalReadDf = spark.read
41+ private def getNormalReadData (checkpointDir : String ): DataFrame = {
42+ spark.read
4343 .format(" statestore" )
4444 .option(StateSourceOptions .PATH , checkpointDir)
4545 .load()
4646 .selectExpr(" partition_id" , " key" , " value" )
47-
48- normalReadDf.collect()
49- .map { row =>
50- val partitionId = row.getInt(0 )
51- val key = row.getStruct(1 )
52- val value = row.getStruct(2 )
53- (partitionId, key, value)
54- }
55- .toSet
5647 }
5748
5849 /**
@@ -108,9 +99,14 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
10899
109100 /**
110101 * Parses the bytes read DataFrame into a set of (partitionId, key, value, columnFamily) tuples.
102+ * For RocksDB provider, skipVersionBytes should be true.
103+ * For HDFS provider, skipVersionBytes should be false.
111104 */
112105 private def parseBytesReadData (
113- df : DataFrame , numOfKey : Int , numOfValue : Int ): Set [(Int , UnsafeRow , UnsafeRow , String )] = {
106+ df : DataFrame ,
107+ numOfKey : Int ,
108+ numOfValue : Int ,
109+ skipVersionBytes : Boolean = true ): Set [(Int , UnsafeRow , UnsafeRow , String )] = {
114110 df.selectExpr(" partition_id" , " key_bytes" , " value_bytes" , " column_family_name" )
115111 .collect()
116112 .map { row =>
@@ -120,20 +116,38 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
120116 val columnFamily = row.getString(3 )
121117
122118 // Deserialize key bytes to UnsafeRow
123- // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
124119 val keyRow = new UnsafeRow (numOfKey)
125- keyRow.pointTo(
126- keyBytes,
127- Platform .BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider .STATE_ENCODING_NUM_VERSION_BYTES ,
128- keyBytes.length - RocksDBStateStoreProvider .STATE_ENCODING_NUM_VERSION_BYTES )
120+ if (skipVersionBytes) {
121+ // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
122+ // This is for RocksDB provider
123+ keyRow.pointTo(
124+ keyBytes,
125+ Platform .BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider .STATE_ENCODING_NUM_VERSION_BYTES ,
126+ keyBytes.length - RocksDBStateStoreProvider .STATE_ENCODING_NUM_VERSION_BYTES )
127+ } else {
128+ // HDFS provider doesn't add version bytes, use bytes directly
129+ keyRow.pointTo(
130+ keyBytes,
131+ Platform .BYTE_ARRAY_OFFSET ,
132+ keyBytes.length)
133+ }
129134
130135 // Deserialize value bytes to UnsafeRow
131- // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
132136 val valueRow = new UnsafeRow (numOfValue)
133- valueRow.pointTo(
134- valueBytes,
135- Platform .BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider .STATE_ENCODING_NUM_VERSION_BYTES ,
136- valueBytes.length - RocksDBStateStoreProvider .STATE_ENCODING_NUM_VERSION_BYTES )
137+ if (skipVersionBytes) {
138+ // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
139+ // This is for RocksDB provider
140+ valueRow.pointTo(
141+ valueBytes,
142+ Platform .BYTE_ARRAY_OFFSET + RocksDBStateStoreProvider .STATE_ENCODING_NUM_VERSION_BYTES ,
143+ valueBytes.length - RocksDBStateStoreProvider .STATE_ENCODING_NUM_VERSION_BYTES )
144+ } else {
145+ // HDFS provider doesn't add version bytes, use bytes directly
146+ valueRow.pointTo(
147+ valueBytes,
148+ Platform .BYTE_ARRAY_OFFSET ,
149+ valueBytes.length)
150+ }
137151
138152 (partitionId, keyRow.copy(), valueRow.copy(), columnFamily)
139153 }
@@ -144,24 +158,31 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
144158 * Compares normal read data with bytes read data for a specific column family.
145159 */
146160 private def compareNormalAndBytesData (
147- normalData : Set [( Int , Row , Row )] ,
148- bytesData : Set [( Int , UnsafeRow , UnsafeRow , String )] ,
161+ normalReadDf : DataFrame ,
162+ bytesReadDf : DataFrame ,
149163 columnFamily : String ,
150164 keySchema : StructType ,
151- valueSchema : StructType ): Unit = {
165+ valueSchema : StructType ,
166+ skipVersionBytes : Boolean ): Unit = {
167+
152168 // Filter bytes data for the specified column family
169+ val bytesData = parseBytesReadData(bytesReadDf, keySchema.length, valueSchema.length,
170+ skipVersionBytes)
153171 val filteredBytesData = bytesData.filter(_._4 == columnFamily)
154172
155- // Verify same number of rows
156- assert(filteredBytesData.size == normalData.size,
157- s " Row count mismatch for column family ' $columnFamily': " +
158- s " normal read has ${filteredBytesData.size} rows, bytes read has ${normalData.size} rows " )
159173 // Convert to comparable format (extract field values)
160- val normalSet = normalData.map { case (partId, key, value) =>
174+ val normalSet = normalReadDf.collect().map { row =>
175+ val partitionId = row.getInt(0 )
176+ val key = row.getStruct(1 )
177+ val value = row.getStruct(2 )
161178 val keyFields = (0 until key.length).map(i => key.get(i))
162179 val valueFields = (0 until value.length).map(i => value.get(i))
163- (partId, keyFields, valueFields)
164- }
180+ (partitionId, keyFields, valueFields)
181+ }.toSet
182+ // Verify same number of rows
183+ assert(filteredBytesData.size == normalSet.size,
184+ s " Row count mismatch for column family ' $columnFamily': " +
185+ s " normal read has ${filteredBytesData.size} rows, bytes read has ${normalSet.size} rows " )
165186
166187 val bytesSet = filteredBytesData.map { case (partId, keyRow, valueRow, _) =>
167188 val keyFields = (0 until keySchema.length).map(i =>
@@ -174,67 +195,70 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
174195 assert(normalSet == bytesSet)
175196 }
176197
177- test(" read all column families with simple operator" ) {
178- withTempDir { tempDir =>
179- withSQLConf(
180- SQLConf .STATE_STORE_PROVIDER_CLASS .key ->
181- classOf [RocksDBStateStoreProvider ].getName,
182- SQLConf .SHUFFLE_PARTITIONS .key -> " 2" ) {
183-
184- val inputData = MemoryStream [Int ]
185- val aggregated = inputData.toDF()
186- .selectExpr(" value" , " value % 10 AS groupKey" )
187- .groupBy($" groupKey" )
188- .agg(
189- count(" *" ).as(" cnt" ),
190- sum(" value" ).as(" sum" )
198+ Seq (
199+ (" RocksDBStateStoreProvider" , classOf [RocksDBStateStoreProvider ], true ),
200+ (" HDFSBackedStateStoreProvider" , classOf [HDFSBackedStateStoreProvider ], false )
201+ ).foreach { case (providerName, providerClass, skipVersionBytes) =>
202+ test(s " read all column families with simple operator - $providerName" ) {
203+ withTempDir { tempDir =>
204+ withSQLConf(
205+ SQLConf .STATE_STORE_PROVIDER_CLASS .key -> providerClass.getName,
206+ SQLConf .SHUFFLE_PARTITIONS .key -> " 2" ) {
207+
208+ val inputData = MemoryStream [Int ]
209+ val aggregated = inputData.toDF()
210+ .selectExpr(" value" , " value % 10 AS groupKey" )
211+ .groupBy($" groupKey" )
212+ .agg(
213+ count(" *" ).as(" cnt" ),
214+ sum(" value" ).as(" sum" )
215+ )
216+ .as[(Int , Long , Long )]
217+
218+ testStream(aggregated, OutputMode .Update )(
219+ StartStream (checkpointLocation = tempDir.getAbsolutePath),
220+ // batch 0
221+ AddData (inputData, 0 until 20 : _* ),
222+ CheckLastBatch (
223+ (0 , 2 , 10 ), // 0, 10
224+ (1 , 2 , 12 ), // 1, 11
225+ (2 , 2 , 14 ), // 2, 12
226+ (3 , 2 , 16 ), // 3, 13
227+ (4 , 2 , 18 ), // 4, 14
228+ (5 , 2 , 20 ), // 5, 15
229+ (6 , 2 , 22 ), // 6, 16
230+ (7 , 2 , 24 ), // 7, 17
231+ (8 , 2 , 26 ), // 8, 18
232+ (9 , 2 , 28 ) // 9, 19
233+ ),
234+ StopStream
191235 )
192- .as[(Int , Long , Long )]
193-
194- testStream(aggregated, OutputMode .Update )(
195- StartStream (checkpointLocation = tempDir.getAbsolutePath),
196- // batch 0
197- AddData (inputData, 0 until 20 : _* ),
198- CheckLastBatch (
199- (0 , 2 , 10 ), // 0, 10
200- (1 , 2 , 12 ), // 1, 11
201- (2 , 2 , 14 ), // 2, 12
202- (3 , 2 , 16 ), // 3, 13
203- (4 , 2 , 18 ), // 4, 14
204- (5 , 2 , 20 ), // 5, 15
205- (6 , 2 , 22 ), // 6, 16
206- (7 , 2 , 24 ), // 7, 17
207- (8 , 2 , 26 ), // 8, 18
208- (9 , 2 , 28 ) // 9, 19
209- ),
210- StopStream
211- )
212-
213- // Read state data once with READ_ALL_COLUMN_FAMILIES = true
214- val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath)
215-
216- // Verify schema and column families
217- validateBytesReadSchema(bytesReadDf,
218- expectedRowCount = 10 ,
219- expectedColumnFamilies = Seq (" default" ))
220-
221- // Get normal read data for comparison
222- val normalData = getNormalReadData(tempDir.getAbsolutePath)
223-
224- // Compare normal and bytes data for default column family
225- val keySchema : StructType = StructType (Array (
226- StructField (" key" , IntegerType , nullable = false )
227- ))
228-
229- // Value schema for the aggregation: count and sum columns
230- val valueSchema : StructType = StructType (Array (
231- StructField (" count" , LongType , nullable = false ),
232- StructField (" sum" , LongType , nullable = false )
233- ))
234- // Parse bytes read data
235- val bytesData = parseBytesReadData(bytesReadDf, keySchema.length, valueSchema.length)
236-
237- compareNormalAndBytesData(normalData, bytesData, " default" , keySchema, valueSchema)
236+
237+ // Read state data once with READ_ALL_COLUMN_FAMILIES = true
238+ val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath)
239+
240+ // Verify schema and column families
241+ validateBytesReadSchema(bytesReadDf,
242+ expectedRowCount = 10 ,
243+ expectedColumnFamilies = Seq (" default" ))
244+
245+ // Compare normal and bytes data for default column family
246+ val keySchema : StructType = StructType (Array (
247+ StructField (" key" , IntegerType , nullable = false )
248+ ))
249+
250+ // Value schema for the aggregation: count and sum columns
251+ val valueSchema : StructType = StructType (Array (
252+ StructField (" count" , LongType , nullable = false ),
253+ StructField (" sum" , LongType , nullable = false )
254+ ))
255+ // Parse bytes read data
256+
257+ // Get normal read data for comparison
258+ val normalData = getNormalReadData(tempDir.getAbsolutePath)
259+ compareNormalAndBytesData(
260+ normalData, bytesReadDf, " default" , keySchema, valueSchema, skipVersionBytes)
261+ }
238262 }
239263 }
240264 }
0 commit comments