Skip to content

Commit ddb5cd8

Browse files
author
Ubuntu
committed
refactor test
1 parent 0a878e9 commit ddb5cd8

File tree

1 file changed

+14
-40
lines changed

1 file changed

+14
-40
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReaderAllColumnFamiliesSuite.scala

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.v2.state
1818

19-
import org.apache.spark.sql.DataFrame
19+
import org.apache.spark.sql.{DataFrame, Row}
2020
import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, UnsafeRow}
2121
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
2222
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
@@ -60,10 +60,7 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
6060
/**
6161
* Validates the schema and column families of the bytes read DataFrame.
6262
*/
63-
private def validateBytesReadSchema(
64-
df: DataFrame,
65-
expectedRowCount: Int,
66-
expectedColumnFamilies: Seq[String]): Unit = {
63+
private def validateBytesReadSchema(df: DataFrame): Unit = {
6764
// Verify schema
6865
val schema = df.schema
6966
assert(schema.fieldNames === Array(
@@ -72,50 +69,25 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
7269
assert(schema("key_bytes").dataType.typeName === "binary")
7370
assert(schema("value_bytes").dataType.typeName === "binary")
7471
assert(schema("column_family_name").dataType.typeName === "string")
75-
76-
// Verify data
77-
val rows = df
78-
.selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name")
79-
.collect()
80-
assert(rows.length == expectedRowCount,
81-
s"Expected $expectedRowCount rows but got: ${rows.length}")
82-
83-
val columnFamilies = rows.map(r => Option(r.getString(3)).getOrElse("null")).distinct.sorted
84-
assert(columnFamilies.length == expectedColumnFamilies.length,
85-
s"Expected ${expectedColumnFamilies.length} column families, " +
86-
s"but got ${columnFamilies.length}: ${columnFamilies.mkString(", ")}")
87-
88-
expectedColumnFamilies.foreach { expectedCF =>
89-
assert(columnFamilies.contains(expectedCF),
90-
s"Expected column family '$expectedCF', " +
91-
s"but got: ${columnFamilies.mkString(", ")}")
92-
}
9372
}
9473

95-
private def parseBytesReadData(
96-
df: DataFrame)
74+
private def parseBytesReadData(df: Array[Row], keyLength: Int, valueLength: Int)
9775
: Set[(GenericRowWithSchema, UnsafeRow, UnsafeRow, String)] = {
98-
df.selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name")
99-
.collect()
100-
.map { row =>
76+
df.map { row =>
10177
val partitionKey = row.getAs[GenericRowWithSchema](0)
10278
val keyBytes = row.getAs[Array[Byte]](1)
10379
val valueBytes = row.getAs[Array[Byte]](2)
10480
val columnFamily = row.getString(3)
10581

10682
// Deserialize key bytes to UnsafeRow
107-
val keyRow = new UnsafeRow(1)
108-
// Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
109-
// This is for RocksDB provider
83+
val keyRow = new UnsafeRow(keyLength)
11084
keyRow.pointTo(
11185
keyBytes,
11286
Platform.BYTE_ARRAY_OFFSET,
11387
keyBytes.length)
11488

11589
// Deserialize value bytes to UnsafeRow
116-
val valueRow = new UnsafeRow(2)
117-
// Skip the version byte (STATE_ENCODING_NUM_VERSION_BYTES) at the beginning
118-
// This is for RocksDB provider
90+
val valueRow = new UnsafeRow(valueLength)
11991
valueRow.pointTo(
12092
valueBytes,
12193
Platform.BYTE_ARRAY_OFFSET,
@@ -134,9 +106,15 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
134106
columnFamily: String,
135107
keySchema: StructType,
136108
valueSchema: StructType): Unit = {
109+
// Verify data
110+
val bytesDf = bytesReadDf
111+
.selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name")
112+
.collect()
113+
assert(bytesDf.length == 10,
114+
s"Expected 10 rows but got: ${bytesDf.length}")
137115

138116
// Filter bytes data for the specified column family
139-
val bytesData = parseBytesReadData(bytesReadDf)
117+
val bytesData = parseBytesReadData(bytesDf, keySchema.length, valueSchema.length)
140118
val filteredBytesData = bytesData.filter(_._4 == columnFamily)
141119

142120
// Apply the projection
@@ -203,10 +181,7 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
203181
val bytesReadDf = getBytesReadDf(tempDir.getAbsolutePath)
204182

205183
// Verify schema and column families
206-
validateBytesReadSchema(bytesReadDf,
207-
expectedRowCount = 10,
208-
expectedColumnFamilies = Seq("default"))
209-
184+
validateBytesReadSchema(bytesReadDf)
210185
// Compare normal and bytes data for default column family
211186
val keySchema: StructType = StructType(Array(
212187
StructField("key", IntegerType, nullable = false)
@@ -217,7 +192,6 @@ class StatePartitionReaderAllColumnFamiliesSuite extends StateDataSourceTestBase
217192
StructField("count", LongType, nullable = false),
218193
StructField("sum", LongType, nullable = false)
219194
))
220-
// Parse bytes read data
221195

222196
// Get normal read data for comparison
223197
val normalData = getNormalReadData(tempDir.getAbsolutePath)

0 commit comments

Comments
 (0)