Skip to content

Commit e2a88e8

Browse files
committed
Update metrics for read and writes via DSV2
Thanks @ymuzammil for fixing the issue with the read metrics. Fixes SPARKC-712
1 parent 1c2ffa1 commit e2a88e8

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraInJoinReaderFactory.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.InternalRow
3030
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
3131
import org.apache.spark.sql.sources.In
3232
import org.apache.spark.sql.types.{LongType, StructField, StructType}
33+
import org.apache.spark.metrics.InputMetricsUpdater
34+
import org.apache.spark.TaskContext
3335

3436
import scala.util.{Failure, Success}
3537

@@ -80,16 +82,18 @@ abstract class CassandraBaseInJoinReader(
8082
protected val maybeRateLimit = JoinHelper.maybeRateLimit(readConf)
8183
protected val requestsPerSecondRateLimiter = JoinHelper.requestsPerSecondRateLimiter(readConf)
8284

85+
protected val metricsUpdater = InputMetricsUpdater(TaskContext.get(), readConf)
8386
protected def pairWithRight(left: CassandraRow): SettableFuture[Iterator[(CassandraRow, InternalRow)]] = {
8487
val resultFuture = SettableFuture.create[Iterator[(CassandraRow, InternalRow)]]
8588
val leftSide = Iterator.continually(left)
8689

8790
queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete {
8891
case Success(rs) =>
8992
val resultSet = new PrefetchingResultSetIterator(rs)
93+
val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics)
9094
/* This is a much less than ideal place to actually rate limit, we are buffering
9195
these futures this means we will most likely exceed our threshold*/
92-
val throttledIterator = resultSet.map(maybeRateLimit)
96+
val throttledIterator = iteratorWithMetrics.map(maybeRateLimit)
9397
val rightSide = throttledIterator.map(rowReader.read(_, rowMetadata))
9498
resultFuture.set(leftSide.zip(rightSide))
9599
case Failure(throwable) =>
@@ -121,6 +125,7 @@ abstract class CassandraBaseInJoinReader(
121125
override def get(): InternalRow = currentRow
122126

123127
override def close(): Unit = {
128+
metricsUpdater.finish()
124129
session.close()
125130
}
126131
}

connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanPartitionReaderFactory.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ import com.datastax.spark.connector.util.Logging
3030
import org.apache.spark.sql.catalyst.InternalRow
3131
import org.apache.spark.sql.connector.read._
3232
import org.apache.spark.sql.types.{LongType, StructField, StructType}
33+
import org.apache.spark.metrics.InputMetricsUpdater
34+
import org.apache.spark.TaskContext
3335

3436
case class CassandraScanPartitionReaderFactory(
3537
connector: CassandraConnector,
@@ -79,6 +81,8 @@ abstract class CassandraPartitionReaderBase
7981
protected val rowIterator = getIterator()
8082
protected var lastRow: InternalRow = InternalRow()
8183

84+
protected val metricsUpdater = InputMetricsUpdater(TaskContext.get(), readConf)
85+
8286
override def next(): Boolean = {
8387
if (rowIterator.hasNext) {
8488
lastRow = rowIterator.next()
@@ -91,6 +95,7 @@ abstract class CassandraPartitionReaderBase
9195
override def get(): InternalRow = lastRow
9296

9397
override def close(): Unit = {
98+
metricsUpdater.finish()
9499
scanner.close()
95100
}
96101

@@ -125,7 +130,8 @@ abstract class CassandraPartitionReaderBase
125130
tokenRanges.iterator.flatMap { range =>
126131
val scanResult = ScanHelper.fetchTokenRange(scanner, tableDef, queryParts, range, readConf.consistencyLevel, readConf.fetchSizeInRows)
127132
val meta = scanResult.metadata
128-
scanResult.rows.map(rowReader.read(_, meta))
133+
val iteratorWithMetrics = scanResult.rows.map(metricsUpdater.updateMetrics)
134+
iteratorWithMetrics.map(rowReader.read(_, meta))
129135
}
130136
}
131137

connector/src/main/scala/com/datastax/spark/connector/datasource/CasssandraDriverDataWriterFactory.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory
2626
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, WriterCommitMessage}
2727
import org.apache.spark.sql.types.StructType
28+
import org.apache.spark.metrics.OutputMetricsUpdater
29+
import org.apache.spark.TaskContext
2830

2931
case class CassandraDriverDataWriterFactory(
3032
connector: CassandraConnector,
@@ -54,22 +56,31 @@ case class CassandraDriverDataWriter(
5456

5557
private val columns = SomeColumns(inputSchema.fieldNames.map(name => ColumnName(name)): _*)
5658

57-
private val writer =
59+
private val metricsUpdater = OutputMetricsUpdater(TaskContext.get(), writeConf)
60+
61+
private val asycWriter =
5862
TableWriter(connector, tableDef, columns, writeConf, false)(unsafeRowWriterFactory)
5963
.getAsyncWriter()
6064

65+
private val writer = asycWriter.copy(
66+
successHandler = Some(metricsUpdater.batchFinished(success = true, _, _, _)),
67+
failureHandler = Some(metricsUpdater.batchFinished(success = false, _, _, _)))
68+
6169
override def write(record: InternalRow): Unit = writer.write(record)
6270

6371
override def commit(): WriterCommitMessage = {
72+
metricsUpdater.finish()
6473
writer.close()
6574
CassandraCommitMessage()
6675
}
6776

6877
override def abort(): Unit = {
78+
metricsUpdater.finish()
6979
writer.close()
7080
}
7181

7282
override def close(): Unit = {
83+
metricsUpdater.finish()
7384
//Our proxy Session Handler handles double closes by ignoring them so this is fine
7485
writer.close()
7586
}

0 commit comments

Comments
 (0)