Skip to content

Commit 25f925d

Browse files
authored
Supporting reading and writing numeric/decimal values without loss of precision (#290)
1 parent 5b39f4c commit 25f925d

File tree

13 files changed

+479
-130
lines changed

13 files changed

+479
-130
lines changed

destination.go

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"context"
1919
"encoding/json"
2020
"fmt"
21+
"math/big"
2122
"strings"
2223

2324
sq "github.com/Masterminds/squirrel"
@@ -26,6 +27,7 @@ import (
2627
"github.com/conduitio/conduit-connector-postgres/internal"
2728
sdk "github.com/conduitio/conduit-connector-sdk"
2829
"github.com/jackc/pgx/v5"
30+
"github.com/shopspring/decimal"
2931
)
3032

3133
type Destination struct {
@@ -35,6 +37,7 @@ type Destination struct {
3537
getTableName destination.TableFn
3638

3739
conn *pgx.Conn
40+
dbInfo *internal.DbInfo
3841
stmtBuilder sq.StatementBuilderType
3942
}
4043

@@ -61,6 +64,7 @@ func (d *Destination) Open(ctx context.Context) error {
6164
return fmt.Errorf("invalid table name or table name function: %w", err)
6265
}
6366

67+
d.dbInfo = internal.NewDbInfo(conn)
6468
return nil
6569
}
6670

@@ -156,7 +160,7 @@ func (d *Destination) upsert(ctx context.Context, r opencdc.Record, b *pgx.Batch
156160
return fmt.Errorf("failed to get table name for write: %w", err)
157161
}
158162

159-
query, args, err := d.formatUpsertQuery(key, payload, keyColumnName, tableName)
163+
query, args, err := d.formatUpsertQuery(ctx, key, payload, keyColumnName, tableName)
160164
if err != nil {
161165
return fmt.Errorf("error formatting query: %w", err)
162166
}
@@ -215,7 +219,11 @@ func (d *Destination) insert(ctx context.Context, r opencdc.Record, b *pgx.Batch
215219
return err
216220
}
217221

218-
colArgs, valArgs := d.formatColumnsAndValues(key, payload)
222+
colArgs, valArgs, err := d.formatColumnsAndValues(ctx, tableName, key, payload)
223+
if err != nil {
224+
return fmt.Errorf("error formatting columns and values: %w", err)
225+
}
226+
219227
sdk.Logger(ctx).Trace().
220228
Str("table_name", tableName).
221229
Msg("inserting record")
@@ -272,12 +280,7 @@ func (d *Destination) structuredDataFormatter(data opencdc.Data) (opencdc.Struct
272280
// * In our case, we can only rely on the record.Key's parsed key value.
273281
// * If other schema constraints prevent a write, this won't upsert on
274282
// that conflict.
275-
func (d *Destination) formatUpsertQuery(
276-
key opencdc.StructuredData,
277-
payload opencdc.StructuredData,
278-
keyColumnName string,
279-
tableName string,
280-
) (string, []interface{}, error) {
283+
func (d *Destination) formatUpsertQuery(ctx context.Context, key opencdc.StructuredData, payload opencdc.StructuredData, keyColumnName string, tableName string) (string, []interface{}, error) {
281284
upsertQuery := fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", internal.WrapSQLIdent(keyColumnName))
282285
for column := range payload {
283286
// tuples form a comma separated list, so they need a comma at the end.
@@ -294,10 +297,13 @@ func (d *Destination) formatUpsertQuery(
294297
// remove the last comma from the list of tuples
295298
upsertQuery = strings.TrimSuffix(upsertQuery, ",")
296299

297-
// we have to manually append a semi colon to the upsert sql;
300+
// we have to manually append a semicolon to the upsert sql;
298301
upsertQuery += ";"
299302

300-
colArgs, valArgs := d.formatColumnsAndValues(key, payload)
303+
colArgs, valArgs, err := d.formatColumnsAndValues(ctx, tableName, key, payload)
304+
if err != nil {
305+
return "", nil, fmt.Errorf("error formatting columns and values: %w", err)
306+
}
301307

302308
return d.stmtBuilder.
303309
Insert(internal.WrapSQLIdent(tableName)).
@@ -309,32 +315,40 @@ func (d *Destination) formatUpsertQuery(
309315

310316
// formatColumnsAndValues turns the key and payload into a slice of ordered
311317
// columns and values for upserting into Postgres.
312-
func (d *Destination) formatColumnsAndValues(key, payload opencdc.StructuredData) ([]string, []interface{}) {
318+
func (d *Destination) formatColumnsAndValues(ctx context.Context, table string, key, payload opencdc.StructuredData) ([]string, []interface{}, error) {
313319
var colArgs []string
314320
var valArgs []interface{}
315321

316322
// range over both the key and payload values in order to format the
317323
// query for args and values in proper order
318324
for key, val := range key {
319325
colArgs = append(colArgs, internal.WrapSQLIdent(key))
320-
valArgs = append(valArgs, val)
326+
formatted, err := d.formatValue(ctx, table, key, val)
327+
if err != nil {
328+
return nil, nil, fmt.Errorf("error formatting value: %w", err)
329+
}
330+
valArgs = append(valArgs, formatted)
321331
delete(payload, key) // NB: Delete Key from payload arguments
322332
}
323333

324-
for field, value := range payload {
334+
for field, val := range payload {
325335
colArgs = append(colArgs, internal.WrapSQLIdent(field))
326-
valArgs = append(valArgs, value)
336+
formatted, err := d.formatValue(ctx, table, field, val)
337+
if err != nil {
338+
return nil, nil, fmt.Errorf("error formatting value: %w", err)
339+
}
340+
valArgs = append(valArgs, formatted)
327341
}
328342

329-
return colArgs, valArgs
343+
return colArgs, valArgs, nil
330344
}
331345

332346
// getKeyColumnName will return the name of the first item in the key or the
333347
// connector-configured default name of the key column name.
334348
func (d *Destination) getKeyColumnName(key opencdc.StructuredData, defaultKeyName string) string {
335349
if len(key) > 1 {
336350
// Go maps aren't order preserving, so anything over len 1 will have
337-
// non deterministic results until we handle composite keys.
351+
// non-deterministic results until we handle composite keys.
338352
panic("composite keys not yet supported")
339353
}
340354
for k := range key {
@@ -346,3 +360,31 @@ func (d *Destination) getKeyColumnName(key opencdc.StructuredData, defaultKeyNam
346360
func (d *Destination) hasKey(e opencdc.Record) bool {
347361
return e.Key != nil && len(e.Key.Bytes()) > 0
348362
}
363+
364+
func (d *Destination) formatValue(ctx context.Context, table string, column string, val interface{}) (interface{}, error) {
365+
switch v := val.(type) {
366+
case *big.Rat:
367+
return d.formatBigRat(ctx, table, column, v)
368+
case big.Rat:
369+
return d.formatBigRat(ctx, table, column, &v)
370+
default:
371+
return val, nil
372+
}
373+
}
374+
375+
// formatBigRat formats a big.Rat into a string that can be written into a NUMERIC/DECIMAL column.
376+
func (d *Destination) formatBigRat(ctx context.Context, table string, column string, v *big.Rat) (string, error) {
377+
if v == nil {
378+
return "", nil
379+
}
380+
381+
// we need to get the scale of the column so we that we can properly
382+
// round the result of dividing the input big.Rat's numerator and denominator.
383+
scale, err := d.dbInfo.GetNumericColumnScale(ctx, table, column)
384+
if err != nil {
385+
return "", fmt.Errorf("failed getting scale of numeric column: %w", err)
386+
}
387+
388+
//nolint:gosec // no risk of overflow, because the scale in Pg is always <= 16383
389+
return decimal.NewFromBigRat(v, int32(scale)).String(), nil
390+
}

0 commit comments

Comments
 (0)