Skip to content

Commit 250ccca

Browse files
ericm-dbanishshri-db
authored andcommitted
[SPARK-54423][SS] Introducing OffsetMap to enable Named Streaming Sources
### What changes were proposed in this pull request? Introducing the OffsetMap format to key source progress by source name, as opposed to ordinal in the logical plan ### Why are the changes needed? These changes are needed in order to enable source evolution on a streaming query (adding, removing, reordering sources) without requiring the user to set a new checkpoint directory ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #53123 from ericm-db/offset-map. Authored-by: ericm-db <eric.marnadi@databricks.com> Signed-off-by: Anish Shrigondekar <anish.shrigondekar@databricks.com>
1 parent 4435a3a commit 250ccca

File tree

19 files changed

+284
-82
lines changed

19 files changed

+284
-82
lines changed

connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.read.streaming.SparkDataStream
4242
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2ScanRelation
4343
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
4444
import org.apache.spark.sql.execution.streaming._
45-
import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeq
45+
import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqBase
4646
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
4747
import org.apache.spark.sql.execution.streaming.runtime.{MicroBatchExecution, StreamExecution, StreamingExecutionRelation}
4848
import org.apache.spark.sql.execution.streaming.runtime.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED}
@@ -854,7 +854,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with
854854
true
855855
},
856856
AssertOnQuery { q =>
857-
val latestOffset: Option[(Long, OffsetSeq)] = q.offsetLog.getLatest()
857+
val latestOffset: Option[(Long, OffsetSeqBase)] = q.offsetLog.getLatest()
858858
latestOffset.exists { offset =>
859859
!offset._2.offsets.exists(_.exists(_.json == "{}"))
860860
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
126126
val offsetLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).offsetLog
127127
offsetLog.get(batchId) match {
128128
case Some(value) =>
129-
val metadata = value.metadata.getOrElse(
129+
val metadata = value.metadataOpt.getOrElse(
130130
throw StateDataSourceErrors.offsetMetadataLogUnavailable(batchId, checkpointLocation)
131131
)
132132

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncOffsetSeqLog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class AsyncOffsetSeqLog(
8181
* the async write of the batch is completed. Future may also be completed exceptionally
8282
* to indicate some write error.
8383
*/
84-
def addAsync(batchId: Long, metadata: OffsetSeq): CompletableFuture[(Long, Boolean)] = {
84+
def addAsync(batchId: Long, metadata: OffsetSeqBase): CompletableFuture[(Long, Boolean)] = {
8585
require(metadata != null, "'null' metadata cannot written to a metadata log")
8686

8787
def issueAsyncWrite(batchId: Long): CompletableFuture[Long] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,30 +35,67 @@ import org.apache.spark.sql.execution.streaming.runtime.{MultipleWatermarkPolicy
3535
import org.apache.spark.sql.internal.SQLConf
3636
import org.apache.spark.sql.internal.SQLConf._
3737

38+
trait OffsetSeqBase {
39+
def offsets: Seq[Option[OffsetV2]]
3840

39-
/**
40-
* An ordered collection of offsets, used to track the progress of processing data from one or more
41-
* [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
42-
* vector clock that must progress linearly forward.
43-
*/
44-
case class OffsetSeq(offsets: Seq[Option[OffsetV2]], metadata: Option[OffsetSeqMetadata] = None) {
41+
def metadataOpt: Option[OffsetSeqMetadata]
42+
43+
override def toString: String = this match {
44+
case offsetMap: OffsetMap =>
45+
offsetMap.offsetsMap.map { case (sourceId, offsetOpt) =>
46+
s"$sourceId: ${offsetOpt.map(_.json).getOrElse("-")}"
47+
}.mkString("{", ", ", "}")
48+
case _ =>
49+
offsets.map(_.map(_.json).getOrElse("-")).mkString("[", ", ", "]")
50+
}
4551

4652
/**
47-
* Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of
48-
* sources.
53+
* Unpacks an offset into [[StreamProgress]] by associating each offset with the
54+
* ordered list of sources.
4955
*
50-
* This method is typically used to associate a serialized offset with actual sources (which
51-
* cannot be serialized).
56+
* This method is typically used to associate a serialized offset with actual
57+
* sources (which cannot be serialized).
5258
*/
5359
def toStreamProgress(sources: Seq[SparkDataStream]): StreamProgress = {
60+
assert(!this.isInstanceOf[OffsetMap], "toStreamProgress must be called with map")
5461
assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " +
55-
s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " +
56-
s"Cannot continue.")
62+
s"checkpoint offsets and now there are [${sources.size}] sources requested by " +
63+
s"the query. Cannot continue.")
5764
new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) }
5865
}
5966

60-
override def toString: String =
61-
offsets.map(_.map(_.json).getOrElse("-")).mkString("[", ", ", "]")
67+
/**
68+
* Converts OffsetMap to StreamProgress using source ID mapping.
69+
* This method is specific to OffsetMap and requires a mapping from sourceId to SparkDataStream.
70+
*/
71+
def toStreamProgress(
72+
sources: Seq[SparkDataStream],
73+
sourceIdToSourceMap: Map[String, SparkDataStream]): StreamProgress = {
74+
this match {
75+
case offsetMap: OffsetMap =>
76+
val streamProgressEntries = for {
77+
(sourceId, offsetOpt) <- offsetMap.offsetsMap
78+
offset <- offsetOpt
79+
source <- sourceIdToSourceMap.get(sourceId)
80+
} yield source -> offset
81+
new StreamProgress ++ streamProgressEntries
82+
case _ =>
83+
// Fallback to original method for backward compatibility
84+
toStreamProgress(sources)
85+
}
86+
}
87+
}
88+
89+
/**
90+
* An ordered collection of offsets, used to track the progress of processing data from one or more
91+
* [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
92+
* vector clock that must progress linearly forward.
93+
*/
94+
case class OffsetSeq(
95+
offsets: Seq[Option[OffsetV2]],
96+
metadata: Option[OffsetSeqMetadata] = None) extends OffsetSeqBase {
97+
98+
override def metadataOpt: Option[OffsetSeqMetadata] = metadata
6299
}
63100

64101
object OffsetSeq {
@@ -79,6 +116,23 @@ object OffsetSeq {
79116
}
80117

81118

119+
/**
120+
* A map-based collection of offsets, used to track the progress of processing data from one or more
121+
* streaming sources. Each source is identified by a string key (initially sourceId.toString()).
122+
* This replaces the sequence-based approach with a more flexible map-based approach to support
123+
* named source identities.
124+
*/
125+
case class OffsetMap(
126+
offsetsMap: Map[String, Option[OffsetV2]],
127+
metadataOpt: Option[OffsetSeqMetadata] = None) extends OffsetSeqBase {
128+
129+
// OffsetMap does not support sequence-based access
130+
override def offsets: Seq[Option[OffsetV2]] = {
131+
throw new UnsupportedOperationException(
132+
"OffsetMap does not support sequence-based offsets access. Use offsetsMap directly.")
133+
}
134+
}
135+
82136
/**
83137
* Contains metadata associated with a [[OffsetSeq]]. This information is
84138
* persisted to the offset log in the checkpoint location via the [[OffsetSeq]] metadata field.
@@ -97,7 +151,8 @@ object OffsetSeq {
97151
case class OffsetSeqMetadata(
98152
batchWatermarkMs: Long = 0,
99153
batchTimestampMs: Long = 0,
100-
conf: Map[String, String] = Map.empty) {
154+
conf: Map[String, String] = Map.empty,
155+
version: Int = 1) {
101156
def json: String = Serialization.write(this)(OffsetSeqMetadata.format)
102157
}
103158

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,56 +43,109 @@ import org.apache.spark.sql.execution.streaming.runtime.SerializedOffset
4343
* - // No offset for this source i.e., an invalid JSON string
4444
* {2} // LongOffset 2
4545
* ...
46+
*
47+
* Version 2 format (OffsetMap):
48+
* v2 // version 2
49+
* metadata
50+
* 0:{0} // sourceId:offset
51+
* 1:{3} // sourceId:offset
52+
* ...
4653
*/
4754
class OffsetSeqLog(sparkSession: SparkSession, path: String)
48-
extends HDFSMetadataLog[OffsetSeq](sparkSession, path) {
55+
extends HDFSMetadataLog[OffsetSeqBase](sparkSession, path) {
4956

50-
override protected def deserialize(in: InputStream): OffsetSeq = {
57+
override protected def deserialize(in: InputStream): OffsetSeqBase = {
5158
// called inside a try-finally where the underlying stream is closed in the caller
52-
def parseOffset(value: String): OffsetV2 = value match {
53-
case OffsetSeqLog.SERIALIZED_VOID_OFFSET => null
54-
case json => SerializedOffset(json)
55-
}
5659
val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
5760
if (!lines.hasNext) {
5861
throw new IllegalStateException("Incomplete log file")
5962
}
6063

61-
validateVersion(lines.next(), OffsetSeqLog.VERSION)
64+
val versionStr = lines.next()
65+
val versionInt = validateVersion(versionStr, OffsetSeqLog.MAX_VERSION)
6266

6367
// read metadata
6468
val metadata = lines.next().trim match {
6569
case "" => None
6670
case md => Some(md)
6771
}
6872
import org.apache.spark.util.ArrayImplicits._
69-
OffsetSeq.fill(metadata, lines.map(parseOffset).toArray.toImmutableArraySeq: _*)
73+
if (versionInt == OffsetSeqLog.VERSION_2) {
74+
// deserialize the remaining lines into the offset map
75+
val remainingLines = lines.toArray
76+
// New OffsetMap format: sourceId:offset
77+
val offsetsMap = remainingLines.map { line =>
78+
val colonIndex = line.indexOf(':')
79+
if (colonIndex == -1) {
80+
throw new IllegalStateException(s"Invalid OffsetMap format: $line")
81+
}
82+
val sourceId = line.substring(0, colonIndex)
83+
val offsetStr = line.substring(colonIndex + 1)
84+
val offset = if (offsetStr == OffsetSeqLog.SERIALIZED_VOID_OFFSET) {
85+
None
86+
} else {
87+
Some(OffsetSeqLog.parseOffset(offsetStr))
88+
}
89+
sourceId -> offset
90+
}.toMap
91+
OffsetMap(offsetsMap, metadata.map(OffsetSeqMetadata.apply))
92+
} else {
93+
OffsetSeq.fill(metadata,
94+
lines.map(OffsetSeqLog.parseOffset).toArray.toImmutableArraySeq: _*)
95+
}
7096
}
7197

72-
override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = {
98+
override protected def serialize(offsetSeq: OffsetSeqBase, out: OutputStream): Unit = {
7399
// called inside a try-finally where the underlying stream is closed in the caller
74-
out.write(("v" + OffsetSeqLog.VERSION).getBytes(UTF_8))
100+
out.write(("v" + offsetSeq.metadataOpt.map(_.version).getOrElse(OffsetSeqLog.VERSION_1))
101+
.getBytes(UTF_8))
75102

76103
// write metadata
77104
out.write('\n')
78-
out.write(offsetSeq.metadata.map(_.json).getOrElse("").getBytes(UTF_8))
105+
out.write(offsetSeq.metadataOpt.map(_.json).getOrElse("").getBytes(UTF_8))
79106

80-
// write offsets, one per line
81-
offsetSeq.offsets.map(_.map(_.json)).foreach { offset =>
82-
out.write('\n')
83-
offset match {
84-
case Some(json: String) => out.write(json.getBytes(UTF_8))
85-
case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8))
86-
}
107+
offsetSeq match {
108+
case offsetMap: OffsetMap =>
109+
// For OffsetMap, write sourceId:offset pairs, one per line
110+
offsetMap.offsetsMap.foreach { case (sourceId, offsetOpt) =>
111+
out.write('\n')
112+
out.write(sourceId.getBytes(UTF_8))
113+
out.write(':')
114+
offsetOpt match {
115+
case Some(offset) => out.write(offset.json.getBytes(UTF_8))
116+
case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8))
117+
}
118+
}
119+
case _ =>
120+
// Original sequence-based serialization
121+
offsetSeq.offsets.map(_.map(_.json)).foreach { offset =>
122+
out.write('\n')
123+
offset match {
124+
case Some(json: String) => out.write(json.getBytes(UTF_8))
125+
case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8))
126+
}
127+
}
87128
}
88129
}
89130

90131
def offsetSeqMetadataForBatchId(batchId: Long): Option[OffsetSeqMetadata] = {
91-
if (batchId < 0) None else get(batchId).flatMap(_.metadata)
132+
if (batchId < 0) {
133+
None
134+
} else {
135+
get(batchId).flatMap(_.metadataOpt)
136+
}
92137
}
93138
}
94139

95140
object OffsetSeqLog {
96-
private[streaming] val VERSION = 1
97-
private val SERIALIZED_VOID_OFFSET = "-"
141+
private[streaming] val VERSION_1 = 1
142+
private[streaming] val VERSION_2 = 2
143+
private[streaming] val VERSION = VERSION_1 // Default version for backward compatibility
144+
private[streaming] val MAX_VERSION = VERSION_2
145+
private[streaming] val SERIALIZED_VOID_OFFSET = "-"
146+
147+
private[checkpointing] def parseOffset(value: String): OffsetV2 = value match {
148+
case SERIALIZED_VOID_OFFSET => null
149+
case json => SerializedOffset(json)
150+
}
98151
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
3333
import org.apache.spark.sql.classic.SparkSession
3434
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability}
3535
import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution
36-
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, PartitionOffset, ReadLimit}
36+
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, PartitionOffset, ReadLimit, SparkDataStream}
3737
import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write}
3838
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
3939
import org.apache.spark.sql.execution.SQLExecution
@@ -65,6 +65,8 @@ class ContinuousExecution(
6565

6666
@volatile protected var sources: Seq[ContinuousStream] = Seq()
6767

68+
def sourceToIdMap: Map[SparkDataStream, String] = Map.empty
69+
6870
// For use only in test harnesses.
6971
private[sql] var currentEpochCoordinatorId: String = _
7072

@@ -186,7 +188,7 @@ class ContinuousExecution(
186188
val nextOffsets = offsetLog.get(latestEpochId).getOrElse {
187189
throw new IllegalStateException(
188190
s"Batch $latestEpochId was committed without end epoch offsets!")
189-
}
191+
}.asInstanceOf[OffsetSeq]
190192
committedOffsets = nextOffsets.toStreamProgress(sources)
191193
execCtx.batchId = latestEpochId + 1
192194

@@ -210,7 +212,8 @@ class ContinuousExecution(
210212
val execCtx = latestExecutionContext
211213

212214
if (execCtx.batchId > 0) {
213-
AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources(Some(offsets), sources)
215+
AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources(
216+
Some(offsets), sources, Map.empty[String, SparkDataStream])
214217
}
215218

216219
val withNewSources: LogicalPlan = logicalPlan transform {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AcceptsLatestSeenOffsetHandler.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@ package org.apache.spark.sql.execution.streaming.runtime
2020
import org.apache.spark.SparkUnsupportedOperationException
2121
import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, SparkDataStream}
2222
import org.apache.spark.sql.execution.streaming.Source
23-
import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeq
23+
import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqBase
2424

2525
/**
2626
* This feeds "latest seen offset" to the sources that implement AcceptsLatestSeenOffset.
2727
*/
2828
object AcceptsLatestSeenOffsetHandler {
2929
def setLatestSeenOffsetOnSources(
30-
offsets: Option[OffsetSeq],
31-
sources: Seq[SparkDataStream]): Unit = {
30+
offsets: Option[OffsetSeqBase],
31+
sources: Seq[SparkDataStream],
32+
sourceIdMap: Map[String, SparkDataStream]): Unit = {
3233
assertNoAcceptsLatestSeenOffsetWithDataSourceV1(sources)
3334

34-
offsets.map(_.toStreamProgress(sources)) match {
35+
offsets.map(_.toStreamProgress(sources, sourceIdMap)) match {
3536
case Some(streamProgress) =>
3637
streamProgress.foreach {
3738
case (src: AcceptsLatestSeenOffset, offset) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AsyncProgressTrackingMicroBatchExecution.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.streaming.WriteToStream
2525
import org.apache.spark.sql.classic.SparkSession
2626
import org.apache.spark.sql.errors.QueryExecutionErrors
2727
import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, OneTimeTrigger, ProcessingTimeTrigger}
28-
import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog, CommitMetadata, OffsetSeq}
28+
import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog, CommitMetadata, OffsetSeqBase}
2929
import org.apache.spark.sql.execution.streaming.operators.stateful.StateStoreWriter
3030
import org.apache.spark.sql.streaming.Trigger
3131
import org.apache.spark.util.{Clock, ThreadUtils}
@@ -49,7 +49,7 @@ class AsyncProgressTrackingMicroBatchExecution(
4949
// Offsets that are ready to be committed by the source.
5050
// This is needed so that we can call source commit in the same thread as micro-batch execution
5151
// to be thread safe
52-
private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
52+
private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeqBase]()
5353

5454
// to cache the batch id of the last batch written to storage
5555
private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
@@ -104,7 +104,7 @@ class AsyncProgressTrackingMicroBatchExecution(
104104
// perform quick validation to fail faster
105105
validateAndGetTrigger()
106106

107-
override def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeq] = {
107+
override def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeqBase] = {
108108
/* Initialize committed offsets to a committed batch, which at this
109109
* is the second latest batch id in the offset log.
110110
* The offset log may not be contiguous */
@@ -137,14 +137,15 @@ class AsyncProgressTrackingMicroBatchExecution(
137137
// Because we are using a thread pool with only one thread, async writes to the offset log
138138
// are still written in a serial / in order fashion
139139
offsetLog
140-
.addAsync(execCtx.batchId, execCtx.endOffsets.toOffsetSeq(sources, execCtx.offsetSeqMetadata))
141-
.thenAccept(tuple => {
142-
val (batchId, persistedToDurableStorage) = tuple
140+
.addAsync(execCtx.batchId,
141+
execCtx.endOffsets.toOffsets(sources, sourceIdMap, execCtx.offsetSeqMetadata))
142+
.thenAccept((tuple: (Long, Boolean)) => {
143+
val (batchId: Long, persistedToDurableStorage: Boolean) = tuple
143144
if (persistedToDurableStorage) {
144145
// batch id cache not initialized
145146
if (lastBatchPersistedToDurableStorage.get == -1) {
146147
lastBatchPersistedToDurableStorage.set(
147-
offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
148+
offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1L))
148149
}
149150

150151
if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {

0 commit comments

Comments
 (0)