@@ -66,28 +66,38 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
6666 val sourceOptions = StateSourceOptions .modifySourceOptions(hadoopConf,
6767 StateSourceOptions .apply(session, hadoopConf, properties))
6868 val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
69- val stateStoreReaderInfo : StateStoreReaderInfo = getStoreMetadataAndRunChecks(
70- sourceOptions)
71-
72- // The key state encoder spec should be available for all operators except stream-stream joins
73- val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
74- stateStoreReaderInfo.keyStateEncoderSpecOpt.get
69+ if (sourceOptions.readAllColumnFamilies) {
70+ // For readAllColumnFamilies mode, we don't need specific metadata
71+ val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec (new StructType ())
72+ new StateTable (session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
73+ None , None , None , None )
7574 } else {
76- val keySchema = SchemaUtil .getSchemaAsDataType(schema, " key" ).asInstanceOf [StructType ]
77- NoPrefixKeyStateEncoderSpec (keySchema)
78- }
75+ val stateStoreReaderInfo : StateStoreReaderInfo = getStoreMetadataAndRunChecks(
76+ sourceOptions)
7977
80- new StateTable (session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
81- stateStoreReaderInfo.transformWithStateVariableInfoOpt,
82- stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
83- stateStoreReaderInfo.stateSchemaProviderOpt,
84- stateStoreReaderInfo.joinColFamilyOpt)
78+ // The key state encoder spec should be available for all operators except stream-stream joins
79+ val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
80+ stateStoreReaderInfo.keyStateEncoderSpecOpt.get
81+ } else {
82+ val keySchema = SchemaUtil .getSchemaAsDataType(schema, " key" ).asInstanceOf [StructType ]
83+ NoPrefixKeyStateEncoderSpec (keySchema)
84+ }
85+ new StateTable (session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
86+ stateStoreReaderInfo.transformWithStateVariableInfoOpt,
87+ stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
88+ stateStoreReaderInfo.stateSchemaProviderOpt,
89+ stateStoreReaderInfo.joinColFamilyOpt)
90+ }
8591 }
8692
8793 override def inferSchema (options : CaseInsensitiveStringMap ): StructType = {
8894 val sourceOptions = StateSourceOptions .modifySourceOptions(hadoopConf,
8995 StateSourceOptions .apply(session, hadoopConf, options))
90-
96+ if (sourceOptions.readAllColumnFamilies) {
97+ // For readAllColumnFamilies mode, return the binary schema directly
98+ return SchemaUtil .getSourceSchema(
99+ sourceOptions, new StructType (), new StructType (), None , None )
100+ }
91101 val stateStoreReaderInfo : StateStoreReaderInfo = getStoreMetadataAndRunChecks(
92102 sourceOptions)
93103 val oldSchemaFilePaths = StateDataSource .getOldSchemaFilePaths(sourceOptions, hadoopConf)
@@ -372,6 +382,7 @@ case class StateSourceOptions(
372382 stateVarName : Option [String ],
373383 readRegisteredTimers : Boolean ,
374384 flattenCollectionTypes : Boolean ,
385+ readAllColumnFamilies : Boolean ,
375386 startOperatorStateUniqueIds : Option [Array [Array [String ]]] = None ,
376387 endOperatorStateUniqueIds : Option [Array [Array [String ]]] = None ) {
377388 def stateCheckpointLocation : Path = new Path (resolvedCpLocation, DIR_NAME_STATE )
@@ -380,7 +391,8 @@ case class StateSourceOptions(
380391 var desc = s " StateSourceOptions(checkpointLocation= $resolvedCpLocation, batchId= $batchId, " +
381392 s " operatorId= $operatorId, storeName= $storeName, joinSide= $joinSide, " +
382393 s " stateVarName= ${stateVarName.getOrElse(" None" )}, + " +
383- s " flattenCollectionTypes= $flattenCollectionTypes"
394+ s " flattenCollectionTypes= $flattenCollectionTypes" +
395+ s " readAllColumnFamilies= $readAllColumnFamilies"
384396 if (fromSnapshotOptions.isDefined) {
385397 desc += s " , snapshotStartBatchId= ${fromSnapshotOptions.get.snapshotStartBatchId}"
386398 desc += s " , snapshotPartitionId= ${fromSnapshotOptions.get.snapshotPartitionId}"
@@ -407,6 +419,7 @@ object StateSourceOptions extends DataSourceOptions {
407419 val STATE_VAR_NAME = newOption(" stateVarName" )
408420 val READ_REGISTERED_TIMERS = newOption(" readRegisteredTimers" )
409421 val FLATTEN_COLLECTION_TYPES = newOption(" flattenCollectionTypes" )
422+ val READ_ALL_COLUMN_FAMILIES = newOption(" readAllColumnFamilies" )
410423
411424 object JoinSideValues extends Enumeration {
412425 type JoinSideValues = Value
@@ -492,6 +505,27 @@ object StateSourceOptions extends DataSourceOptions {
492505
493506 val readChangeFeed = Option (options.get(READ_CHANGE_FEED )).exists(_.toBoolean)
494507
508+ val readAllColumnFamilies = try {
509+ Option (options.get(READ_ALL_COLUMN_FAMILIES ))
510+ .map(_.toBoolean).getOrElse(false )
511+ } catch {
512+ case _ : IllegalArgumentException =>
513+ throw StateDataSourceErrors .invalidOptionValue(READ_ALL_COLUMN_FAMILIES ,
514+ " Boolean value is expected" )
515+ }
516+
517+ if (readAllColumnFamilies && stateVarName.isDefined) {
518+ throw StateDataSourceErrors .conflictOptions(Seq (READ_ALL_COLUMN_FAMILIES , STATE_VAR_NAME ))
519+ }
520+
521+ if (readAllColumnFamilies && joinSide != JoinSideValues .none) {
522+ throw StateDataSourceErrors .conflictOptions(Seq (READ_ALL_COLUMN_FAMILIES , JOIN_SIDE ))
523+ }
524+
525+ if (readAllColumnFamilies && readChangeFeed) {
526+ throw StateDataSourceErrors .conflictOptions(Seq (READ_ALL_COLUMN_FAMILIES , READ_CHANGE_FEED ))
527+ }
528+
495529 val changeStartBatchId = Option (options.get(CHANGE_START_BATCH_ID )).map(_.toLong)
496530 var changeEndBatchId = Option (options.get(CHANGE_END_BATCH_ID )).map(_.toLong)
497531
@@ -616,7 +650,7 @@ object StateSourceOptions extends DataSourceOptions {
616650 resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
617651 readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
618652 stateVarName, readRegisteredTimers, flattenCollectionTypes,
619- startOperatorStateUniqueIds, endOperatorStateUniqueIds)
653+ readAllColumnFamilies, startOperatorStateUniqueIds, endOperatorStateUniqueIds)
620654 }
621655
622656 private def getLastCommittedBatch (session : SparkSession , checkpointLocation : String ): Long = {
0 commit comments