Skip to content

Commit 0f8f7d3

Browse files
committed
get keySchema from stateStoreColFamilySchemaOpt
1 parent b46e8d1 commit 0f8f7d3

File tree

3 files changed

+47
-48
lines changed

3 files changed

+47
-48
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.v2.state
1818

19-
import scala.collection.mutable
20-
2119
import org.apache.spark.internal.Logging
2220
import org.apache.spark.sql.catalyst.InternalRow
2321
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
@@ -53,7 +51,7 @@ class StatePartitionReaderFactory(
5351
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
5452
if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) {
5553
new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf,
56-
stateStoreInputPartition, schema, keyStateEncoderSpec)
54+
stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt)
5755
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
5856
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
5957
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
@@ -83,39 +81,25 @@ abstract class StatePartitionReaderBase(
8381
extends PartitionReader[InternalRow] with Logging {
8482
// Used primarily as a placeholder for the value schema in the context of
8583
// state variables used within the transformWithState operator.
86-
// Also used as a placeholder for both key and value schema for
87-
// StatePartitionAllColumnFamiliesReader
88-
private val placeholderSchema: StructType =
84+
private val schemaForValueRow: StructType =
8985
StructType(Array(StructField("__dummy__", NullType)))
9086

91-
private val colFamilyToSchema : mutable.HashMap[String, StateStoreColFamilySchema] = {
92-
val stateStoreId = StateStoreId(
93-
partition.sourceOptions.stateCheckpointLocation.toString,
94-
partition.sourceOptions.operatorId,
95-
StateStore.PARTITION_ID_TO_CHECK_SCHEMA,
96-
partition.sourceOptions.storeName)
97-
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
98-
val manager = new StateSchemaCompatibilityChecker(stateStoreProviderId, hadoopConf.value)
99-
val schemaFile = manager.readSchemaFile()
100-
val schemaMap = mutable.HashMap[String, StateStoreColFamilySchema]()
101-
schemaFile.foreach { schema => schemaMap.put(schema.colFamilyName, schema)}
102-
schemaMap
103-
}
104-
105-
protected val keySchema = {
87+
protected val keySchema : StructType = {
10688
if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) {
10789
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
10890
} else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
109-
colFamilyToSchema(StateStore.DEFAULT_COL_FAMILY_NAME).keySchema
91+
require(stateStoreColFamilySchemaOpt.isDefined)
92+
stateStoreColFamilySchemaOpt.map(_.keySchema).get
11093
} else {
11194
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
11295
}
11396
}
11497

115-
protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
116-
placeholderSchema
98+
protected val valueSchema : StructType = if (stateVariableInfoOpt.isDefined) {
99+
schemaForValueRow
117100
} else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
118-
colFamilyToSchema(StateStore.DEFAULT_COL_FAMILY_NAME).valueSchema
101+
require(stateStoreColFamilySchemaOpt.isDefined)
102+
stateStoreColFamilySchemaOpt.map(_.valueSchema).get
119103
} else {
120104
SchemaUtil.getSchemaAsDataType(
121105
schema, "value").asInstanceOf[StructType]
@@ -273,11 +257,12 @@ class StatePartitionAllColumnFamiliesReader(
273257
hadoopConf: SerializableConfiguration,
274258
partition: StateStoreInputPartition,
275259
schema: StructType,
276-
keyStateEncoderSpec: KeyStateEncoderSpec)
260+
keyStateEncoderSpec: KeyStateEncoderSpec,
261+
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
277262
extends StatePartitionReaderBase(
278263
storeConf,
279264
hadoopConf, partition, schema,
280-
keyStateEncoderSpec, None, None, None, None) {
265+
keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
281266

282267
private lazy val store: ReadStateStore = {
283268
assert(getStartStoreUniqueId == getEndStoreUniqueId,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ object OfflineStateRepartitionErrors {
8888

8989
def unsupportedStateStoreProviderError(
9090
checkpointLocation: String,
91-
providerClass: String
92-
): StateRepartitionUnsupportedProviderError = {
91+
providerClass: String): StateRepartitionUnsupportedProviderError = {
9392
new StateRepartitionUnsupportedProviderError(checkpointLocation, providerClass)
9493
}
9594
}
@@ -211,7 +210,8 @@ class StateRepartitionUnsupportedOffsetSeqVersionError(
211210

212211
class StateRepartitionUnsupportedProviderError(
213212
checkpointLocation: String,
214-
provider: String) extends StateRepartitionInvalidCheckpointError(
213+
provider: String)
214+
extends StateRepartitionInvalidCheckpointError(
215215
checkpointLocation,
216216
subClass = "UNSUPPORTED_PROVIDER",
217217
messageParameters = Map("provider" -> provider))

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
221221
))
222222

223223
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
224-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
224+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
225225

226-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
226+
validateBytesReadDfSchema(bytesDf)
227+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
227228
}
228229
}
229230
}
@@ -248,9 +249,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
248249
))
249250

250251
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
251-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
252+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
252253

253-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
254+
validateBytesReadDfSchema(bytesDf)
255+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
254256
}
255257
}
256258
}
@@ -272,9 +274,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
272274
))
273275

274276
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
275-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
277+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
276278

277-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
279+
validateBytesReadDfSchema(bytesDf)
280+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
278281
}
279282
}
280283
}
@@ -292,9 +295,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
292295
))
293296

294297
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
295-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
298+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
296299

297-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
300+
validateBytesReadDfSchema(bytesDf)
301+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
298302
}
299303
}
300304

@@ -310,9 +314,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
310314
))
311315

312316
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
313-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
317+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
314318

315-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
319+
validateBytesReadDfSchema(bytesDf)
320+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
316321
}
317322
}
318323

@@ -328,9 +333,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
328333
))
329334

330335
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
331-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
336+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
332337

333-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
338+
validateBytesReadDfSchema(bytesDf)
339+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
334340
}
335341
}
336342

@@ -353,9 +359,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
353359
))
354360

355361
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
356-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
362+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
357363

358-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
364+
validateBytesReadDfSchema(bytesDf)
365+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
359366
}
360367
}
361368

@@ -378,9 +385,11 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
378385
))
379386

380387
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
381-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
388+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
382389

383-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
390+
validateBytesReadDfSchema(bytesDf)
391+
compareNormalAndBytesData(
392+
normalData, bytesDf.collect(), "default", keySchema, valueSchema)
384393
}
385394
}
386395
}
@@ -403,8 +412,11 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
403412
))
404413

405414
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
406-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
407-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
415+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
416+
417+
validateBytesReadDfSchema(bytesDf)
418+
compareNormalAndBytesData(
419+
normalData, bytesDf.collect(), "default", keySchema, valueSchema)
408420
}
409421
}
410422
}
@@ -427,6 +439,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
427439
StructField("value", LongType)
428440
))
429441

442+
validateBytesReadDfSchema(stateBytesDfForRight)
430443
compareNormalAndBytesData(
431444
stateReaderForRight.collect(),
432445
stateBytesDfForRight.collect(),
@@ -458,6 +471,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
458471
))
459472
}
460473

474+
validateBytesReadDfSchema(stateBytesDfForRight)
461475
compareNormalAndBytesData(
462476
stateReaderForRight.collect(),
463477
stateBytesDfForRight.collect(),

0 commit comments

Comments
 (0)