@@ -18,6 +18,7 @@ import (
1818 "context"
1919 "encoding/json"
2020 "fmt"
21+ "math/big"
2122 "strings"
2223
2324 sq "github.com/Masterminds/squirrel"
@@ -26,6 +27,7 @@ import (
2627 "github.com/conduitio/conduit-connector-postgres/internal"
2728 sdk "github.com/conduitio/conduit-connector-sdk"
2829 "github.com/jackc/pgx/v5"
30+ "github.com/shopspring/decimal"
2931)
3032
3133type Destination struct {
@@ -35,6 +37,7 @@ type Destination struct {
3537 getTableName destination.TableFn
3638
3739 conn * pgx.Conn
40+ dbInfo * internal.DbInfo
3841 stmtBuilder sq.StatementBuilderType
3942}
4043
@@ -61,6 +64,7 @@ func (d *Destination) Open(ctx context.Context) error {
6164 return fmt .Errorf ("invalid table name or table name function: %w" , err )
6265 }
6366
67+ d .dbInfo = internal .NewDbInfo (conn )
6468 return nil
6569}
6670
@@ -156,7 +160,7 @@ func (d *Destination) upsert(ctx context.Context, r opencdc.Record, b *pgx.Batch
156160 return fmt .Errorf ("failed to get table name for write: %w" , err )
157161 }
158162
159- query , args , err := d .formatUpsertQuery (key , payload , keyColumnName , tableName )
163+ query , args , err := d .formatUpsertQuery (ctx , key , payload , keyColumnName , tableName )
160164 if err != nil {
161165 return fmt .Errorf ("error formatting query: %w" , err )
162166 }
@@ -215,7 +219,11 @@ func (d *Destination) insert(ctx context.Context, r opencdc.Record, b *pgx.Batch
215219 return err
216220 }
217221
218- colArgs , valArgs := d .formatColumnsAndValues (key , payload )
222+ colArgs , valArgs , err := d .formatColumnsAndValues (ctx , tableName , key , payload )
223+ if err != nil {
224+ return fmt .Errorf ("error formatting columns and values: %w" , err )
225+ }
226+
219227 sdk .Logger (ctx ).Trace ().
220228 Str ("table_name" , tableName ).
221229 Msg ("inserting record" )
@@ -272,12 +280,7 @@ func (d *Destination) structuredDataFormatter(data opencdc.Data) (opencdc.Struct
272280// * In our case, we can only rely on the record.Key's parsed key value.
273281// * If other schema constraints prevent a write, this won't upsert on
274282// that conflict.
275- func (d * Destination ) formatUpsertQuery (
276- key opencdc.StructuredData ,
277- payload opencdc.StructuredData ,
278- keyColumnName string ,
279- tableName string ,
280- ) (string , []interface {}, error ) {
283+ func (d * Destination ) formatUpsertQuery (ctx context.Context , key opencdc.StructuredData , payload opencdc.StructuredData , keyColumnName string , tableName string ) (string , []interface {}, error ) {
281284 upsertQuery := fmt .Sprintf ("ON CONFLICT (%s) DO UPDATE SET" , internal .WrapSQLIdent (keyColumnName ))
282285 for column := range payload {
283286 // tuples form a comma separated list, so they need a comma at the end.
@@ -294,10 +297,13 @@ func (d *Destination) formatUpsertQuery(
294297 // remove the last comma from the list of tuples
295298 upsertQuery = strings .TrimSuffix (upsertQuery , "," )
296299
297- // we have to manually append a semi colon to the upsert sql;
300+ // we have to manually append a semicolon to the upsert sql;
298301 upsertQuery += ";"
299302
300- colArgs , valArgs := d .formatColumnsAndValues (key , payload )
303+ colArgs , valArgs , err := d .formatColumnsAndValues (ctx , tableName , key , payload )
304+ if err != nil {
305+ return "" , nil , fmt .Errorf ("error formatting columns and values: %w" , err )
306+ }
301307
302308 return d .stmtBuilder .
303309 Insert (internal .WrapSQLIdent (tableName )).
@@ -309,32 +315,40 @@ func (d *Destination) formatUpsertQuery(
309315
310316// formatColumnsAndValues turns the key and payload into a slice of ordered
311317// columns and values for upserting into Postgres.
312- func (d * Destination ) formatColumnsAndValues (key , payload opencdc.StructuredData ) ([]string , []interface {}) {
318+ func (d * Destination ) formatColumnsAndValues (ctx context. Context , table string , key , payload opencdc.StructuredData ) ([]string , []interface {}, error ) {
313319 var colArgs []string
314320 var valArgs []interface {}
315321
316322 // range over both the key and payload values in order to format the
317323 // query for args and values in proper order
318324 for key , val := range key {
319325 colArgs = append (colArgs , internal .WrapSQLIdent (key ))
320- valArgs = append (valArgs , val )
326+ formatted , err := d .formatValue (ctx , table , key , val )
327+ if err != nil {
328+ return nil , nil , fmt .Errorf ("error formatting value: %w" , err )
329+ }
330+ valArgs = append (valArgs , formatted )
321331 delete (payload , key ) // NB: Delete Key from payload arguments
322332 }
323333
324- for field , value := range payload {
334+ for field , val := range payload {
325335 colArgs = append (colArgs , internal .WrapSQLIdent (field ))
326- valArgs = append (valArgs , value )
336+ formatted , err := d .formatValue (ctx , table , field , val )
337+ if err != nil {
338+ return nil , nil , fmt .Errorf ("error formatting value: %w" , err )
339+ }
340+ valArgs = append (valArgs , formatted )
327341 }
328342
329- return colArgs , valArgs
343+ return colArgs , valArgs , nil
330344}
331345
332346// getKeyColumnName will return the name of the first item in the key or the
333347// connector-configured default name of the key column name.
334348func (d * Destination ) getKeyColumnName (key opencdc.StructuredData , defaultKeyName string ) string {
335349 if len (key ) > 1 {
336350 // Go maps aren't order preserving, so anything over len 1 will have
337- // non deterministic results until we handle composite keys.
351+ // non- deterministic results until we handle composite keys.
338352 panic ("composite keys not yet supported" )
339353 }
340354 for k := range key {
@@ -346,3 +360,31 @@ func (d *Destination) getKeyColumnName(key opencdc.StructuredData, defaultKeyNam
346360func (d * Destination ) hasKey (e opencdc.Record ) bool {
347361 return e .Key != nil && len (e .Key .Bytes ()) > 0
348362}
363+
364+ func (d * Destination ) formatValue (ctx context.Context , table string , column string , val interface {}) (interface {}, error ) {
365+ switch v := val .(type ) {
366+ case * big.Rat :
367+ return d .formatBigRat (ctx , table , column , v )
368+ case big.Rat :
369+ return d .formatBigRat (ctx , table , column , & v )
370+ default :
371+ return val , nil
372+ }
373+ }
374+
375+ // formatBigRat formats a big.Rat into a string that can be written into a NUMERIC/DECIMAL column.
376+ func (d * Destination ) formatBigRat (ctx context.Context , table string , column string , v * big.Rat ) (string , error ) {
377+ if v == nil {
378+ return "" , nil
379+ }
380+
381+ // we need to get the scale of the column so we that we can properly
382+ // round the result of dividing the input big.Rat's numerator and denominator.
383+ scale , err := d .dbInfo .GetNumericColumnScale (ctx , table , column )
384+ if err != nil {
385+ return "" , fmt .Errorf ("failed getting scale of numeric column: %w" , err )
386+ }
387+
388+ //nolint:gosec // no risk of overflow, because the scale in Pg is always <= 16383
389+ return decimal .NewFromBigRat (v , int32 (scale )).String (), nil
390+ }
0 commit comments