Skip to content

Commit fe78246

Browse files
committed
Update metrics for read and writes via DSV2
Thanks @ymuzammil for fixing the issue with the read metrics. Fixes SPARKC-712
1 parent d2d13b2 commit fe78246

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraInJoinReaderFactory.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import org.apache.spark.sql.catalyst.InternalRow
1212
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
1313
import org.apache.spark.sql.sources.In
1414
import org.apache.spark.sql.types.{LongType, StructField, StructType}
15+
import org.apache.spark.metrics.InputMetricsUpdater
16+
import org.apache.spark.TaskContext
1517

1618
import scala.util.{Failure, Success}
1719

@@ -62,16 +64,18 @@ abstract class CassandraBaseInJoinReader(
6264
protected val maybeRateLimit = JoinHelper.maybeRateLimit(readConf)
6365
protected val requestsPerSecondRateLimiter = JoinHelper.requestsPerSecondRateLimiter(readConf)
6466

67+
protected val metricsUpdater = InputMetricsUpdater(TaskContext.get(), readConf)
6568
protected def pairWithRight(left: CassandraRow): SettableFuture[Iterator[(CassandraRow, InternalRow)]] = {
6669
val resultFuture = SettableFuture.create[Iterator[(CassandraRow, InternalRow)]]
6770
val leftSide = Iterator.continually(left)
6871

6972
queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete {
7073
case Success(rs) =>
7174
val resultSet = new PrefetchingResultSetIterator(rs)
75+
val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics)
7276
/* This is a much less than ideal place to actually rate limit, we are buffering
7377
these futures this means we will most likely exceed our threshold*/
74-
val throttledIterator = resultSet.map(maybeRateLimit)
78+
val throttledIterator = iteratorWithMetrics.map(maybeRateLimit)
7579
val rightSide = throttledIterator.map(rowReader.read(_, rowMetadata))
7680
resultFuture.set(leftSide.zip(rightSide))
7781
case Failure(throwable) =>
@@ -103,6 +107,7 @@ abstract class CassandraBaseInJoinReader(
103107
override def get(): InternalRow = currentRow
104108

105109
override def close(): Unit = {
110+
metricsUpdater.finish()
106111
session.close()
107112
}
108113
}

connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanPartitionReaderFactory.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import com.datastax.spark.connector.util.Logging
1212
import org.apache.spark.sql.catalyst.InternalRow
1313
import org.apache.spark.sql.connector.read._
1414
import org.apache.spark.sql.types.{LongType, StructField, StructType}
15+
import org.apache.spark.metrics.InputMetricsUpdater
16+
import org.apache.spark.TaskContext
1517

1618
case class CassandraScanPartitionReaderFactory(
1719
connector: CassandraConnector,
@@ -61,6 +63,8 @@ abstract class CassandraPartitionReaderBase
6163
protected val rowIterator = getIterator()
6264
protected var lastRow: InternalRow = InternalRow()
6365

66+
protected val metricsUpdater = InputMetricsUpdater(TaskContext.get(), readConf)
67+
6468
override def next(): Boolean = {
6569
if (rowIterator.hasNext) {
6670
lastRow = rowIterator.next()
@@ -73,6 +77,7 @@ abstract class CassandraPartitionReaderBase
7377
override def get(): InternalRow = lastRow
7478

7579
override def close(): Unit = {
80+
metricsUpdater.finish()
7681
scanner.close()
7782
}
7883

@@ -107,7 +112,8 @@ abstract class CassandraPartitionReaderBase
107112
tokenRanges.iterator.flatMap { range =>
108113
val scanResult = ScanHelper.fetchTokenRange(scanner, tableDef, queryParts, range, readConf.consistencyLevel, readConf.fetchSizeInRows)
109114
val meta = scanResult.metadata
110-
scanResult.rows.map(rowReader.read(_, meta))
115+
val iteratorWithMetrics = scanResult.rows.map(metricsUpdater.updateMetrics)
116+
iteratorWithMetrics.map(rowReader.read(_, meta))
111117
}
112118
}
113119

connector/src/main/scala/com/datastax/spark/connector/datasource/CasssandraDriverDataWriterFactory.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import org.apache.spark.sql.catalyst.InternalRow
77
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory
88
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, WriterCommitMessage}
99
import org.apache.spark.sql.types.StructType
10+
import org.apache.spark.metrics.OutputMetricsUpdater
11+
import org.apache.spark.TaskContext
1012

1113
case class CassandraDriverDataWriterFactory(
1214
connector: CassandraConnector,
@@ -36,22 +38,31 @@ case class CassandraDriverDataWriter(
3638

3739
private val columns = SomeColumns(inputSchema.fieldNames.map(name => ColumnName(name)): _*)
3840

39-
private val writer =
41+
private val metricsUpdater = OutputMetricsUpdater(TaskContext.get(), writeConf)
42+
43+
private val asycWriter =
4044
TableWriter(connector, tableDef, columns, writeConf, false)(unsafeRowWriterFactory)
4145
.getAsyncWriter()
4246

47+
private val writer = asycWriter.copy(
48+
successHandler = Some(metricsUpdater.batchFinished(success = true, _, _, _)),
49+
failureHandler = Some(metricsUpdater.batchFinished(success = false, _, _, _)))
50+
4351
override def write(record: InternalRow): Unit = writer.write(record)
4452

4553
override def commit(): WriterCommitMessage = {
54+
metricsUpdater.finish()
4655
writer.close()
4756
CassandraCommitMessage()
4857
}
4958

5059
override def abort(): Unit = {
60+
metricsUpdater.finish()
5161
writer.close()
5262
}
5363

5464
override def close(): Unit = {
65+
metricsUpdater.finish()
5566
//Our proxy Session Handler handles double closes by ignoring them so this is fine
5667
writer.close()
5768
}

0 commit comments

Comments
 (0)