@@ -4,7 +4,7 @@ use std::sync::Arc;
44use super :: schema:: postgres_to_delta_schema;
55use deltalake:: arrow:: record_batch:: RecordBatch ;
66use deltalake:: { DeltaOps , DeltaResult , DeltaTable , DeltaTableBuilder , open_table} ;
7- use etl:: types:: TableSchema ;
7+ use etl:: types:: { TableSchema , TableRow , Cell } ;
88
99/// Client for connecting to Delta Lake tables.
1010#[ derive( Clone ) ]
@@ -220,27 +220,56 @@ impl DeltaLakeClient {
220220 primary_keys : & HashSet < String > ,
221221 pk_column_names : & [ String ] ,
222222 ) -> String {
223- // todo(abhi): Implement proper predicate building for primary key matching
224- // todo(abhi): Handle composite primary keys
225- // todo(abhi): Handle SQL injection prevention
226- // todo(abhi): Build disjunction for multiple keys
227-
228223 if primary_keys. is_empty ( ) {
229224 return "false" . to_string ( ) ; // No rows to match
230225 }
231226
232- // Simple single-column PK case for now
233- if pk_column_names. len ( ) == 1 {
234- let pk_column = & pk_column_names[ 0 ] ;
235- let keys: Vec < String > = primary_keys. iter ( ) . map ( |k| format ! ( "'{k}'" ) ) . collect ( ) ;
236- return format ! ( "{} IN ({})" , pk_column, keys. join( ", " ) ) ;
227+ if pk_column_names. is_empty ( ) {
228+ return "false" . to_string ( ) ; // No PK columns
237229 }
238230
239- // todo(abhi): Handle composite primary keys
240- // For composite keys, need to build something like:
241- // (col1 = 'val1' AND col2 = 'val2') OR (col1 = 'val3' AND col2 = 'val4') ...
242-
243- "false" . to_string ( ) // Fallback
231+ if pk_column_names. len ( ) == 1 {
232+ // Single column primary key: col IN ('val1', 'val2', ...)
233+ let pk_column = Self :: escape_identifier ( & pk_column_names[ 0 ] ) ;
234+ let escaped_keys: Vec < String > = primary_keys
235+ . iter ( )
236+ . map ( |k| Self :: escape_string_literal ( k) )
237+ . collect ( ) ;
238+ format ! ( "{} IN ({})" , pk_column, escaped_keys. join( ", " ) )
239+ } else {
240+ // Composite primary key: (col1 = 'val1' AND col2 = 'val2') OR (col1 = 'val3' AND col2 = 'val4') ...
241+ let conditions: Vec < String > = primary_keys
242+ . iter ( )
243+ . map ( |composite_key| {
244+ let key_parts = Self :: split_composite_key ( composite_key) ;
245+ if key_parts. len ( ) != pk_column_names. len ( ) {
246+ // Malformed composite key, skip
247+ return "false" . to_string ( ) ;
248+ }
249+
250+ let conditions: Vec < String > = pk_column_names
251+ . iter ( )
252+ . zip ( key_parts. iter ( ) )
253+ . map ( |( col, val) | {
254+ format ! (
255+ "{} = {}" ,
256+ Self :: escape_identifier( col) ,
257+ Self :: escape_string_literal( val)
258+ )
259+ } )
260+ . collect ( ) ;
261+
262+ format ! ( "({})" , conditions. join( " AND " ) )
263+ } )
264+ . filter ( |cond| cond != "false" ) // Remove malformed conditions
265+ . collect ( ) ;
266+
267+ if conditions. is_empty ( ) {
268+ "false" . to_string ( )
269+ } else {
270+ conditions. join ( " OR " )
271+ }
272+ }
244273 }
245274
246275 /// Generate app-level transaction ID for idempotency
@@ -274,4 +303,281 @@ impl DeltaLakeClient {
274303
275304 Ok ( vec ! [ ] ) // No missing columns for now
276305 }
306+
307+ /// Extract primary key from a TableRow using the table schema
308+ pub fn extract_primary_key (
309+ & self ,
310+ table_row : & TableRow ,
311+ table_schema : & TableSchema ,
312+ ) -> Result < String , String > {
313+ let pk_columns: Vec < & str > = table_schema
314+ . column_schemas
315+ . iter ( )
316+ . enumerate ( )
317+ . filter_map ( |( idx, col) | {
318+ if col. primary {
319+ Some ( ( idx, col. name . as_str ( ) ) )
320+ } else {
321+ None
322+ }
323+ } )
324+ . map ( |( _, name) | name)
325+ . collect ( ) ;
326+
327+ if pk_columns. is_empty ( ) {
328+ return Err ( "No primary key columns found in table schema" . to_string ( ) ) ;
329+ }
330+
331+ let pk_indices: Vec < usize > = table_schema
332+ . column_schemas
333+ . iter ( )
334+ . enumerate ( )
335+ . filter_map ( |( idx, col) | if col. primary { Some ( idx) } else { None } )
336+ . collect ( ) ;
337+
338+ if pk_indices. len ( ) != pk_columns. len ( ) {
339+ return Err ( "Mismatch between PK column count and indices" . to_string ( ) ) ;
340+ }
341+
342+ // Check that all PK indices are within bounds
343+ for & idx in & pk_indices {
344+ if idx >= table_row. values . len ( ) {
345+ return Err ( format ! (
346+ "Primary key column index {} out of bounds for row with {} columns" ,
347+ idx,
348+ table_row. values. len( )
349+ ) ) ;
350+ }
351+ }
352+
353+ if pk_columns. len ( ) == 1 {
354+ // Single column primary key
355+ let cell = & table_row. values [ pk_indices[ 0 ] ] ;
356+ Ok ( Self :: cell_to_string ( cell) )
357+ } else {
358+ // Composite primary key - join with delimiter
359+ let key_parts: Vec < String > = pk_indices
360+ . iter ( )
361+ . map ( |& idx| Self :: cell_to_string ( & table_row. values [ idx] ) )
362+ . collect ( ) ;
363+ Ok ( Self :: join_composite_key ( & key_parts) )
364+ }
365+ }
366+
367+ /// Convert a Cell to its string representation for primary key purposes
368+ fn cell_to_string ( cell : & Cell ) -> String {
369+ match cell {
370+ Cell :: Null => "NULL" . to_string ( ) ,
371+ Cell :: Bool ( b) => b. to_string ( ) ,
372+ Cell :: String ( s) => s. clone ( ) ,
373+ Cell :: I16 ( i) => i. to_string ( ) ,
374+ Cell :: I32 ( i) => i. to_string ( ) ,
375+ Cell :: I64 ( i) => i. to_string ( ) ,
376+ Cell :: U32 ( i) => i. to_string ( ) ,
377+ Cell :: F32 ( f) => f. to_string ( ) ,
378+ Cell :: F64 ( f) => f. to_string ( ) ,
379+ Cell :: Numeric ( n) => n. to_string ( ) ,
380+ Cell :: Date ( d) => d. to_string ( ) ,
381+ Cell :: Time ( t) => t. to_string ( ) ,
382+ Cell :: Timestamp ( ts) => ts. to_string ( ) ,
383+ Cell :: TimestampTz ( ts) => ts. to_string ( ) ,
384+ Cell :: Uuid ( u) => u. to_string ( ) ,
385+ Cell :: Json ( j) => j. to_string ( ) ,
386+ Cell :: Bytes ( b) => {
387+ let hex_string: String = b. iter ( ) . map ( |byte| format ! ( "{:02x}" , byte) ) . collect ( ) ;
388+ format ! ( "\\ x{}" , hex_string)
389+ } ,
390+ Cell :: Array ( _) => "[ARRAY]" . to_string ( ) , // Arrays shouldn't be PKs
391+ }
392+ }
393+
394+ /// Join composite key parts with a delimiter
395+ const COMPOSITE_KEY_DELIMITER : & ' static str = "::" ;
396+ const COMPOSITE_KEY_ESCAPE_REPLACEMENT : & ' static str = "::::" ;
397+
398+ fn join_composite_key ( parts : & [ String ] ) -> String {
399+ let escaped_parts: Vec < String > = parts
400+ . iter ( )
401+ . map ( |part| {
402+ part. replace (
403+ Self :: COMPOSITE_KEY_DELIMITER ,
404+ Self :: COMPOSITE_KEY_ESCAPE_REPLACEMENT ,
405+ )
406+ } )
407+ . collect ( ) ;
408+ escaped_parts. join ( Self :: COMPOSITE_KEY_DELIMITER )
409+ }
410+
411+ /// Split a composite key back into its parts
412+ fn split_composite_key ( composite_key : & str ) -> Vec < String > {
413+ // Split on single delimiter (::) but avoid splitting on escaped delimiter (::::)
414+ let mut parts = Vec :: new ( ) ;
415+ let mut current_part = String :: new ( ) ;
416+ let mut chars = composite_key. chars ( ) . peekable ( ) ;
417+
418+ while let Some ( ch) = chars. next ( ) {
419+ if ch == ':' {
420+ if chars. peek ( ) == Some ( & ':' ) {
421+ chars. next ( ) ; // consume second ':'
422+ if chars. peek ( ) == Some ( & ':' ) {
423+ // This is the escaped delimiter "::::" - treat as literal "::"
424+ chars. next ( ) ; // consume third ':'
425+ chars. next ( ) ; // consume fourth ':'
426+ current_part. push_str ( Self :: COMPOSITE_KEY_DELIMITER ) ;
427+ } else {
428+ // This is the actual delimiter "::" - split here
429+ parts. push ( current_part. clone ( ) ) ;
430+ current_part. clear ( ) ;
431+ }
432+ } else {
433+ // Single colon, just add it
434+ current_part. push ( ch) ;
435+ }
436+ } else {
437+ current_part. push ( ch) ;
438+ }
439+ }
440+
441+ // Add the final part
442+ if !current_part. is_empty ( ) || !parts. is_empty ( ) {
443+ parts. push ( current_part) ;
444+ }
445+
446+ parts
447+ }
448+
449+ /// Escape SQL identifier (column name)
450+ fn escape_identifier ( identifier : & str ) -> String {
451+ // For Delta Lake, use backticks for identifier escaping
452+ format ! ( "`{}`" , identifier. replace( '`' , "``" ) )
453+ }
454+
455+ /// Escape string literal for SQL
456+ fn escape_string_literal ( value : & str ) -> String {
457+ // Escape single quotes by doubling them
458+ format ! ( "'{}'" , value. replace( '\'' , "''" ) )
459+ }
460+
461+ /// Get primary key column names from table schema
462+ pub fn get_primary_key_columns ( table_schema : & TableSchema ) -> Vec < String > {
463+ table_schema
464+ . column_schemas
465+ . iter ( )
466+ . filter ( |col| col. primary )
467+ . map ( |col| col. name . clone ( ) )
468+ . collect ( )
469+ }
470+ }
471+
472+ #[ cfg( test) ]
473+ mod tests {
474+ use super :: * ;
475+ use etl:: types:: { ColumnSchema , TableName , Type , Cell , TableId , TableRow , TableSchema } ;
476+
477+ fn create_test_schema ( ) -> TableSchema {
478+ TableSchema :: new (
479+ TableId ( 1 ) ,
480+ TableName :: new ( "public" . to_string ( ) , "test_table" . to_string ( ) ) ,
481+ vec ! [
482+ ColumnSchema :: new( "id" . to_string( ) , Type :: INT4 , -1 , false , true ) ,
483+ ColumnSchema :: new( "name" . to_string( ) , Type :: TEXT , -1 , true , false ) ,
484+ ] ,
485+ )
486+ }
487+
488+ fn create_test_row ( id : i32 , name : & str ) -> TableRow {
489+ TableRow :: new ( vec ! [
490+ Cell :: I32 ( id) ,
491+ Cell :: String ( name. to_string( ) ) ,
492+ ] )
493+ }
494+
495+ #[ test]
496+ fn test_extract_primary_key_single_column ( ) {
497+ let client = DeltaLakeClient :: new ( None ) ;
498+ let schema = create_test_schema ( ) ;
499+ let row = create_test_row ( 42 , "test" ) ;
500+
501+ let result = client. extract_primary_key ( & row, & schema) ;
502+ assert ! ( result. is_ok( ) ) ;
503+ assert_eq ! ( result. unwrap( ) , "42" ) ;
504+ }
505+
506+ #[ test]
507+ fn test_extract_primary_key_composite ( ) {
508+ let client = DeltaLakeClient :: new ( None ) ;
509+ let mut schema = create_test_schema ( ) ;
510+ // Make both columns primary keys
511+ schema. column_schemas [ 1 ] . primary = true ;
512+
513+ let row = create_test_row ( 42 , "test" ) ;
514+
515+ let result = client. extract_primary_key ( & row, & schema) ;
516+ assert ! ( result. is_ok( ) ) ;
517+ assert_eq ! ( result. unwrap( ) , "42::test" ) ;
518+ }
519+
520+ #[ test]
521+ fn test_build_pk_predicate_single_column ( ) {
522+ let client = DeltaLakeClient :: new ( None ) ;
523+ let mut keys = HashSet :: new ( ) ;
524+ keys. insert ( "42" . to_string ( ) ) ;
525+ keys. insert ( "43" . to_string ( ) ) ;
526+
527+ let pk_columns = vec ! [ "id" . to_string( ) ] ;
528+ let predicate = client. build_pk_predicate ( & keys, & pk_columns) ;
529+
530+ // Should be `id` IN ('42', '43') - order may vary
531+ assert ! ( predicate. contains( "`id` IN" ) ) ;
532+ assert ! ( predicate. contains( "'42'" ) ) ;
533+ assert ! ( predicate. contains( "'43'" ) ) ;
534+ }
535+
536+ #[ test]
537+ fn test_build_pk_predicate_composite ( ) {
538+ let client = DeltaLakeClient :: new ( None ) ;
539+ let mut keys = HashSet :: new ( ) ;
540+ keys. insert ( "42::test" . to_string ( ) ) ;
541+ keys. insert ( "43::hello" . to_string ( ) ) ;
542+
543+ let pk_columns = vec ! [ "id" . to_string( ) , "name" . to_string( ) ] ;
544+ let predicate = client. build_pk_predicate ( & keys, & pk_columns) ;
545+
546+ // Should be (`id` = '42' AND `name` = 'test') OR (`id` = '43' AND `name` = 'hello')
547+ assert ! ( predicate. contains( "`id` = '42' AND `name` = 'test'" ) ) ;
548+ assert ! ( predicate. contains( "`id` = '43' AND `name` = 'hello'" ) ) ;
549+ assert ! ( predicate. contains( " OR " ) ) ;
550+ }
551+
552+ #[ test]
553+ fn test_build_pk_predicate_empty ( ) {
554+ let client = DeltaLakeClient :: new ( None ) ;
555+ let keys = HashSet :: new ( ) ;
556+ let pk_columns = vec ! [ "id" . to_string( ) ] ;
557+
558+ let predicate = client. build_pk_predicate ( & keys, & pk_columns) ;
559+ assert_eq ! ( predicate, "false" ) ;
560+ }
561+
562+ #[ test]
563+ fn test_composite_key_escape ( ) {
564+ let parts = vec ! [ "value::with::delimiter" . to_string( ) , "normal" . to_string( ) ] ;
565+ let composite = DeltaLakeClient :: join_composite_key ( & parts) ;
566+ assert_eq ! ( composite, "value::::with::::delimiter::normal" ) ;
567+
568+ let split_parts = DeltaLakeClient :: split_composite_key ( & composite) ;
569+ assert_eq ! ( split_parts, parts) ;
570+ }
571+
572+ #[ test]
573+ fn test_escape_identifier ( ) {
574+ assert_eq ! ( DeltaLakeClient :: escape_identifier( "normal" ) , "`normal`" ) ;
575+ assert_eq ! ( DeltaLakeClient :: escape_identifier( "with`backtick" ) , "`with``backtick`" ) ;
576+ }
577+
578+ #[ test]
579+ fn test_escape_string_literal ( ) {
580+ assert_eq ! ( DeltaLakeClient :: escape_string_literal( "normal" ) , "'normal'" ) ;
581+ assert_eq ! ( DeltaLakeClient :: escape_string_literal( "with'quote" ) , "'with''quote'" ) ;
582+ }
277583}
0 commit comments