1616 */
1717package org .neo4j .spark
1818
19+ import junitparams .JUnitParamsRunner
20+ import junitparams .Parameters
1921import org .apache .commons .lang3 .exception .ExceptionUtils
2022import org .apache .spark .SparkException
2123import org .apache .spark .sql .DataFrame
2224import org .apache .spark .sql .SaveMode
2325import org .apache .spark .sql .SparkSession
26+ import org .apache .spark .sql .types .ArrayType
27+ import org .apache .spark .sql .types .DataType
28+ import org .apache .spark .sql .types .DayTimeIntervalType
29+ import org .apache .spark .sql .types .YearMonthIntervalType
2430import org .junit
2531import org .junit .Assert ._
2632import org .junit .Ignore
2733import org .junit .Test
34+ import org .junit .runner .RunWith
2835import org .neo4j .driver .Result
2936import org .neo4j .driver .Transaction
3037import org .neo4j .driver .TransactionWork
3138import org .neo4j .driver .Value
3239import org .neo4j .driver .exceptions .ClientException
40+ import org .neo4j .driver .exceptions .value .Uncoercible
3341import org .neo4j .driver .internal .InternalPoint2D
3442import org .neo4j .driver .internal .InternalPoint3D
3543import org .neo4j .driver .internal .types .InternalTypeSystem
@@ -54,7 +62,7 @@ import scala.util.Random
5462
5563abstract class Neo4jType (`type` : String )
5664
57- case class Duration (`type` : String = " duration " , months : Long , days : Long , seconds : Long , nanoseconds : Long )
65+ case class Duration (months : Long , days : Long , seconds : Long , nanoseconds : Long , `type` : String = " duration " )
5866 extends Neo4jType (`type`)
5967
6068case class Point2d (`type` : String = " point-2d" , srid : Int , x : Double , y : Double ) extends Neo4jType (`type`)
@@ -73,8 +81,9 @@ case class SimplePerson(name: String, surname: String)
7381
7482case class EmptyRow [T ](data : T )
7583
84+ @ RunWith (classOf [JUnitParamsRunner ])
7685class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
77- val sparkSession = SparkSession .builder().getOrCreate()
86+ val sparkSession = SparkSession .builder().master( " local[*] " ). getOrCreate()
7887
7988 import sparkSession .implicits ._
8089
@@ -414,11 +423,11 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
414423 }
415424
416425 @ Test
417- def `should write nodes with duration values into Neo4j` (): Unit = {
426+ def `should write nodes with duration values into Neo4j from struct ` (): Unit = {
418427 val total = 10
419428 val ds = (1 to total)
420429 .map(i => i.toLong)
421- .map(i => EmptyRow (Duration (months = i, days = i, seconds = i, nanoseconds = i)))
430+ .map(i => EmptyRow (Duration (i, i, i, i)))
422431 .toDS()
423432
424433 ds.write
@@ -430,10 +439,10 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
430439
431440 val records = SparkConnectorScalaSuiteIT .session().run(
432441 """ MATCH (p:BeanWithDuration)
433- |RETURN p.data AS data
442+ |RETURN p.data AS duration
434443 |""" .stripMargin
435444 ).list().asScala
436- .map(r => r.get(" data " ).asIsoDuration())
445+ .map(r => r.get(" duration " ).asIsoDuration())
437446 .map(data => (data.months, data.days, data.seconds, data.nanoseconds))
438447 .toSet
439448
@@ -445,14 +454,14 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
445454 }
446455
447456 @ Test
448- def `should write nodes with duration array values into Neo4j` (): Unit = {
457+ def `should write nodes with duration array values into Neo4j from struct ` (): Unit = {
449458 val total = 10
450459 val ds = (1 to total)
451460 .map(i => i.toLong)
452461 .map(i =>
453462 EmptyRow (Seq (
454- Duration (months = i, days = i, seconds = i, nanoseconds = i),
455- Duration (months = i, days = i, seconds = i, nanoseconds = i)
463+ Duration (i, i, i, i),
464+ Duration (i, i, i, i)
456465 ))
457466 )
458467 .toDS()
@@ -484,6 +493,129 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
484493 assertEquals(expected, records)
485494 }
486495
496+ private case class DurationCase (
497+ intervalExpression : String ,
498+ duration : Duration ,
499+ expectedDt : Class [_ <: DataType ] = classOf [DayTimeIntervalType ]
500+ ) {
501+ private val isArithmetic = intervalExpression.startsWith(" timestamp" )
502+
503+ val sql : String = if (isArithmetic) {
504+ intervalExpression
505+ } else {
506+ s " INTERVAL $intervalExpression"
507+ }
508+ }
509+
510+ private def sqlDurationCases : java.util.List [DurationCase ] = java.util.Arrays .asList(
511+ // DAY/TIME -> DayTimeIntervalType
512+ DurationCase (" '3' DAY" , Duration (0 , 3 , 0 , 0 )),
513+ DurationCase (" '10 05' DAY TO HOUR" , Duration (0 , 10 , 5L * 3600 , 0 )),
514+ DurationCase (" '10 05:30' DAY TO MINUTE" , Duration (0 , 10 , 5L * 3600 + 30L * 60 , 0 )),
515+ DurationCase (" '10 05:30:15.123456' DAY TO SECOND" , Duration (0 , 10 , 5L * 3600 + 30L * 60 + 15L , 123456000 )),
516+ DurationCase (" '12' HOUR" , Duration (0 , 0 , 12L * 3600 , 0 )),
517+ DurationCase (" '12:34' HOUR TO MINUTE" , Duration (0 , 0 , 12L * 3600 + 34L * 60 , 0 )),
518+ DurationCase (" '12:34:56.123456' HOUR TO SECOND" , Duration (0 , 0 , 12L * 3600 + 34L * 60 + 56L , 123456000 )),
519+ DurationCase (" '42' MINUTE" , Duration (0 , 0 , 42L * 60 , 0 )),
520+ DurationCase (" '42:07.001002' MINUTE TO SECOND" , Duration (0 , 0 , 42L * 60 + 7L , 1002000 )),
521+ DurationCase (" '59.000001' SECOND" , Duration (0 , 0 , 59L , 1000 )),
522+ DurationCase (
523+ " timestamp('2025-01-02 18:30:00.454') - timestamp('2024-01-01 00:00:00')" ,
524+ Duration (0 , 367 , 66600L , 454000000 )
525+ ),
526+ // YEAR/MONTH -> YearMonthIntervalType
527+ DurationCase (" '3' YEAR" , Duration (36 , 0 , 0 , 0 ), classOf [YearMonthIntervalType ]),
528+ DurationCase (" '7' MONTH" , Duration (7 , 0 , 0 , 0 ), classOf [YearMonthIntervalType ]),
529+ DurationCase (" '4-5' YEAR TO MONTH" , Duration (53 , 0 , 0 , 0 ), classOf [YearMonthIntervalType ])
530+ )
531+
532+ @ Test
533+ @ Parameters (method = " sqlDurationCases" )
534+ def `interval literals map to native neo4j durations` (testCase : DurationCase ): Unit = {
535+ val id = java.util.UUID .randomUUID().toString
536+ val df = sparkSession.sql(s " SELECT ' $id' AS id, ${testCase.sql} AS duration " )
537+
538+ df.write
539+ .format(classOf [DataSource ].getName)
540+ .mode(SaveMode .Append )
541+ .option(" url" , SparkConnectorScalaSuiteIT .server.getBoltUrl)
542+ .option(" labels" , " Dur" )
543+ .save()
544+
545+ val wantType = testCase.expectedDt.getSimpleName
546+ val gotType = df.schema(" duration" ).dataType
547+ assertTrue(s " expected Spark to pick $wantType but it was $gotType" , testCase.expectedDt.isInstance(gotType))
548+
549+ val gotDuration = SparkConnectorScalaSuiteIT .session().run(
550+ s """ MATCH (d:Dur {id: ' $id'})
551+ |RETURN d.duration AS duration
552+ | """ .stripMargin
553+ ).single().get(" duration" ).asIsoDuration()
554+
555+ assertEquals(s " ${testCase.sql} -> months " , testCase.duration.months, gotDuration.months)
556+ assertEquals(s " ${testCase.sql} -> days " , testCase.duration.days, gotDuration.days)
557+ assertEquals(s " ${testCase.sql} -> seconds " , testCase.duration.seconds, gotDuration.seconds)
558+ assertEquals(s " ${testCase.sql} -> nanos " , testCase.duration.nanoseconds, gotDuration.nanoseconds)
559+ }
560+
561+ private val sqlDurationArrayCases : java.util.List [Seq [DurationCase ]] = java.util.Arrays .asList(
562+ Seq (
563+ DurationCase (" '10 05:30:15.123' DAY TO SECOND" , null ),
564+ DurationCase (" '0 00:00:01.000' DAY TO SECOND" , null )
565+ ),
566+ Seq (
567+ DurationCase (" timestamp('2024-01-02 00:00:00') - timestamp('2024-01-01 00:00:00')" , null ),
568+ DurationCase (" timestamp('2024-01-01 00:00:00') - current_timestamp()" , null )
569+ ),
570+ Seq (
571+ DurationCase (" '1-02' YEAR TO MONTH" , null , classOf [YearMonthIntervalType ]),
572+ DurationCase (" '0-11' YEAR TO MONTH" , null , classOf [YearMonthIntervalType ])
573+ )
574+ )
575+
576+ @ Test
577+ @ Parameters (method = " sqlDurationArrayCases" )
578+ def `should write interval arrays as native neo4j durations` (testCase : Seq [DurationCase ]): Unit = {
579+ val id = java.util.UUID .randomUUID().toString
580+ val expectedDt = testCase.head.expectedDt
581+ val sqlArray = testCase.map(_.sql).mkString(" array(" , " , " , " )" )
582+ val df = sparkSession.sql(s " SELECT ' $id' AS id, $sqlArray AS durations " )
583+
584+ df.write
585+ .format(classOf [DataSource ].getName)
586+ .mode(SaveMode .Append )
587+ .option(" url" , SparkConnectorScalaSuiteIT .server.getBoltUrl)
588+ .option(" labels" , " DurArr" )
589+ .save()
590+
591+ val gotType = df.schema(" durations" ).dataType
592+
593+ assertTrue(
594+ s " expected Spark to infer ArrayType( ${expectedDt.getSimpleName}) but it was $gotType" ,
595+ gotType match {
596+ case ArrayType (et, _) if expectedDt.isInstance(et) => true
597+ case _ => false
598+ }
599+ )
600+
601+ val result = SparkConnectorScalaSuiteIT .session().run(
602+ s """ MATCH (d:DurArr {id: ' $id'})
603+ |RETURN d.durations AS durations
604+ | """ .stripMargin
605+ ).single().get(" durations" )
606+
607+ assertTrue(
608+ s " expected successful conversion to IsoDuration array, but it failed: $result" ,
609+ try {
610+ val _ = result.asList((v : Value ) => v.asIsoDuration())
611+ true
612+ } catch {
613+ case _ : Uncoercible => false
614+ case e => throw e
615+ }
616+ )
617+ }
618+
487619 @ Test
488620 def `should write nodes into Neo4j with points` (): Unit = {
489621 val total = 10
0 commit comments