1616 */
1717package org .apache .spark .sql .execution .datasources .v2 .state
1818
19- import org .apache .spark .sql .DataFrame
19+ import org .apache .spark .sql .{ DataFrame , Row }
2020import org .apache .spark .sql .catalyst .expressions .{GenericRowWithSchema , UnsafeRow }
2121import org .apache .spark .sql .execution .streaming .runtime .MemoryStream
2222import org .apache .spark .sql .execution .streaming .state .RocksDBStateStoreProvider
@@ -60,10 +60,7 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
6060 /**
6161 * Validates the schema and column families of the bytes read DataFrame.
6262 */
63- private def validateBytesReadSchema (
64- df : DataFrame ,
65- expectedRowCount : Int ,
66- expectedColumnFamilies : Seq [String ]): Unit = {
63+ private def validateBytesReadSchema (df : DataFrame ): Unit = {
6764 // Verify schema
6865 val schema = df.schema
6966 assert(schema.fieldNames === Array (
@@ -72,50 +69,25 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
7269 assert(schema(" key_bytes" ).dataType.typeName === " binary" )
7370 assert(schema(" value_bytes" ).dataType.typeName === " binary" )
7471 assert(schema(" column_family_name" ).dataType.typeName === " string" )
75-
76- // Verify data
77- val rows = df
78- .selectExpr(" partition_key" , " key_bytes" , " value_bytes" , " column_family_name" )
79- .collect()
80- assert(rows.length == expectedRowCount,
81- s " Expected $expectedRowCount rows but got: ${rows.length}" )
82-
83- val columnFamilies = rows.map(r => Option (r.getString(3 )).getOrElse(" null" )).distinct.sorted
84- assert(columnFamilies.length == expectedColumnFamilies.length,
85- s " Expected ${expectedColumnFamilies.length} column families, " +
86- s " but got ${columnFamilies.length}: ${columnFamilies.mkString(" , " )}" )
87-
88- expectedColumnFamilies.foreach { expectedCF =>
89- assert(columnFamilies.contains(expectedCF),
90- s " Expected column family ' $expectedCF', " +
91- s " but got: ${columnFamilies.mkString(" , " )}" )
92- }
9372 }
9473
95- private def parseBytesReadData (
96- df : DataFrame )
74+ private def parseBytesReadData (df : Array [Row ], keyLength : Int , valueLength : Int )
9775 : Set [(GenericRowWithSchema , UnsafeRow , UnsafeRow , String )] = {
98- df.selectExpr(" partition_key" , " key_bytes" , " value_bytes" , " column_family_name" )
99- .collect()
100- .map { row =>
76+ df.map { row =>
10177 val partitionKey = row.getAs[GenericRowWithSchema ](0 )
10278 val keyBytes = row.getAs[Array [Byte ]](1 )
10379 val valueBytes = row.getAs[Array [Byte ]](2 )
10480 val columnFamily = row.getString(3 )
10581
10682 // Deserialize key bytes to UnsafeRow
107- val keyRow = new UnsafeRow (1 )
108- // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
109- // This is for RocksDB provider
83+ val keyRow = new UnsafeRow (keyLength)
11084 keyRow.pointTo(
11185 keyBytes,
11286 Platform .BYTE_ARRAY_OFFSET ,
11387 keyBytes.length)
11488
11589 // Deserialize value bytes to UnsafeRow
116- val valueRow = new UnsafeRow (2 )
117- // Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
118- // This is for RocksDB provider
90+ val valueRow = new UnsafeRow (valueLength)
11991 valueRow.pointTo(
12092 valueBytes,
12193 Platform .BYTE_ARRAY_OFFSET ,
@@ -134,9 +106,15 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
134106 columnFamily : String ,
135107 keySchema : StructType ,
136108 valueSchema : StructType ): Unit = {
109+ // Verify data
110+ val bytesDf = bytesReadDf
111+ .selectExpr(" partition_key" , " key_bytes" , " value_bytes" , " column_family_name" )
112+ .collect()
113+ assert(bytesDf.length == 10 ,
114+ s " Expected 10 rows but got: ${bytesDf.length}" )
137115
138116 // Filter bytes data for the specified column family
139- val bytesData = parseBytesReadData(bytesReadDf )
117+ val bytesData = parseBytesReadData(bytesDf, keySchema.length, valueSchema.length )
140118 val filteredBytesData = bytesData.filter(_._4 == columnFamily)
141119
142120 // Apply the projection
@@ -203,10 +181,7 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
203181 val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath)
204182
205183 // Verify schema and column families
206- validateBytesReadSchema(bytesReadDf,
207- expectedRowCount = 10 ,
208- expectedColumnFamilies = Seq (" default" ))
209-
184+ validateBytesReadSchema(bytesReadDf)
210185 // Compare normal and bytes data for default column family
211186 val keySchema : StructType = StructType (Array (
212187 StructField (" key" , IntegerType , nullable = false )
@@ -217,7 +192,6 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
217192 StructField (" count" , LongType , nullable = false ),
218193 StructField (" sum" , LongType , nullable = false )
219194 ))
220- // Parse bytes read data
221195
222196 // Get normal read data for comparison
223197 val normalData = getNormalReadData(tempDir.getAbsolutePath)
0 commit comments