@@ -23,7 +23,8 @@ use std::str::FromStr;
2323use std:: sync:: Arc ;
2424
2525use arrow_arith:: boolean:: { and, and_kleene, is_not_null, is_null, not, or, or_kleene} ;
26- use arrow_array:: { Array , ArrayRef , BooleanArray , RecordBatch } ;
26+ use arrow_array:: { Array , ArrayRef , BooleanArray , Datum as ArrowDatum , RecordBatch , Scalar } ;
27+ use arrow_cast:: cast:: cast;
2728use arrow_ord:: cmp:: { eq, gt, gt_eq, lt, lt_eq, neq} ;
2829use arrow_schema:: {
2930 ArrowError , DataType , FieldRef , Schema as ArrowSchema , SchemaRef as ArrowSchemaRef ,
@@ -1103,6 +1104,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11031104
11041105 Ok ( Box :: new ( move |batch| {
11051106 let left = project_column ( & batch, idx) ?;
1107+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
11061108 lt ( & left, literal. as_ref ( ) )
11071109 } ) )
11081110 } else {
@@ -1122,6 +1124,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11221124
11231125 Ok ( Box :: new ( move |batch| {
11241126 let left = project_column ( & batch, idx) ?;
1127+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
11251128 lt_eq ( & left, literal. as_ref ( ) )
11261129 } ) )
11271130 } else {
@@ -1141,6 +1144,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11411144
11421145 Ok ( Box :: new ( move |batch| {
11431146 let left = project_column ( & batch, idx) ?;
1147+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
11441148 gt ( & left, literal. as_ref ( ) )
11451149 } ) )
11461150 } else {
@@ -1160,6 +1164,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11601164
11611165 Ok ( Box :: new ( move |batch| {
11621166 let left = project_column ( & batch, idx) ?;
1167+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
11631168 gt_eq ( & left, literal. as_ref ( ) )
11641169 } ) )
11651170 } else {
@@ -1179,6 +1184,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11791184
11801185 Ok ( Box :: new ( move |batch| {
11811186 let left = project_column ( & batch, idx) ?;
1187+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
11821188 eq ( & left, literal. as_ref ( ) )
11831189 } ) )
11841190 } else {
@@ -1198,6 +1204,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
11981204
11991205 Ok ( Box :: new ( move |batch| {
12001206 let left = project_column ( & batch, idx) ?;
1207+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
12011208 neq ( & left, literal. as_ref ( ) )
12021209 } ) )
12031210 } else {
@@ -1217,6 +1224,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12171224
12181225 Ok ( Box :: new ( move |batch| {
12191226 let left = project_column ( & batch, idx) ?;
1227+ let literal = try_cast_literal ( & literal, left. data_type ( ) ) ?;
12201228 starts_with ( & left, literal. as_ref ( ) )
12211229 } ) )
12221230 } else {
@@ -1236,7 +1244,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12361244
12371245 Ok ( Box :: new ( move |batch| {
12381246 let left = project_column ( & batch, idx) ?;
1239-
1247+ let literal = try_cast_literal ( & literal , left . data_type ( ) ) ? ;
12401248 // update here if arrow ever adds a native not_starts_with
12411249 not ( & starts_with ( & left, literal. as_ref ( ) ) ?)
12421250 } ) )
@@ -1261,8 +1269,10 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12611269 Ok ( Box :: new ( move |batch| {
12621270 // update this if arrow ever adds a native is_in kernel
12631271 let left = project_column ( & batch, idx) ?;
1272+
12641273 let mut acc = BooleanArray :: from ( vec ! [ false ; batch. num_rows( ) ] ) ;
12651274 for literal in & literals {
1275+ let literal = try_cast_literal ( literal, left. data_type ( ) ) ?;
12661276 acc = or ( & acc, & eq ( & left, literal. as_ref ( ) ) ?) ?
12671277 }
12681278
@@ -1291,6 +1301,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
12911301 let left = project_column ( & batch, idx) ?;
12921302 let mut acc = BooleanArray :: from ( vec ! [ true ; batch. num_rows( ) ] ) ;
12931303 for literal in & literals {
1304+ let literal = try_cast_literal ( literal, left. data_type ( ) ) ?;
12941305 acc = and ( & acc, & neq ( & left, literal. as_ref ( ) ) ?) ?
12951306 }
12961307
@@ -1370,14 +1381,35 @@ impl<R: FileRead> AsyncFileReader for ArrowFileReader<R> {
13701381 }
13711382}
13721383
1384+ /// The Arrow type of an array that the Parquet reader reads may not match the exact Arrow type
1385+ /// that Iceberg uses for literals - but they are effectively the same logical type,
1386+ /// i.e. LargeUtf8 and Utf8 or Utf8View and Utf8 or Utf8View and LargeUtf8.
1387+ ///
1388+ /// The Arrow compute kernels that we use must match the type exactly, so first cast the literal
1389+ /// into the type of the batch we read from Parquet before sending it to the compute kernel.
1390+ fn try_cast_literal (
1391+ literal : & Arc < dyn ArrowDatum + Send + Sync > ,
1392+ column_type : & DataType ,
1393+ ) -> std:: result:: Result < Arc < dyn ArrowDatum + Send + Sync > , ArrowError > {
1394+ let literal_array = literal. get ( ) . 0 ;
1395+
1396+ // No cast required
1397+ if literal_array. data_type ( ) == column_type {
1398+ return Ok ( Arc :: clone ( literal) ) ;
1399+ }
1400+
1401+ let literal_array = cast ( literal_array, column_type) ?;
1402+ Ok ( Arc :: new ( Scalar :: new ( literal_array) ) )
1403+ }
1404+
13731405#[ cfg( test) ]
13741406mod tests {
13751407 use std:: collections:: { HashMap , HashSet } ;
13761408 use std:: fs:: File ;
13771409 use std:: sync:: Arc ;
13781410
13791411 use arrow_array:: cast:: AsArray ;
1380- use arrow_array:: { ArrayRef , RecordBatch , StringArray } ;
1412+ use arrow_array:: { ArrayRef , LargeStringArray , RecordBatch , StringArray } ;
13811413 use arrow_schema:: { DataType , Field , Schema as ArrowSchema , TimeUnit } ;
13821414 use futures:: TryStreamExt ;
13831415 use parquet:: arrow:: arrow_reader:: { RowSelection , RowSelector } ;
@@ -1573,7 +1605,8 @@ message schema {
15731605 // Expected: [NULL, "foo"].
15741606 let expected = vec ! [ None , Some ( "foo" . to_string( ) ) ] ;
15751607
1576- let ( file_io, schema, table_location, _temp_dir) = setup_kleene_logic ( data_for_col_a) ;
1608+ let ( file_io, schema, table_location, _temp_dir) =
1609+ setup_kleene_logic ( data_for_col_a, DataType :: Utf8 ) ;
15771610 let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
15781611
15791612 let result_data = test_perform_read ( predicate, schema, table_location, reader) . await ;
@@ -1594,14 +1627,88 @@ message schema {
15941627 // Expected: ["bar"].
15951628 let expected = vec ! [ Some ( "bar" . to_string( ) ) ] ;
15961629
1597- let ( file_io, schema, table_location, _temp_dir) = setup_kleene_logic ( data_for_col_a) ;
1630+ let ( file_io, schema, table_location, _temp_dir) =
1631+ setup_kleene_logic ( data_for_col_a, DataType :: Utf8 ) ;
15981632 let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
15991633
16001634 let result_data = test_perform_read ( predicate, schema, table_location, reader) . await ;
16011635
16021636 assert_eq ! ( result_data, expected) ;
16031637 }
16041638
1639+ #[ tokio:: test]
1640+ async fn test_predicate_cast_literal ( ) {
1641+ let predicates = vec ! [
1642+ // a == 'foo'
1643+ ( Reference :: new( "a" ) . equal_to( Datum :: string( "foo" ) ) , vec![
1644+ Some ( "foo" . to_string( ) ) ,
1645+ ] ) ,
1646+ // a != 'foo'
1647+ (
1648+ Reference :: new( "a" ) . not_equal_to( Datum :: string( "foo" ) ) ,
1649+ vec![ Some ( "bar" . to_string( ) ) ] ,
1650+ ) ,
1651+ // STARTS_WITH(a, 'foo')
1652+ ( Reference :: new( "a" ) . starts_with( Datum :: string( "f" ) ) , vec![
1653+ Some ( "foo" . to_string( ) ) ,
1654+ ] ) ,
1655+ // NOT STARTS_WITH(a, 'foo')
1656+ (
1657+ Reference :: new( "a" ) . not_starts_with( Datum :: string( "f" ) ) ,
1658+ vec![ Some ( "bar" . to_string( ) ) ] ,
1659+ ) ,
1660+ // a < 'foo'
1661+ ( Reference :: new( "a" ) . less_than( Datum :: string( "foo" ) ) , vec![
1662+ Some ( "bar" . to_string( ) ) ,
1663+ ] ) ,
1664+ // a <= 'foo'
1665+ (
1666+ Reference :: new( "a" ) . less_than_or_equal_to( Datum :: string( "foo" ) ) ,
1667+ vec![ Some ( "foo" . to_string( ) ) , Some ( "bar" . to_string( ) ) ] ,
1668+ ) ,
1669+ // a > 'foo'
1670+ (
1671+ Reference :: new( "a" ) . greater_than( Datum :: string( "bar" ) ) ,
1672+ vec![ Some ( "foo" . to_string( ) ) ] ,
1673+ ) ,
1674+ // a >= 'foo'
1675+ (
1676+ Reference :: new( "a" ) . greater_than_or_equal_to( Datum :: string( "foo" ) ) ,
1677+ vec![ Some ( "foo" . to_string( ) ) ] ,
1678+ ) ,
1679+ // a IN ('foo', 'bar')
1680+ (
1681+ Reference :: new( "a" ) . is_in( [ Datum :: string( "foo" ) , Datum :: string( "baz" ) ] ) ,
1682+ vec![ Some ( "foo" . to_string( ) ) ] ,
1683+ ) ,
1684+ // a NOT IN ('foo', 'bar')
1685+ (
1686+ Reference :: new( "a" ) . is_not_in( [ Datum :: string( "foo" ) , Datum :: string( "baz" ) ] ) ,
1687+ vec![ Some ( "bar" . to_string( ) ) ] ,
1688+ ) ,
1689+ ] ;
1690+
1691+ // Table data: ["foo", "bar"]
1692+ let data_for_col_a = vec ! [ Some ( "foo" . to_string( ) ) , Some ( "bar" . to_string( ) ) ] ;
1693+
1694+ let ( file_io, schema, table_location, _temp_dir) =
1695+ setup_kleene_logic ( data_for_col_a, DataType :: LargeUtf8 ) ;
1696+ let reader = ArrowReaderBuilder :: new ( file_io) . build ( ) ;
1697+
1698+ for ( predicate, expected) in predicates {
1699+ println ! ( "testing predicate {predicate}" ) ;
1700+ let result_data = test_perform_read (
1701+ predicate. clone ( ) ,
1702+ schema. clone ( ) ,
1703+ table_location. clone ( ) ,
1704+ reader. clone ( ) ,
1705+ )
1706+ . await ;
1707+
1708+ assert_eq ! ( result_data, expected, "predicate={predicate}" ) ;
1709+ }
1710+ }
1711+
16051712 async fn test_perform_read (
16061713 predicate : Predicate ,
16071714 schema : SchemaRef ,
@@ -1644,6 +1751,7 @@ message schema {
16441751
16451752 fn setup_kleene_logic (
16461753 data_for_col_a : Vec < Option < String > > ,
1754+ col_a_type : DataType ,
16471755 ) -> ( FileIO , SchemaRef , String , TempDir ) {
16481756 let schema = Arc :: new (
16491757 Schema :: builder ( )
@@ -1660,7 +1768,7 @@ message schema {
16601768
16611769 let arrow_schema = Arc :: new ( ArrowSchema :: new ( vec ! [ Field :: new(
16621770 "a" ,
1663- DataType :: Utf8 ,
1771+ col_a_type . clone ( ) ,
16641772 true ,
16651773 )
16661774 . with_metadata( HashMap :: from( [ (
@@ -1673,7 +1781,11 @@ message schema {
16731781
16741782 let file_io = FileIO :: from_path ( & table_location) . unwrap ( ) . build ( ) . unwrap ( ) ;
16751783
1676- let col = Arc :: new ( StringArray :: from ( data_for_col_a) ) as ArrayRef ;
1784+ let col = match col_a_type {
1785+ DataType :: Utf8 => Arc :: new ( StringArray :: from ( data_for_col_a) ) as ArrayRef ,
1786+ DataType :: LargeUtf8 => Arc :: new ( LargeStringArray :: from ( data_for_col_a) ) as ArrayRef ,
1787+ _ => panic ! ( "unexpected col_a_type" ) ,
1788+ } ;
16771789
16781790 let to_write = RecordBatch :: try_new ( arrow_schema. clone ( ) , vec ! [ col] ) . unwrap ( ) ;
16791791
0 commit comments