diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e619adfce17b..f4bd782617ae 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2ScanRelation import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeq +import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqBase import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.runtime.{MicroBatchExecution, StreamExecution, StreamingExecutionRelation} import org.apache.spark.sql.execution.streaming.runtime.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED} @@ -854,7 +854,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with true }, AssertOnQuery { q => - val latestOffset: Option[(Long, OffsetSeq)] = q.offsetLog.getLatest() + val latestOffset: Option[(Long, OffsetSeqBase)] = q.offsetLog.getLatest() latestOffset.exists { offset => !offset._2.offsets.exists(_.exists(_.json == "{}")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 6af418e1ddc2..c97a70eb3c8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -126,7 +126,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val offsetLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).offsetLog offsetLog.get(batchId) match { case Some(value) => - val metadata = value.metadata.getOrElse( + val metadata = value.metadataOpt.getOrElse( throw StateDataSourceErrors.offsetMetadataLogUnavailable(batchId, checkpointLocation) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncOffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncOffsetSeqLog.scala index 18d18e61da47..e6ba644ed483 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncOffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/AsyncOffsetSeqLog.scala @@ -81,7 +81,7 @@ class AsyncOffsetSeqLog( * the async write of the batch is completed. Future may also be completed exceptionally * to indicate some write error. */ - def addAsync(batchId: Long, metadata: OffsetSeq): CompletableFuture[(Long, Boolean)] = { + def addAsync(batchId: Long, metadata: OffsetSeqBase): CompletableFuture[(Long, Boolean)] = { require(metadata != null, "'null' metadata cannot written to a metadata log") def issueAsyncWrite(batchId: Long): CompletableFuture[Long] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala index 888dc0cdb912..a882d9539c4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala @@ -35,30 +35,67 @@ import org.apache.spark.sql.execution.streaming.runtime.{MultipleWatermarkPolicy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ +trait OffsetSeqBase { + def offsets: Seq[Option[OffsetV2]] -/** - * An ordered collection of offsets, used to track the progress of processing data from one or more - * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance - * vector clock that must progress linearly forward. - */ -case class OffsetSeq(offsets: Seq[Option[OffsetV2]], metadata: Option[OffsetSeqMetadata] = None) { + def metadataOpt: Option[OffsetSeqMetadata] + + override def toString: String = this match { + case offsetMap: OffsetMap => + offsetMap.offsetsMap.map { case (sourceId, offsetOpt) => + s"$sourceId: ${offsetOpt.map(_.json).getOrElse("-")}" + }.mkString("{", ", ", "}") + case _ => + offsets.map(_.map(_.json).getOrElse("-")).mkString("[", ", ", "]") + } /** - * Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of - * sources. + * Unpacks an offset into [[StreamProgress]] by associating each offset with the + * ordered list of sources. * - * This method is typically used to associate a serialized offset with actual sources (which - * cannot be serialized). + * This method is typically used to associate a serialized offset with actual + * sources (which cannot be serialized). */ def toStreamProgress(sources: Seq[SparkDataStream]): StreamProgress = { + assert(!this.isInstanceOf[OffsetMap], "toStreamProgress must be called with map") assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + - s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + - s"Cannot continue.") + s"checkpoint offsets and now there are [${sources.size}] sources requested by " + + s"the query. Cannot continue.") new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } - override def toString: String = - offsets.map(_.map(_.json).getOrElse("-")).mkString("[", ", ", "]") + /** + * Converts OffsetMap to StreamProgress using source ID mapping. + * This method is specific to OffsetMap and requires a mapping from sourceId to SparkDataStream. + */ + def toStreamProgress( + sources: Seq[SparkDataStream], + sourceIdToSourceMap: Map[String, SparkDataStream]): StreamProgress = { + this match { + case offsetMap: OffsetMap => + val streamProgressEntries = for { + (sourceId, offsetOpt) <- offsetMap.offsetsMap + offset <- offsetOpt + source <- sourceIdToSourceMap.get(sourceId) + } yield source -> offset + new StreamProgress ++ streamProgressEntries + case _ => + // Fallback to original method for backward compatibility + toStreamProgress(sources) + } + } +} + +/** + * An ordered collection of offsets, used to track the progress of processing data from one or more + * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance + * vector clock that must progress linearly forward. + */ +case class OffsetSeq( + offsets: Seq[Option[OffsetV2]], + metadata: Option[OffsetSeqMetadata] = None) extends OffsetSeqBase { + + override def metadataOpt: Option[OffsetSeqMetadata] = metadata } object OffsetSeq { @@ -79,6 +116,23 @@ object OffsetSeq { } +/** + * A map-based collection of offsets, used to track the progress of processing data from one or more + * streaming sources. Each source is identified by a string key (initially sourceId.toString()). + * This replaces the sequence-based approach with a more flexible map-based approach to support + * named source identities. + */ +case class OffsetMap( + offsetsMap: Map[String, Option[OffsetV2]], + metadataOpt: Option[OffsetSeqMetadata] = None) extends OffsetSeqBase { + + // OffsetMap does not support sequence-based access + override def offsets: Seq[Option[OffsetV2]] = { + throw new UnsupportedOperationException( + "OffsetMap does not support sequence-based offsets access. Use offsetsMap directly.") + } +} + /** * Contains metadata associated with a [[OffsetSeq]]. This information is * persisted to the offset log in the checkpoint location via the [[OffsetSeq]] metadata field. @@ -97,7 +151,8 @@ object OffsetSeq { case class OffsetSeqMetadata( batchWatermarkMs: Long = 0, batchTimestampMs: Long = 0, - conf: Map[String, String] = Map.empty) { + conf: Map[String, String] = Map.empty, + version: Int = 1) { def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala index 816563b3f09f..891a66b21b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeqLog.scala @@ -43,22 +43,26 @@ import org.apache.spark.sql.execution.streaming.runtime.SerializedOffset * - // No offset for this source i.e., an invalid JSON string * {2} // LongOffset 2 * ... + * + * Version 2 format (OffsetMap): + * v2 // version 2 + * metadata + * 0:{0} // sourceId:offset + * 1:{3} // sourceId:offset + * ... */ class OffsetSeqLog(sparkSession: SparkSession, path: String) - extends HDFSMetadataLog[OffsetSeq](sparkSession, path) { + extends HDFSMetadataLog[OffsetSeqBase](sparkSession, path) { - override protected def deserialize(in: InputStream): OffsetSeq = { + override protected def deserialize(in: InputStream): OffsetSeqBase = { // called inside a try-finally where the underlying stream is closed in the caller - def parseOffset(value: String): OffsetV2 = value match { - case OffsetSeqLog.SERIALIZED_VOID_OFFSET => null - case json => SerializedOffset(json) - } val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - validateVersion(lines.next(), OffsetSeqLog.VERSION) + val versionStr = lines.next() + val versionInt = validateVersion(versionStr, OffsetSeqLog.MAX_VERSION) // read metadata val metadata = lines.next().trim match { @@ -66,33 +70,82 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) case md => Some(md) } import org.apache.spark.util.ArrayImplicits._ - OffsetSeq.fill(metadata, lines.map(parseOffset).toArray.toImmutableArraySeq: _*) + if (versionInt == OffsetSeqLog.VERSION_2) { + // deserialize the remaining lines into the offset map + val remainingLines = lines.toArray + // New OffsetMap format: sourceId:offset + val offsetsMap = remainingLines.map { line => + val colonIndex = line.indexOf(':') + if (colonIndex == -1) { + throw new IllegalStateException(s"Invalid OffsetMap format: $line") + } + val sourceId = line.substring(0, colonIndex) + val offsetStr = line.substring(colonIndex + 1) + val offset = if (offsetStr == OffsetSeqLog.SERIALIZED_VOID_OFFSET) { + None + } else { + Some(OffsetSeqLog.parseOffset(offsetStr)) + } + sourceId -> offset + }.toMap + OffsetMap(offsetsMap, metadata.map(OffsetSeqMetadata.apply)) + } else { + OffsetSeq.fill(metadata, + lines.map(OffsetSeqLog.parseOffset).toArray.toImmutableArraySeq: _*) + } } - override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = { + override protected def serialize(offsetSeq: OffsetSeqBase, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(("v" + OffsetSeqLog.VERSION).getBytes(UTF_8)) + out.write(("v" + offsetSeq.metadataOpt.map(_.version).getOrElse(OffsetSeqLog.VERSION_1)) + .getBytes(UTF_8)) // write metadata out.write('\n') - out.write(offsetSeq.metadata.map(_.json).getOrElse("").getBytes(UTF_8)) + out.write(offsetSeq.metadataOpt.map(_.json).getOrElse("").getBytes(UTF_8)) - // write offsets, one per line - offsetSeq.offsets.map(_.map(_.json)).foreach { offset => - out.write('\n') - offset match { - case Some(json: String) => out.write(json.getBytes(UTF_8)) - case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8)) - } + offsetSeq match { + case offsetMap: OffsetMap => + // For OffsetMap, write sourceId:offset pairs, one per line + offsetMap.offsetsMap.foreach { case (sourceId, offsetOpt) => + out.write('\n') + out.write(sourceId.getBytes(UTF_8)) + out.write(':') + offsetOpt match { + case Some(offset) => out.write(offset.json.getBytes(UTF_8)) + case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8)) + } + } + case _ => + // Original sequence-based serialization + offsetSeq.offsets.map(_.map(_.json)).foreach { offset => + out.write('\n') + offset match { + case Some(json: String) => out.write(json.getBytes(UTF_8)) + case None => out.write(OffsetSeqLog.SERIALIZED_VOID_OFFSET.getBytes(UTF_8)) + } + } } } def offsetSeqMetadataForBatchId(batchId: Long): Option[OffsetSeqMetadata] = { - if (batchId < 0) None else get(batchId).flatMap(_.metadata) + if (batchId < 0) { + None + } else { + get(batchId).flatMap(_.metadataOpt) + } } } object OffsetSeqLog { - private[streaming] val VERSION = 1 - private val SERIALIZED_VOID_OFFSET = "-" + private[streaming] val VERSION_1 = 1 + private[streaming] val VERSION_2 = 2 + private[streaming] val VERSION = VERSION_1 // Default version for backward compatibility + private[streaming] val MAX_VERSION = VERSION_2 + private[streaming] val SERIALIZED_VOID_OFFSET = "-" + + private[checkpointing] def parseOffset(value: String): OffsetV2 = value match { + case SERIALIZED_VOID_OFFSET => null + case json => SerializedOffset(json) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 51cd457fbc85..ad7e9f3e4aa9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution -import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, PartitionOffset, ReadLimit} +import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, PartitionOffset, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.SQLExecution @@ -65,6 +65,8 @@ class ContinuousExecution( @volatile protected var sources: Seq[ContinuousStream] = Seq() + def sourceToIdMap: Map[SparkDataStream, String] = Map.empty + // For use only in test harnesses. private[sql] var currentEpochCoordinatorId: String = _ @@ -186,7 +188,7 @@ class ContinuousExecution( val nextOffsets = offsetLog.get(latestEpochId).getOrElse { throw new IllegalStateException( s"Batch $latestEpochId was committed without end epoch offsets!") - } + }.asInstanceOf[OffsetSeq] committedOffsets = nextOffsets.toStreamProgress(sources) execCtx.batchId = latestEpochId + 1 @@ -210,7 +212,8 @@ class ContinuousExecution( val execCtx = latestExecutionContext if (execCtx.batchId > 0) { - AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources(Some(offsets), sources) + AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources( + Some(offsets), sources, Map.empty[String, SparkDataStream]) } val withNewSources: LogicalPlan = logicalPlan transform { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AcceptsLatestSeenOffsetHandler.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AcceptsLatestSeenOffsetHandler.scala index b15b93b47ada..3eb5e6eb7d70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AcceptsLatestSeenOffsetHandler.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AcceptsLatestSeenOffsetHandler.scala @@ -20,18 +20,19 @@ package org.apache.spark.sql.execution.streaming.runtime import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, SparkDataStream} import org.apache.spark.sql.execution.streaming.Source -import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeq +import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqBase /** * This feeds "latest seen offset" to the sources that implement AcceptsLatestSeenOffset. */ object AcceptsLatestSeenOffsetHandler { def setLatestSeenOffsetOnSources( - offsets: Option[OffsetSeq], - sources: Seq[SparkDataStream]): Unit = { + offsets: Option[OffsetSeqBase], + sources: Seq[SparkDataStream], + sourceIdMap: Map[String, SparkDataStream]): Unit = { assertNoAcceptsLatestSeenOffsetWithDataSourceV1(sources) - offsets.map(_.toStreamProgress(sources)) match { + offsets.map(_.toStreamProgress(sources, sourceIdMap)) match { case Some(streamProgress) => streamProgress.foreach { case (src: AcceptsLatestSeenOffset, offset) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AsyncProgressTrackingMicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AsyncProgressTrackingMicroBatchExecution.scala index 2a87ba338088..4168df2e1f51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AsyncProgressTrackingMicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/AsyncProgressTrackingMicroBatchExecution.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.streaming.WriteToStream import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, OneTimeTrigger, ProcessingTimeTrigger} -import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog, CommitMetadata, OffsetSeq} +import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog, CommitMetadata, OffsetSeqBase} import org.apache.spark.sql.execution.streaming.operators.stateful.StateStoreWriter import org.apache.spark.sql.streaming.Trigger import org.apache.spark.util.{Clock, ThreadUtils} @@ -49,7 +49,7 @@ class AsyncProgressTrackingMicroBatchExecution( // Offsets that are ready to be committed by the source. // This is needed so that we can call source commit in the same thread as micro-batch execution // to be thread safe - private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]() + private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeqBase]() // to cache the batch id of the last batch written to storage private val lastBatchPersistedToDurableStorage = new AtomicLong(-1) @@ -104,7 +104,7 @@ class AsyncProgressTrackingMicroBatchExecution( // perform quick validation to fail faster validateAndGetTrigger() - override def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeq] = { + override def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeqBase] = { /* Initialize committed offsets to a committed batch, which at this * is the second latest batch id in the offset log. * The offset log may not be contiguous */ @@ -137,14 +137,15 @@ class AsyncProgressTrackingMicroBatchExecution( // Because we are using a thread pool with only one thread, async writes to the offset log // are still written in a serial / in order fashion offsetLog - .addAsync(execCtx.batchId, execCtx.endOffsets.toOffsetSeq(sources, execCtx.offsetSeqMetadata)) - .thenAccept(tuple => { - val (batchId, persistedToDurableStorage) = tuple + .addAsync(execCtx.batchId, + execCtx.endOffsets.toOffsets(sources, sourceIdMap, execCtx.offsetSeqMetadata)) + .thenAccept((tuple: (Long, Boolean)) => { + val (batchId: Long, persistedToDurableStorage: Boolean) = tuple if (persistedToDurableStorage) { // batch id cache not initialized if (lastBatchPersistedToDurableStorage.get == -1) { lastBatchPersistedToDurableStorage.set( - offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1)) + offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1L)) } if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala index cf2fca3d3cd8..5ea97e6a2c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, RealTimeStreamScanExec, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, OneTimeTrigger, ProcessingTimeTrigger, RealTimeModeAllowlist, RealTimeTrigger, Sink, Source, StreamingQueryPlanTraverseHelper} -import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeq, OffsetSeqMetadata} +import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeqBase, OffsetSeqMetadata} import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, StateStoreWriter} import org.apache.spark.sql.execution.streaming.runtime.AcceptsLatestSeenOffsetHandler import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} @@ -100,6 +100,14 @@ class MicroBatchExecution( @volatile protected var sources: Seq[SparkDataStream] = Seq.empty + // Source ID mapping for OffsetMap support + // Using index as sourceId initially, can be extended to support user-provided names + // This is initialized in the same path as the sources Seq (defined above) and is used + // in the same way, when OffsetLog v2 is used. + @volatile protected var sourceIdMap: Map[String, SparkDataStream] = Map.empty + + override protected def sourceToIdMap: Map[SparkDataStream, String] = sourceIdMap.map(_.swap) + @volatile protected[sql] var triggerExecutor: TriggerExecutor = _ protected def getTrigger(): TriggerExecutor = { @@ -243,6 +251,11 @@ class MicroBatchExecution( case r: StreamingDataSourceV2ScanRelation => r.stream } + // Create source ID mapping for OffsetMap support + sourceIdMap = sources.zipWithIndex.map { + case (source, index) => index.toString -> source + }.toMap + // Inform the source if it is in real time mode if (trigger.isInstanceOf[RealTimeTrigger]) { sources.foreach{ @@ -399,7 +412,10 @@ class MicroBatchExecution( } AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources( - offsetLog.getLatest().map(_._2), sources) + offsetLog.getLatest().map(_._2), + sources, + sourceIdMap + ) val execCtx = new MicroBatchExecutionContext(id, runId, name, triggerClock, sources, sink, progressReporter, -1, sparkSession, None) @@ -552,7 +568,7 @@ class MicroBatchExecution( * @param latestBatchId the batch id of the current micro batch * @return A option that contains the offset of the previously written batch */ - def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeq] = { + def validateOffsetLogAndGetPrevOffset(latestBatchId: Long): Option[OffsetSeqBase] = { if (latestBatchId != 0) { Some(offsetLog.get(latestBatchId - 1).getOrElse { logError(log"The offset log for batch ${MDC(LogKeys.BATCH_ID, latestBatchId - 1)} " + @@ -601,16 +617,16 @@ class MicroBatchExecution( * in the offset log */ execCtx.batchId = latestBatchId execCtx.isCurrentBatchConstructed = true - execCtx.endOffsets = nextOffsets.toStreamProgress(sources) + execCtx.endOffsets = nextOffsets.toStreamProgress(sources, sourceIdMap) // validate the integrity of offset log and get the previous offset from the offset log val secondLatestOffsets = validateOffsetLogAndGetPrevOffset(latestBatchId) secondLatestOffsets.foreach { offset => - execCtx.startOffsets = offset.toStreamProgress(sources) + execCtx.startOffsets = offset.toStreamProgress(sources, sourceIdMap) } // update offset metadata - nextOffsets.metadata.foreach { metadata => + nextOffsets.metadataOpt.foreach { metadata => OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.sessionState.conf) execCtx.offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) @@ -846,8 +862,8 @@ class MicroBatchExecution( shouldConstructNextBatch } - protected def commitSources(offsetSeq: OffsetSeq): Unit = { - offsetSeq.toStreamProgress(sources).foreach { + protected def commitSources(offsetSeq: OffsetSeqBase): Unit = { + offsetSeq.toStreamProgress(sources, sourceIdMap).foreach { case (src: Source, off: Offset) => src.commit(off) case (stream: MicroBatchStream, off) => stream.commit(stream.deserializeOffset(off.json)) @@ -1106,7 +1122,7 @@ class MicroBatchExecution( if (!trigger.isInstanceOf[RealTimeTrigger]) { if (!offsetLog.add( execCtx.batchId, - execCtx.endOffsets.toOffsetSeq(sources, execCtx.offsetSeqMetadata) + execCtx.endOffsets.toOffsets(sources, sourceIdMap, execCtx.offsetSeqMetadata) )) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } @@ -1262,7 +1278,7 @@ class MicroBatchExecution( execCtx.reportTimeTaken("walCommit") { if (!offsetLog.add( execCtx.batchId, - execCtx.endOffsets.toOffsetSeq(sources, execCtx.offsetSeqMetadata) + execCtx.endOffsets.toOffsets(sources, sourceIdMap, execCtx.offsetSeqMetadata) )) { throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamExecution.scala index 56ed0de1fcdc..65c1226a85df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamExecution.scala @@ -154,6 +154,12 @@ abstract class StreamExecution( */ protected def sources: Seq[SparkDataStream] + /** + * Source-to-ID mapping for OffsetMap support. + * Using index as sourceId initially, can be extended to support user-provided names. + */ + protected def sourceToIdMap: Map[SparkDataStream, String] + /** Isolated spark session to run the batches with. */ protected[sql] val sparkSessionForStream: SparkSession = sparkSession.cloneSession() @@ -370,10 +376,12 @@ abstract class StreamExecution( toDebugString(includeLogicalPlan = isInitialized), cause = cause, getLatestExecutionContext().startOffsets - .toOffsetSeq(sources.toSeq, getLatestExecutionContext().offsetSeqMetadata) + .toOffsets(sources.toSeq, sourceToIdMap.map(_.swap), + getLatestExecutionContext().offsetSeqMetadata) .toString, getLatestExecutionContext().endOffsets - .toOffsetSeq(sources.toSeq, getLatestExecutionContext().offsetSeqMetadata) + .toOffsets(sources.toSeq, sourceToIdMap.map(_.swap), + getLatestExecutionContext().offsetSeqMetadata) .toString, errorClass = "STREAM_FAILED", messageParameters = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala index a6fd103e8d6a..0708f931b77e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamProgress.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.streaming.runtime import scala.collection.immutable import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, SparkDataStream} -import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetSeq, OffsetSeqMetadata} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetMap, OffsetSeq, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata} /** * A helper class that looks like a Map[Source, Offset]. @@ -30,10 +31,43 @@ class StreamProgress( new immutable.HashMap[SparkDataStream, OffsetV2]) extends scala.collection.immutable.Map[SparkDataStream, OffsetV2] { - def toOffsetSeq(source: Seq[SparkDataStream], metadata: OffsetSeqMetadata): OffsetSeq = { + /** + * Unified method to convert StreamProgress to appropriate OffsetSeq format. + * Handles both VERSION_1 (OffsetSeq) and VERSION_2 (OffsetMap) based on metadata version. + */ + def toOffsets( + sources: Seq[SparkDataStream], + sourceIdMap: Map[String, SparkDataStream], + metadata: OffsetSeqMetadata): OffsetSeqBase = { + metadata.version match { + case OffsetSeqLog.VERSION_1 => + toOffsetSeq(sources, metadata) + case OffsetSeqLog.VERSION_2 => + toOffsetMap(sourceIdMap, metadata) + case v => + throw QueryExecutionErrors.logVersionGreaterThanSupported(v, OffsetSeqLog.MAX_VERSION) + } + } + + def toOffsetSeq( + source: Seq[SparkDataStream], + metadata: OffsetSeqMetadata): OffsetSeq = { OffsetSeq(source.map(get), Some(metadata)) } + private def toOffsetMap( + sourceIdMap: Map[String, SparkDataStream], + metadata: OffsetSeqMetadata): OffsetMap = { + // Compute reverse mapping only when needed + val sourceToIdMap = sourceIdMap.map(_.swap) + val offsetsMap = baseMap.map { case (source, offset) => + val sourceId = sourceToIdMap.getOrElse(source, + throw new IllegalArgumentException(s"Source $source not found in sourceToIdMap")) + sourceId -> Some(offset) + } + OffsetMap(offsetsMap, Some(metadata)) + } + override def toString: String = baseMap.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala index 2456b2c9b73b..63e3a6ec8d9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala @@ -129,7 +129,8 @@ class OfflineStateRepartitionRunner( // If it is a failed repartition batch, lets check if the shuffle partitions // is the same as the requested. If same, then we can retry the batch. val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get - val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadata.get).get + val lastBatchShufflePartitions = getShufflePartitions( + lastBatch.metadataOpt.get).get if (lastBatchShufflePartitions == numPartitions) { // We can retry the repartition batch. logInfo(log"The last batch is a failed repartition batch " + @@ -193,7 +194,7 @@ class OfflineStateRepartitionRunner( .offsetSeqNotFoundError(checkpointLocation, lastCommittedBatchId)) // Missing offset metadata not supported - val lastCommittedMetadata = lastCommittedOffsetSeq.metadata.getOrElse( + val lastCommittedMetadata = lastCommittedOffsetSeq.metadataOpt.getOrElse( throw OfflineStateRepartitionErrors.missingOffsetSeqMetadataError( checkpointLocation, version = 1, batchId = lastCommittedBatchId) ) @@ -253,11 +254,11 @@ object OfflineStateRepartitionUtils { throw OfflineStateRepartitionErrors .offsetSeqNotFoundError(checkpointLocation, prevBatchId)) - val batchMetadata = batch.metadata.getOrElse(throw OfflineStateRepartitionErrors + val batchMetadata = batch.metadataOpt.getOrElse(throw OfflineStateRepartitionErrors .missingOffsetSeqMetadataError(checkpointLocation, version = 1, batchId = batchId)) val shufflePartitions = getShufflePartitions(batchMetadata).get - val previousBatchMetadata = previousBatch.metadata.getOrElse( + val previousBatchMetadata = previousBatch.metadataOpt.getOrElse( throw OfflineStateRepartitionErrors .missingOffsetSeqMetadataError(checkpointLocation, version = 1, batchId = prevBatchId)) val previousShufflePartitions = getShufflePartitions(previousBatchMetadata).get diff --git a/sql/core/src/test/resources/structured-streaming/offset-map/0 b/sql/core/src/test/resources/structured-streaming/offset-map/0 new file mode 100644 index 000000000000..ca9c25d8cf3f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/offset-map/0 @@ -0,0 +1,3 @@ +v2 +{"batchWatermarkMs":0,"batchTimestampMs":1758651405232,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.shuffle.partitions":"5"}} +0:0 diff --git a/sql/core/src/test/resources/structured-streaming/offset-map/1 b/sql/core/src/test/resources/structured-streaming/offset-map/1 new file mode 100644 index 000000000000..9e01cb1e2eae --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/offset-map/1 @@ -0,0 +1,3 @@ +v2 +{"batchWatermarkMs":0,"batchTimestampMs":1758651405232,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.shuffle.partitions":"5"}} +0:1 diff --git a/sql/core/src/test/resources/structured-streaming/offset-map/2 b/sql/core/src/test/resources/structured-streaming/offset-map/2 new file mode 100644 index 000000000000..833abc798a6a --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/offset-map/2 @@ -0,0 +1,3 @@ +v2 +{"batchWatermarkMs":0,"batchTimestampMs":1758651405232,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.shuffle.partitions":"5"}} +0:2 diff --git a/sql/core/src/test/resources/structured-streaming/offset-map/3 b/sql/core/src/test/resources/structured-streaming/offset-map/3 new file mode 100644 index 000000000000..f108ad977be0 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/offset-map/3 @@ -0,0 +1,3 @@ +v2 +{"batchWatermarkMs":0,"batchTimestampMs":1758651405232,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.shuffle.partitions":"5"}} +0:3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala index e31e0e70cf39..3f0e65264662 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.time.{Seconds, Span} import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.connector.read.streaming -import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog, OffsetSeq} +import org.apache.spark.sql.execution.streaming.checkpointing.{AsyncCommitLog, AsyncOffsetSeqLog} import org.apache.spark.sql.execution.streaming.runtime.{AsyncProgressTrackingMicroBatchExecution, MemoryStream, StreamExecution} import org.apache.spark.sql.execution.streaming.runtime.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK} import org.apache.spark.sql.functions.{column, window} @@ -835,7 +835,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0) // commits received at source should match up to the ones found in the offset log for (i <- 0 until inputData.commits.length) { - val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get + val offsetOnDisk = offsetLog.get(offsetLogFiles(i)).get val sourceCommittedOffset: streaming.Offset = inputData.commits(i) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index e4312fd16d1f..2c3ae11a4e7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import java.io.File import org.apache.spark.sql.catalyst.util.stringToFile -import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetSeq, OffsetSeqLog, OffsetSeqMetadata} +import org.apache.spark.sql.execution.streaming.checkpointing.{OffsetMap, OffsetSeq, OffsetSeqBase, OffsetSeqLog, OffsetSeqMetadata} import org.apache.spark.sql.execution.streaming.runtime.{LongOffset, SerializedOffset} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -109,7 +109,7 @@ class OffsetSeqLogSuite extends SharedSparkSession { log.get(0) } Seq( - s"maximum supported log version is v${OffsetSeqLog.VERSION}, but encountered v99999", + s"maximum supported log version is v${OffsetSeqLog.MAX_VERSION}, but encountered v99999", "produced by a newer version of Spark and cannot be read by this version" ).foreach { message => assert(e.getMessage.contains(message)) @@ -124,10 +124,10 @@ class OffsetSeqLogSuite extends SharedSparkSession { Some(SerializedOffset("""{"logOffset":345}""")), Some(SerializedOffset("""{"topic-0":{"0":1}}""")) )) - assert(offsetSeq.metadata === Some(OffsetSeqMetadata(0L, 1480981499528L))) + assert(offsetSeq.metadataOpt === Some(OffsetSeqMetadata(0L, 1480981499528L))) } - private def readFromResource(dir: String): (Long, OffsetSeq) = { + private def readFromResource(dir: String): (Long, OffsetSeqBase) = { val input = getClass.getResource(s"/structured-streaming/$dir") val log = new OffsetSeqLog(spark, input.toString) log.getLatest().get @@ -161,7 +161,7 @@ class OffsetSeqLogSuite extends SharedSparkSession { // Read the latest offset log val offsetSeq = log.get(latestBatchId.get).get - val offsetSeqMetadata = offsetSeq.metadata.get + val offsetSeqMetadata = offsetSeq.metadataOpt.get if (entryExists) { val encodingFormatOpt = offsetSeqMetadata.conf.get( @@ -210,7 +210,7 @@ class OffsetSeqLogSuite extends SharedSparkSession { withSQLConf(rowChecksumConf -> true.toString) { val existingChkpt = "offset-log-version-2.1.0" val (_, offsetSeq) = readFromResource(existingChkpt) - val offsetSeqMetadata = offsetSeq.metadata.get + val offsetSeqMetadata = offsetSeq.metadataOpt.get // Not present in existing checkpoint assert(offsetSeqMetadata.conf.get(rowChecksumConf) === None) @@ -219,4 +219,22 @@ class OffsetSeqLogSuite extends SharedSparkSession { assert(!clonedSqlConf.stateStoreRowChecksumEnabled) } } + + test("OffsetMap golden file compatibility test - VERSION_2 format") { + val (batchId, offsetSeq) = readFromResource("offset-map") + assert(batchId === 3) + + // Verify it's an OffsetMap (VERSION_2) + assert(offsetSeq.isInstanceOf[OffsetMap]) + val offsetMap = offsetSeq.asInstanceOf[OffsetMap] + + // Verify the offset data + assert(offsetMap.offsetsMap === Map("0" -> Some(SerializedOffset("3")))) + + // Verify metadata + assert(offsetSeq.metadataOpt.isDefined) + val metadata = offsetSeq.metadataOpt.get + assert(metadata.batchWatermarkMs === 0) + assert(metadata.batchTimestampMs === 1758651405232L) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala index 86b5502b652e..860e7a1ab2e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala @@ -258,7 +258,7 @@ class OfflineStateRepartitionSuite extends StreamTest { assert(lastBatchId == batchId) val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get - val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadata.get).get + val lastBatchShufflePartitions = getShufflePartitions(lastBatch.metadataOpt.get).get assert(lastBatchShufflePartitions == expectedShufflePartitions) // Verify the commit log @@ -277,7 +277,7 @@ class OfflineStateRepartitionSuite extends StreamTest { s"Offsets should be identical between batch $previousBatchId and $batchId") // Verify metadata is the same except for shuffle partitions config - (lastBatch.metadata, previousBatch.metadata) match { + (lastBatch.metadataOpt, previousBatch.metadataOpt) match { case (Some(lastMetadata), Some(previousMetadata)) => // Check watermark and timestamp are the same assert(lastMetadata.batchWatermarkMs == previousMetadata.batchWatermarkMs,