@@ -23,8 +23,9 @@ use std::sync::Arc;
2323use arrow:: array:: GenericStringArray ;
2424use arrow:: array:: {
2525 ArrayRef , BooleanArray , Float32Array , Float64Array , Int16Array , Int32Array ,
26- Int64Array , Int8Array , StringOffsetSizeTrait , UInt16Array , UInt32Array , UInt64Array ,
27- UInt8Array ,
26+ Int64Array , Int8Array , StringOffsetSizeTrait , TimestampMicrosecondArray ,
27+ TimestampMillisecondArray , TimestampNanosecondArray , TimestampSecondArray ,
28+ UInt16Array , UInt32Array , UInt64Array , UInt8Array ,
2829} ;
2930use arrow:: datatypes:: ArrowPrimitiveType ;
3031use arrow:: {
@@ -35,6 +36,7 @@ use arrow::{
3536use crate :: PhysicalExpr ;
3637use arrow:: array:: * ;
3738use arrow:: buffer:: { Buffer , MutableBuffer } ;
39+ use arrow:: datatypes:: TimeUnit ;
3840use datafusion_common:: ScalarValue ;
3941use datafusion_common:: { DataFusionError , Result } ;
4042use datafusion_expr:: ColumnarValue ;
@@ -134,8 +136,8 @@ macro_rules! make_contains_primitive {
134136 . iter( )
135137 . flat_map( |expr| match expr {
136138 ColumnarValue :: Scalar ( s) => match s {
137- ScalarValue :: $SCALAR_VALUE( Some ( v) ) => Some ( * v) ,
138- ScalarValue :: $SCALAR_VALUE( None ) => None ,
139+ ScalarValue :: $SCALAR_VALUE( Some ( v) , .. ) => Some ( * v) ,
140+ ScalarValue :: $SCALAR_VALUE( None , .. ) => None ,
139141 ScalarValue :: Utf8 ( None ) => None ,
140142 datatype => unimplemented!( "Unexpected type {} for InList" , datatype) ,
141143 } ,
@@ -451,6 +453,36 @@ impl PhysicalExpr for InListExpr {
451453 DataType :: LargeUtf8 => {
452454 self . compare_utf8 :: < i64 > ( array, list_values, self . negated )
453455 }
456+ DataType :: Timestamp ( unit, _) => match unit {
457+ TimeUnit :: Second => make_contains_primitive ! (
458+ array,
459+ list_values,
460+ self . negated,
461+ TimestampSecond ,
462+ TimestampSecondArray
463+ ) ,
464+ TimeUnit :: Millisecond => make_contains_primitive ! (
465+ array,
466+ list_values,
467+ self . negated,
468+ TimestampMillisecond ,
469+ TimestampMillisecondArray
470+ ) ,
471+ TimeUnit :: Microsecond => make_contains_primitive ! (
472+ array,
473+ list_values,
474+ self . negated,
475+ TimestampMicrosecond ,
476+ TimestampMicrosecondArray
477+ ) ,
478+ TimeUnit :: Nanosecond => make_contains_primitive ! (
479+ array,
480+ list_values,
481+ self . negated,
482+ TimestampNanosecond ,
483+ TimestampNanosecondArray
484+ ) ,
485+ } ,
454486 datatype => Result :: Err ( DataFusionError :: NotImplemented ( format ! (
455487 "InList does not support datatype {:?}." ,
456488 datatype
@@ -713,4 +745,108 @@ mod tests {
713745
714746 Ok ( ( ) )
715747 }
748+
749+ #[ test]
750+ fn in_list_set_timestamp ( ) -> Result < ( ) > {
751+ // Size at which to use a Set rather than Vec for `IN` / `NOT IN`
752+ // Value chosen by the benchmark at
753+ // https://github.com/apache/arrow-datafusion/pull/2156#discussion_r845198369
754+ // TODO: add switch codeGen in In_List
755+ let optimizer_inset_threshold: usize = 30 ;
756+
757+ let schema = Schema :: new ( vec ! [ Field :: new(
758+ "a" ,
759+ DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ,
760+ true ,
761+ ) ] ) ;
762+ let a = TimestampMicrosecondArray :: from ( vec ! [
763+ Some ( 1388588401000000000 ) ,
764+ Some ( 1288588501000000000 ) ,
765+ None ,
766+ ] ) ;
767+ let col_a = col ( "a" , & schema) ?;
768+ let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ Arc :: new( a) ] ) ?;
769+
770+ let mut list = vec ! [
771+ lit( ScalarValue :: TimestampMicrosecond (
772+ Some ( 1388588401000000000 ) ,
773+ None ,
774+ ) ) ,
775+ lit( ScalarValue :: TimestampMicrosecond ( None , None ) ) ,
776+ lit( ScalarValue :: TimestampMicrosecond (
777+ Some ( 1388588401000000001 ) ,
778+ None ,
779+ ) ) ,
780+ ] ;
781+ let start_ts = 1388588401000000001 ;
782+ for v in start_ts..( start_ts + optimizer_inset_threshold + 4 ) {
783+ list. push ( lit ( ScalarValue :: TimestampMicrosecond ( Some ( v as i64 ) , None ) ) ) ;
784+ }
785+
786+ in_list ! (
787+ batch,
788+ list. clone( ) ,
789+ & false ,
790+ vec![ Some ( true ) , None , None ] ,
791+ col_a. clone( )
792+ ) ;
793+
794+ in_list ! (
795+ batch,
796+ list. clone( ) ,
797+ & true ,
798+ vec![ Some ( false ) , None , None ] ,
799+ col_a. clone( )
800+ ) ;
801+
802+ Ok ( ( ) )
803+ }
804+
805+ #[ test]
806+ fn in_list_timestamp ( ) -> Result < ( ) > {
807+ let schema = Schema :: new ( vec ! [ Field :: new(
808+ "a" ,
809+ DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ,
810+ true ,
811+ ) ] ) ;
812+ let a = TimestampMicrosecondArray :: from ( vec ! [
813+ Some ( 1388588401000000000 ) ,
814+ Some ( 1288588501000000000 ) ,
815+ None ,
816+ ] ) ;
817+ let col_a = col ( "a" , & schema) ?;
818+ let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ Arc :: new( a) ] ) ?;
819+
820+ let list = vec ! [
821+ lit( ScalarValue :: TimestampMicrosecond (
822+ Some ( 1388588401000000000 ) ,
823+ None ,
824+ ) ) ,
825+ lit( ScalarValue :: TimestampMicrosecond (
826+ Some ( 1388588401000000001 ) ,
827+ None ,
828+ ) ) ,
829+ lit( ScalarValue :: TimestampMicrosecond (
830+ Some ( 1388588401000000002 ) ,
831+ None ,
832+ ) ) ,
833+ ] ;
834+
835+ in_list ! (
836+ batch,
837+ list. clone( ) ,
838+ & false ,
839+ vec![ Some ( true ) , Some ( false ) , None ] ,
840+ col_a. clone( )
841+ ) ;
842+
843+ in_list ! (
844+ batch,
845+ list. clone( ) ,
846+ & true ,
847+ vec![ Some ( false ) , Some ( true ) , None ] ,
848+ col_a. clone( )
849+ ) ;
850+ Ok ( ( ) )
851+ }
716852}
0 commit comments