Skip to content

Commit 48521c3

Browse files
committed
get keySchema from stateStoreColFamilySchemaOpt
1 parent aee5732 commit 48521c3

File tree

3 files changed

+47
-49
lines changed

3 files changed

+47
-49
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.v2.state
1818

19-
import scala.collection.mutable
20-
2119
import org.apache.spark.internal.Logging
2220
import org.apache.spark.sql.catalyst.InternalRow
2321
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
@@ -53,7 +51,7 @@ class StatePartitionReaderFactory(
5351
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
5452
if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) {
5553
new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf,
56-
stateStoreInputPartition, schema, keyStateEncoderSpec)
54+
stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt)
5755
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
5856
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
5957
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
@@ -83,39 +81,25 @@ abstract class StatePartitionReaderBase(
8381
extends PartitionReader[InternalRow] with Logging {
8482
// Used primarily as a placeholder for the value schema in the context of
8583
// state variables used within the transformWithState operator.
86-
// Also used as a placeholder for both key and value schema for
87-
// StatePartitionAllColumnFamiliesReader
88-
private val placeholderSchema: StructType =
84+
private val schemaForValueRow: StructType =
8985
StructType(Array(StructField("__dummy__", NullType)))
9086

91-
private val colFamilyToSchema : mutable.HashMap[String, StateStoreColFamilySchema] = {
92-
val stateStoreId = StateStoreId(
93-
partition.sourceOptions.stateCheckpointLocation.toString,
94-
partition.sourceOptions.operatorId,
95-
StateStore.PARTITION_ID_TO_CHECK_SCHEMA,
96-
partition.sourceOptions.storeName)
97-
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
98-
val manager = new StateSchemaCompatibilityChecker(stateStoreProviderId, hadoopConf.value)
99-
val schemaFile = manager.readSchemaFile()
100-
val schemaMap = mutable.HashMap[String, StateStoreColFamilySchema]()
101-
schemaFile.foreach { schema => schemaMap.put(schema.colFamilyName, schema)}
102-
schemaMap
103-
}
104-
105-
protected val keySchema = {
87+
protected val keySchema : StructType = {
10688
if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) {
10789
SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions)
10890
} else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
109-
colFamilyToSchema(StateStore.DEFAULT_COL_FAMILY_NAME).keySchema
91+
require(stateStoreColFamilySchemaOpt.isDefined)
92+
stateStoreColFamilySchemaOpt.map(_.keySchema).get
11093
} else {
11194
SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
11295
}
11396
}
11497

115-
protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
116-
placeholderSchema
98+
protected val valueSchema : StructType = if (stateVariableInfoOpt.isDefined) {
99+
schemaForValueRow
117100
} else if (partition.sourceOptions.internalOnlyReadAllColumnFamilies) {
118-
colFamilyToSchema(StateStore.DEFAULT_COL_FAMILY_NAME).valueSchema
101+
require(stateStoreColFamilySchemaOpt.isDefined)
102+
stateStoreColFamilySchemaOpt.map(_.valueSchema).get
119103
} else {
120104
SchemaUtil.getSchemaAsDataType(
121105
schema, "value").asInstanceOf[StructType]
@@ -273,11 +257,12 @@ class StatePartitionAllColumnFamiliesReader(
273257
hadoopConf: SerializableConfiguration,
274258
partition: StateStoreInputPartition,
275259
schema: StructType,
276-
keyStateEncoderSpec: KeyStateEncoderSpec)
260+
keyStateEncoderSpec: KeyStateEncoderSpec,
261+
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
277262
extends StatePartitionReaderBase(
278263
storeConf,
279264
hadoopConf, partition, schema,
280-
keyStateEncoderSpec, None, None, None, None) {
265+
keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
281266

282267
private lazy val store: ReadStateStore = {
283268
assert(getStartStoreUniqueId == getEndStoreUniqueId,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionErrors.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ object OfflineStateRepartitionErrors {
8888

8989
def unsupportedStateStoreProviderError(
9090
checkpointLocation: String,
91-
providerClass: String
92-
): StateRepartitionUnsupportedProviderError = {
91+
providerClass: String): StateRepartitionUnsupportedProviderError = {
9392
new StateRepartitionUnsupportedProviderError(checkpointLocation, providerClass)
9493
}
9594
}
@@ -211,7 +210,8 @@ class StateRepartitionUnsupportedOffsetSeqVersionError(
211210

212211
class StateRepartitionUnsupportedProviderError(
213212
checkpointLocation: String,
214-
provider: String) extends StateRepartitionInvalidCheckpointError(
213+
provider: String)
214+
extends StateRepartitionInvalidCheckpointError(
215215
checkpointLocation,
216216
subClass = "UNSUPPORTED_PROVIDER",
217217
messageParameters = Map("provider" -> provider))

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
6464
.option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, "true")
6565
.option(StateSourceOptions.STORE_NAME, storeName.orNull)
6666
.load()
67-
.selectExpr("partition_key", "key_bytes", "value_bytes", "column_family_name")
6867
}
6968

7069
/**
@@ -221,9 +220,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
221220
))
222221

223222
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
224-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
223+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
225224

226-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
225+
validateBytesReadDfSchema(bytesDf)
226+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
227227
}
228228
}
229229
}
@@ -248,9 +248,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
248248
))
249249

250250
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
251-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
251+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
252252

253-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
253+
validateBytesReadDfSchema(bytesDf)
254+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
254255
}
255256
}
256257
}
@@ -272,9 +273,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
272273
))
273274

274275
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
275-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
276+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
276277

277-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
278+
validateBytesReadDfSchema(bytesDf)
279+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
278280
}
279281
}
280282
}
@@ -292,9 +294,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
292294
))
293295

294296
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
295-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
297+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
296298

297-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
299+
validateBytesReadDfSchema(bytesDf)
300+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
298301
}
299302
}
300303

@@ -310,9 +313,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
310313
))
311314

312315
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
313-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
316+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
314317

315-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
318+
validateBytesReadDfSchema(bytesDf)
319+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
316320
}
317321
}
318322

@@ -328,9 +332,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
328332
))
329333

330334
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
331-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
335+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
332336

333-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
337+
validateBytesReadDfSchema(bytesDf)
338+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
334339
}
335340
}
336341

@@ -353,9 +358,10 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
353358
))
354359

355360
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
356-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
361+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
357362

358-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
363+
validateBytesReadDfSchema(bytesDf)
364+
compareNormalAndBytesData(normalData, bytesDf.collect(), "default", keySchema, valueSchema)
359365
}
360366
}
361367

@@ -378,9 +384,11 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
378384
))
379385

380386
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
381-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
387+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
382388

383-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
389+
validateBytesReadDfSchema(bytesDf)
390+
compareNormalAndBytesData(
391+
normalData, bytesDf.collect(), "default", keySchema, valueSchema)
384392
}
385393
}
386394
}
@@ -403,8 +411,11 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
403411
))
404412

405413
val normalData = getNormalReadDf(tempDir.getAbsolutePath).collect()
406-
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath).collect()
407-
compareNormalAndBytesData(normalData, bytesDf, "default", keySchema, valueSchema)
414+
val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
415+
416+
validateBytesReadDfSchema(bytesDf)
417+
compareNormalAndBytesData(
418+
normalData, bytesDf.collect(), "default", keySchema, valueSchema)
408419
}
409420
}
410421
}
@@ -427,6 +438,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
427438
StructField("value", LongType)
428439
))
429440

441+
validateBytesReadDfSchema(stateBytesDfForRight)
430442
compareNormalAndBytesData(
431443
stateReaderForRight.collect(),
432444
stateBytesDfForRight.collect(),
@@ -458,6 +470,7 @@ class StatePartitionAllColumnFamiliesReaderSuite extends StateDataSourceTestBase
458470
))
459471
}
460472

473+
validateBytesReadDfSchema(stateBytesDfForRight)
461474
compareNormalAndBytesData(
462475
stateReaderForRight.collect(),
463476
stateBytesDfForRight.collect(),

0 commit comments

Comments
 (0)