Skip to content

Commit 69647b4

Browse files
committed
feat: native support for intervals <-> Neo4j duration
this commit introduces a new way to write from Spark SQL interval types to Neo4j duration type. commit is additive, i.e. the previous method to write Neo4j duration type via custom struct is still possible, and should therefore be backwards compatible. Fixes CONN-341
1 parent 88b7286 commit 69647b4

File tree

4 files changed

+180
-9
lines changed

4 files changed

+180
-9
lines changed

common/src/main/scala/org/neo4j/spark/converter/DataConverter.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,36 @@ trait DataConverter[T] {
5555

5656
object SparkToNeo4jDataConverter {
5757
def apply(): SparkToNeo4jDataConverter = new SparkToNeo4jDataConverter()
58+
59+
def dayTimeIntervalToNeo4j(micros: Long): Value = {
60+
val oneSecondInMicros = 1000000L
61+
val oneDayInMicros = 24 * 3600 * oneSecondInMicros
62+
63+
val numberDays = Math.floorDiv(micros, oneDayInMicros)
64+
val remainderMicros = Math.floorMod(micros, oneDayInMicros)
65+
val numberSeconds = Math.floorDiv(remainderMicros, oneSecondInMicros)
66+
val numberNanos = Math.floorMod(remainderMicros, oneSecondInMicros) * 1000
67+
68+
Values.isoDuration(0L, numberDays, numberSeconds, numberNanos.toInt)
69+
}
70+
71+
// while Neo4j supports years, this driver version's API does not expose it.
72+
def yearMonthIntervalToNeo4j(months: Int): Value = {
73+
Values.isoDuration(months.toLong, 0L, 0L, 0)
74+
}
5875
}
5976

6077
class SparkToNeo4jDataConverter extends DataConverter[Value] {
6178

6279
override def convert(value: Any, dataType: DataType): Value = {
80+
dataType match {
81+
case _: DayTimeIntervalType if value != null =>
82+
return SparkToNeo4jDataConverter.dayTimeIntervalToNeo4j(value.asInstanceOf[Long])
83+
case _: YearMonthIntervalType if value != null =>
84+
return SparkToNeo4jDataConverter.yearMonthIntervalToNeo4j(value.asInstanceOf[Int])
85+
case _ => // do nothing
86+
}
87+
6388
value match {
6489
case date: java.sql.Date => convert(date.toLocalDate, dataType)
6590
case timestamp: java.sql.Timestamp => convert(timestamp.toLocalDateTime, dataType)

common/src/main/scala/org/neo4j/spark/converter/TypeConverter.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ package org.neo4j.spark.converter
1818

1919
import org.apache.spark.sql.types.DataType
2020
import org.apache.spark.sql.types.DataTypes
21+
import org.apache.spark.sql.types.DayTimeIntervalType
22+
import org.apache.spark.sql.types.YearMonthIntervalType
2123
import org.neo4j.driver.types.Entity
2224
import org.neo4j.spark.converter.CypherToSparkTypeConverter.cleanTerms
2325
import org.neo4j.spark.converter.CypherToSparkTypeConverter.durationType
@@ -129,6 +131,8 @@ object SparkToCypherTypeConverter {
129131
DataTypes.DoubleType -> "FLOAT",
130132
DataTypes.DateType -> "DATE",
131133
DataTypes.TimestampType -> "LOCAL DATETIME",
134+
DayTimeIntervalType() -> "DURATION",
135+
YearMonthIntervalType() -> "DURATION",
132136
durationType -> "DURATION",
133137
pointType -> "POINT",
134138
// Cypher graph entities do not allow null values in arrays
@@ -141,6 +145,10 @@ object SparkToCypherTypeConverter {
141145
DataTypes.createArrayType(DataTypes.DateType, false) -> "LIST<DATE NOT NULL>",
142146
DataTypes.createArrayType(DataTypes.TimestampType, false) -> "LIST<LOCAL DATETIME NOT NULL>",
143147
DataTypes.createArrayType(DataTypes.TimestampType, true) -> "LIST<LOCAL DATETIME NOT NULL>",
148+
DataTypes.createArrayType(DayTimeIntervalType(), false) -> "LIST<DURATION NOT NULL>",
149+
DataTypes.createArrayType(DayTimeIntervalType(), true) -> "LIST<DURATION NOT NULL>",
150+
DataTypes.createArrayType(YearMonthIntervalType(), false) -> "LIST<DURATION NOT NULL>",
151+
DataTypes.createArrayType(YearMonthIntervalType(), true) -> "LIST<DURATION NOT NULL>",
144152
DataTypes.createArrayType(durationType, false) -> "LIST<DURATION NOT NULL>",
145153
DataTypes.createArrayType(pointType, false) -> "LIST<POINT NOT NULL>"
146154
)

spark-3/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@
6363
<artifactId>scalatest_${scala.binary.version}</artifactId>
6464
<scope>test</scope>
6565
</dependency>
66+
<dependency>
67+
<groupId>pl.pragmatists</groupId>
68+
<artifactId>JUnitParams</artifactId>
69+
<version>1.1.1</version>
70+
<scope>test</scope>
71+
</dependency>
6672
</dependencies>
6773
<build>
6874
<resources>

spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterTSE.scala

Lines changed: 141 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,28 @@
1616
*/
1717
package org.neo4j.spark
1818

19+
import junitparams.JUnitParamsRunner
20+
import junitparams.Parameters
1921
import org.apache.commons.lang3.exception.ExceptionUtils
2022
import org.apache.spark.SparkException
2123
import org.apache.spark.sql.DataFrame
2224
import org.apache.spark.sql.SaveMode
2325
import org.apache.spark.sql.SparkSession
26+
import org.apache.spark.sql.types.ArrayType
27+
import org.apache.spark.sql.types.DataType
28+
import org.apache.spark.sql.types.DayTimeIntervalType
29+
import org.apache.spark.sql.types.YearMonthIntervalType
2430
import org.junit
2531
import org.junit.Assert._
2632
import org.junit.Ignore
2733
import org.junit.Test
34+
import org.junit.runner.RunWith
2835
import org.neo4j.driver.Result
2936
import org.neo4j.driver.Transaction
3037
import org.neo4j.driver.TransactionWork
3138
import org.neo4j.driver.Value
3239
import org.neo4j.driver.exceptions.ClientException
40+
import org.neo4j.driver.exceptions.value.Uncoercible
3341
import org.neo4j.driver.internal.InternalPoint2D
3442
import org.neo4j.driver.internal.InternalPoint3D
3543
import org.neo4j.driver.internal.types.InternalTypeSystem
@@ -54,7 +62,7 @@ import scala.util.Random
5462

5563
abstract class Neo4jType(`type`: String)
5664

57-
case class Duration(`type`: String = "duration", months: Long, days: Long, seconds: Long, nanoseconds: Long)
65+
case class Duration(months: Long, days: Long, seconds: Long, nanoseconds: Long, `type`: String = "duration")
5866
extends Neo4jType(`type`)
5967

6068
case class Point2d(`type`: String = "point-2d", srid: Int, x: Double, y: Double) extends Neo4jType(`type`)
@@ -73,8 +81,9 @@ case class SimplePerson(name: String, surname: String)
7381

7482
case class EmptyRow[T](data: T)
7583

84+
@RunWith(classOf[JUnitParamsRunner])
7685
class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
77-
val sparkSession = SparkSession.builder().getOrCreate()
86+
val sparkSession = SparkSession.builder().master("local[*]").getOrCreate()
7887

7988
import sparkSession.implicits._
8089

@@ -414,11 +423,11 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
414423
}
415424

416425
@Test
417-
def `should write nodes with duration values into Neo4j`(): Unit = {
426+
def `should write nodes with duration values into Neo4j from struct`(): Unit = {
418427
val total = 10
419428
val ds = (1 to total)
420429
.map(i => i.toLong)
421-
.map(i => EmptyRow(Duration(months = i, days = i, seconds = i, nanoseconds = i)))
430+
.map(i => EmptyRow(Duration(i, i, i, i)))
422431
.toDS()
423432

424433
ds.write
@@ -430,10 +439,10 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
430439

431440
val records = SparkConnectorScalaSuiteIT.session().run(
432441
"""MATCH (p:BeanWithDuration)
433-
|RETURN p.data AS data
442+
|RETURN p.data AS duration
434443
|""".stripMargin
435444
).list().asScala
436-
.map(r => r.get("data").asIsoDuration())
445+
.map(r => r.get("duration").asIsoDuration())
437446
.map(data => (data.months, data.days, data.seconds, data.nanoseconds))
438447
.toSet
439448

@@ -445,14 +454,14 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
445454
}
446455

447456
@Test
448-
def `should write nodes with duration array values into Neo4j`(): Unit = {
457+
def `should write nodes with duration array values into Neo4j from struct`(): Unit = {
449458
val total = 10
450459
val ds = (1 to total)
451460
.map(i => i.toLong)
452461
.map(i =>
453462
EmptyRow(Seq(
454-
Duration(months = i, days = i, seconds = i, nanoseconds = i),
455-
Duration(months = i, days = i, seconds = i, nanoseconds = i)
463+
Duration(i, i, i, i),
464+
Duration(i, i, i, i)
456465
))
457466
)
458467
.toDS()
@@ -484,6 +493,129 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
484493
assertEquals(expected, records)
485494
}
486495

496+
private case class DurationCase(
497+
intervalExpression: String,
498+
duration: Duration,
499+
expectedDt: Class[_ <: DataType] = classOf[DayTimeIntervalType]
500+
) {
501+
private val isArithmetic = intervalExpression.startsWith("timestamp")
502+
503+
val sql: String = if (isArithmetic) {
504+
intervalExpression
505+
} else {
506+
s"INTERVAL $intervalExpression"
507+
}
508+
}
509+
510+
private def sqlDurationCases: java.util.List[DurationCase] = java.util.Arrays.asList(
511+
// DAY/TIME -> DayTimeIntervalType
512+
DurationCase("'3' DAY", Duration(0, 3, 0, 0)),
513+
DurationCase("'10 05' DAY TO HOUR", Duration(0, 10, 5L * 3600, 0)),
514+
DurationCase("'10 05:30' DAY TO MINUTE", Duration(0, 10, 5L * 3600 + 30L * 60, 0)),
515+
DurationCase("'10 05:30:15.123456' DAY TO SECOND", Duration(0, 10, 5L * 3600 + 30L * 60 + 15L, 123456000)),
516+
DurationCase("'12' HOUR", Duration(0, 0, 12L * 3600, 0)),
517+
DurationCase("'12:34' HOUR TO MINUTE", Duration(0, 0, 12L * 3600 + 34L * 60, 0)),
518+
DurationCase("'12:34:56.123456' HOUR TO SECOND", Duration(0, 0, 12L * 3600 + 34L * 60 + 56L, 123456000)),
519+
DurationCase("'42' MINUTE", Duration(0, 0, 42L * 60, 0)),
520+
DurationCase("'42:07.001002' MINUTE TO SECOND", Duration(0, 0, 42L * 60 + 7L, 1002000)),
521+
DurationCase("'59.000001' SECOND", Duration(0, 0, 59L, 1000)),
522+
DurationCase(
523+
"timestamp('2025-01-02 18:30:00.454') - timestamp('2024-01-01 00:00:00')",
524+
Duration(0, 367, 66600L, 454000000)
525+
),
526+
// YEAR/MONTH -> YearMonthIntervalType
527+
DurationCase("'3' YEAR", Duration(36, 0, 0, 0), classOf[YearMonthIntervalType]),
528+
DurationCase("'7' MONTH", Duration(7, 0, 0, 0), classOf[YearMonthIntervalType]),
529+
DurationCase("'4-5' YEAR TO MONTH", Duration(53, 0, 0, 0), classOf[YearMonthIntervalType])
530+
)
531+
532+
@Test
533+
@Parameters(method = "sqlDurationCases")
534+
def `interval literals map to native neo4j durations`(testCase: DurationCase): Unit = {
535+
val id = java.util.UUID.randomUUID().toString
536+
val df = sparkSession.sql(s"SELECT '$id' AS id, ${testCase.sql} AS duration")
537+
538+
df.write
539+
.format(classOf[DataSource].getName)
540+
.mode(SaveMode.Append)
541+
.option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl)
542+
.option("labels", "Dur")
543+
.save()
544+
545+
val wantType = testCase.expectedDt.getSimpleName
546+
val gotType = df.schema("duration").dataType
547+
assertTrue(s"expected Spark to pick $wantType but it was $gotType", testCase.expectedDt.isInstance(gotType))
548+
549+
val gotDuration = SparkConnectorScalaSuiteIT.session().run(
550+
s"""MATCH (d:Dur {id: '$id'})
551+
|RETURN d.duration AS duration
552+
|""".stripMargin
553+
).single().get("duration").asIsoDuration()
554+
555+
assertEquals(s"${testCase.sql} -> months", testCase.duration.months, gotDuration.months)
556+
assertEquals(s"${testCase.sql} -> days", testCase.duration.days, gotDuration.days)
557+
assertEquals(s"${testCase.sql} -> seconds", testCase.duration.seconds, gotDuration.seconds)
558+
assertEquals(s"${testCase.sql} -> nanos", testCase.duration.nanoseconds, gotDuration.nanoseconds)
559+
}
560+
561+
private val sqlDurationArrayCases: java.util.List[Seq[DurationCase]] = java.util.Arrays.asList(
562+
Seq(
563+
DurationCase("'10 05:30:15.123' DAY TO SECOND", null),
564+
DurationCase("'0 00:00:01.000' DAY TO SECOND", null)
565+
),
566+
Seq(
567+
DurationCase("timestamp('2024-01-02 00:00:00') - timestamp('2024-01-01 00:00:00')", null),
568+
DurationCase("timestamp('2024-01-01 00:00:00') - current_timestamp()", null)
569+
),
570+
Seq(
571+
DurationCase("'1-02' YEAR TO MONTH", null, classOf[YearMonthIntervalType]),
572+
DurationCase("'0-11' YEAR TO MONTH", null, classOf[YearMonthIntervalType])
573+
)
574+
)
575+
576+
@Test
577+
@Parameters(method = "sqlDurationArrayCases")
578+
def `should write interval arrays as native neo4j durations`(testCase: Seq[DurationCase]): Unit = {
579+
val id = java.util.UUID.randomUUID().toString
580+
val expectedDt = testCase.head.expectedDt
581+
val sqlArray = testCase.map(_.sql).mkString("array(", ", ", ")")
582+
val df = sparkSession.sql(s"SELECT '$id' AS id, $sqlArray AS durations")
583+
584+
df.write
585+
.format(classOf[DataSource].getName)
586+
.mode(SaveMode.Append)
587+
.option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl)
588+
.option("labels", "DurArr")
589+
.save()
590+
591+
val gotType = df.schema("durations").dataType
592+
593+
assertTrue(
594+
s"expected Spark to infer ArrayType(${expectedDt.getSimpleName}) but it was $gotType",
595+
gotType match {
596+
case ArrayType(et, _) if expectedDt.isInstance(et) => true
597+
case _ => false
598+
}
599+
)
600+
601+
val result = SparkConnectorScalaSuiteIT.session().run(
602+
s"""MATCH (d:DurArr {id: '$id'})
603+
|RETURN d.durations AS durations
604+
|""".stripMargin
605+
).single().get("durations")
606+
607+
assertTrue(
608+
s"expected successful conversion to IsoDuration array, but it failed: $result",
609+
try {
610+
val _ = result.asList((v: Value) => v.asIsoDuration())
611+
true
612+
} catch {
613+
case _: Uncoercible => false
614+
case e => throw e
615+
}
616+
)
617+
}
618+
487619
@Test
488620
def `should write nodes into Neo4j with points`(): Unit = {
489621
val total = 10

0 commit comments

Comments
 (0)