@@ -18,7 +18,9 @@ import (
1818 "context"
1919 "encoding/json"
2020 "fmt"
21+ "maps"
2122 "math/big"
23+ "slices"
2224 "strings"
2325
2426 sq "github.com/Masterminds/squirrel"
@@ -74,6 +76,10 @@ func (d *Destination) Write(ctx context.Context, recs []opencdc.Record) (int, er
7476 b := & pgx.Batch {}
7577 for _ , rec := range recs {
7678 var err error
79+ rec , err = d .ensureStructuredData (rec )
80+ if err != nil {
81+ return 0 , fmt .Errorf ("failed to clean record: %w" , err )
82+ }
7783 switch rec .Operation {
7884 case opencdc .OperationCreate :
7985 err = d .handleInsert (ctx , rec , b )
@@ -117,9 +123,6 @@ func (d *Destination) Teardown(ctx context.Context) error {
117123// exists and no key column name is configured, it will plainly insert the data.
118124// Otherwise it upserts the record.
119125func (d * Destination ) handleInsert (ctx context.Context , r opencdc.Record , b * pgx.Batch ) error {
120- if ! d .hasKey (r ) || d .config .Key == "" {
121- return d .insert (ctx , r , b )
122- }
123126 return d .upsert (ctx , r , b )
124127}
125128
@@ -143,179 +146,114 @@ func (d *Destination) handleDelete(ctx context.Context, r opencdc.Record, b *pgx
143146}
144147
145148func (d * Destination ) upsert (ctx context.Context , r opencdc.Record , b * pgx.Batch ) error {
146- payload , err := d .getPayload (r )
147- if err != nil {
148- return fmt .Errorf ("failed to get payload: %w" , err )
149- }
150-
151- key , err := d .getKey (r )
152- if err != nil {
153- return fmt .Errorf ("failed to get key: %w" , err )
154- }
155-
156- keyColumnName := d .getKeyColumnName (key , d .config .Key )
157-
149+ payload := r .Payload .After .(opencdc.StructuredData )
150+ key := r .Key .(opencdc.StructuredData )
158151 tableName , err := d .getTableName (r )
159152 if err != nil {
160- return fmt .Errorf ("failed to get table name for write : %w" , err )
153+ return fmt .Errorf ("failed to get table name for upsert : %w" , err )
161154 }
162155
163- query , args , err := d .formatUpsertQuery (ctx , key , payload , keyColumnName , tableName )
156+ query , args , err := d .formatUpsertQuery (ctx , key , payload , tableName )
164157 if err != nil {
165158 return fmt .Errorf ("error formatting query: %w" , err )
166159 }
167160 sdk .Logger (ctx ).Trace ().
168- Str ("table_name" , tableName ).
169- Any ("key" , map [string ]interface {}{keyColumnName : key [keyColumnName ]}).
161+ Str ("table" , tableName ).
162+ Str ("query" , query ).
163+ Any ("key" , key ).
170164 Msg ("upserting record" )
171165
172166 b .Queue (query , args ... )
173167 return nil
174168}
175169
176170func (d * Destination ) remove (ctx context.Context , r opencdc.Record , b * pgx.Batch ) error {
177- key , err := d .getKey (r )
178- if err != nil {
179- return err
180- }
181- keyColumnName := d .getKeyColumnName (key , d .config .Key )
171+ key := r .Key .(opencdc.StructuredData )
182172 tableName , err := d .getTableName (r )
183173 if err != nil {
184- return fmt .Errorf ("failed to get table name for write: %w" , err )
174+ return fmt .Errorf ("failed to get table name for delete: %w" , err )
175+ }
176+
177+ where := make (sq.Eq )
178+ for col , val := range key {
179+ where [internal .WrapSQLIdent (col )] = val
185180 }
186181
187- sdk .Logger (ctx ).Trace ().
188- Str ("table_name" , tableName ).
189- Any ("key" , map [string ]interface {}{keyColumnName : key [keyColumnName ]}).
190- Msg ("deleting record" )
191182 query , args , err := d .stmtBuilder .
192183 Delete (internal .WrapSQLIdent (tableName )).
193- Where (sq. Eq { internal . WrapSQLIdent ( keyColumnName ): key [ keyColumnName ]} ).
184+ Where (where ).
194185 ToSql ()
195186 if err != nil {
196187 return fmt .Errorf ("error formatting delete query: %w" , err )
197188 }
198189
199- b .Queue (query , args ... )
200- return nil
201- }
202-
203- // insert is an append-only operation that doesn't care about keys, but
204- // can error on constraints violations so should only be used when no table
205- // key or unique constraints are otherwise present.
206- func (d * Destination ) insert (ctx context.Context , r opencdc.Record , b * pgx.Batch ) error {
207- tableName , err := d .getTableName (r )
208- if err != nil {
209- return err
210- }
211-
212- key , err := d .getKey (r )
213- if err != nil {
214- return err
215- }
216-
217- payload , err := d .getPayload (r )
218- if err != nil {
219- return err
220- }
221-
222- colArgs , valArgs , err := d .formatColumnsAndValues (ctx , tableName , key , payload )
223- if err != nil {
224- return fmt .Errorf ("error formatting columns and values: %w" , err )
225- }
226-
227190 sdk .Logger (ctx ).Trace ().
228- Str ("table_name" , tableName ).
229- Msg ("inserting record" )
230- query , args , err := d .stmtBuilder .
231- Insert (internal .WrapSQLIdent (tableName )).
232- Columns (colArgs ... ).
233- Values (valArgs ... ).
234- ToSql ()
235- if err != nil {
236- return fmt .Errorf ("error formatting insert query: %w" , err )
237- }
191+ Str ("table" , tableName ).
192+ Str ("query" , query ).
193+ Any ("key" , key ).
194+ Msg ("deleting record" )
238195
239196 b .Queue (query , args ... )
240197 return nil
241198}
242199
243- func (d * Destination ) getPayload (r opencdc.Record ) (opencdc.StructuredData , error ) {
244- if r .Payload .After == nil {
245- return opencdc.StructuredData {}, nil
246- }
247- return d .structuredDataFormatter (r .Payload .After )
248- }
249-
250- func (d * Destination ) getKey (r opencdc.Record ) (opencdc.StructuredData , error ) {
251- if r .Key == nil {
252- return opencdc.StructuredData {}, nil
253- }
254- return d .structuredDataFormatter (r .Key )
255- }
256-
257- func (d * Destination ) structuredDataFormatter (data opencdc.Data ) (opencdc.StructuredData , error ) {
258- if data == nil {
259- return opencdc.StructuredData {}, nil
260- }
261- if sdata , ok := data .(opencdc.StructuredData ); ok {
262- return sdata , nil
263- }
264- raw := data .Bytes ()
265- if len (raw ) == 0 {
266- return opencdc.StructuredData {}, nil
267- }
268-
269- m := make (map [string ]interface {})
270- err := json .Unmarshal (raw , & m )
271- if err != nil {
272- return nil , err
273- }
274- return m , nil
275- }
276-
277200// formatUpsertQuery manually formats the UPSERT and ON CONFLICT query statements.
278201// The `ON CONFLICT` portion of this query needs to specify the constraint
279202// name.
280203// * In our case, we can only rely on the record.Key's parsed key value.
281204// * If other schema constraints prevent a write, this won't upsert on
282205// that conflict.
283- func (d * Destination ) formatUpsertQuery (ctx context.Context , key opencdc.StructuredData , payload opencdc.StructuredData , keyColumnName string , tableName string ) (string , []interface {}, error ) {
284- upsertQuery := fmt .Sprintf ("ON CONFLICT (%s) DO UPDATE SET" , internal .WrapSQLIdent (keyColumnName ))
285- for column := range payload {
286- // tuples form a comma separated list, so they need a comma at the end.
287- // `EXCLUDED` references the new record's values. This will overwrite
288- // every column's value except for the key column.
289- wrappedCol := internal .WrapSQLIdent (column )
290- tuple := fmt .Sprintf ("%s=EXCLUDED.%s," , wrappedCol , wrappedCol )
291- // TODO: Consider removing this space.
292- upsertQuery += " "
293- // add the tuple to the query string
294- upsertQuery += tuple
295- }
296-
297- // remove the last comma from the list of tuples
298- upsertQuery = strings .TrimSuffix (upsertQuery , "," )
299-
300- // we have to manually append a semicolon to the upsert sql;
301- upsertQuery += ";"
302-
303- colArgs , valArgs , err := d .formatColumnsAndValues (ctx , tableName , key , payload )
206+ func (d * Destination ) formatUpsertQuery (
207+ ctx context.Context ,
208+ key , payload opencdc.StructuredData ,
209+ tableName string ,
210+ ) (string , []interface {}, error ) {
211+ colArgs , valArgs , err := d .formatColumnsAndValues (ctx , key , payload , tableName )
304212 if err != nil {
305213 return "" , nil , fmt .Errorf ("error formatting columns and values: %w" , err )
306214 }
307215
308- return d .stmtBuilder .
216+ stmt := d .stmtBuilder .
309217 Insert (internal .WrapSQLIdent (tableName )).
310218 Columns (colArgs ... ).
311- Values (valArgs ... ).
312- SuffixExpr (sq .Expr (upsertQuery )).
313- ToSql ()
219+ Values (valArgs ... )
220+
221+ if len (key ) > 0 {
222+ keyColumns := slices .Collect (maps .Keys (key ))
223+ for i := range keyColumns {
224+ keyColumns [i ] = internal .WrapSQLIdent (keyColumns [i ])
225+ }
226+
227+ var setOnConflict []string
228+ for column := range payload {
229+ // tuples form a comma separated list, so they need a comma at the end.
230+ // `EXCLUDED` references the new record's values. This will overwrite
231+ // every column's value except for the key columns.
232+ wrappedCol := internal .WrapSQLIdent (column )
233+ tuple := fmt .Sprintf ("%s=EXCLUDED.%s" , wrappedCol , wrappedCol )
234+ // add the tuple to the query string
235+ setOnConflict = append (setOnConflict , tuple )
236+ }
237+
238+ upsertQuery := fmt .Sprintf (
239+ "ON CONFLICT (%s) DO UPDATE SET %s" ,
240+ strings .Join (keyColumns , "," ),
241+ strings .Join (setOnConflict , "," ),
242+ )
243+
244+ stmt = stmt .Suffix (upsertQuery )
245+ }
246+
247+ return stmt .ToSql ()
314248}
315249
316250// formatColumnsAndValues turns the key and payload into a slice of ordered
317251// columns and values for upserting into Postgres.
318- func (d * Destination ) formatColumnsAndValues (ctx context.Context , table string , key , payload opencdc.StructuredData ) ([]string , []interface {}, error ) {
252+ func (d * Destination ) formatColumnsAndValues (
253+ ctx context.Context ,
254+ key , payload opencdc.StructuredData ,
255+ table string ,
256+ ) ([]string , []interface {}, error ) {
319257 var colArgs []string
320258 var valArgs []interface {}
321259
@@ -343,22 +281,51 @@ func (d *Destination) formatColumnsAndValues(ctx context.Context, table string,
343281 return colArgs , valArgs , nil
344282}
345283
346- // getKeyColumnName will return the name of the first item in the key or the
347- // connector-configured default name of the key column name.
348- func (d * Destination ) getKeyColumnName (key opencdc.StructuredData , defaultKeyName string ) string {
349- if len (key ) > 1 {
350- // Go maps aren't order preserving, so anything over len 1 will have
351- // non-deterministic results until we handle composite keys.
352- panic ("composite keys not yet supported" )
284+ func (d * Destination ) hasKey (e opencdc.Record ) bool {
285+ structuredKey , ok := e .Key .(opencdc.StructuredData )
286+ if ! ok {
287+ return false
353288 }
354- for k := range key {
355- return k
289+ return len (structuredKey ) > 0
290+ }
291+
292+ // ensureStructuredData makes sure the record key and payload are structured data.
293+ func (d * Destination ) ensureStructuredData (r opencdc.Record ) (opencdc.Record , error ) {
294+ payloadAfter , err := d .structuredDataFormatter (r .Payload .After )
295+ if err != nil {
296+ return opencdc.Record {}, fmt .Errorf ("failed to get structured data for .Payload.After: %w" , err )
356297 }
357- return defaultKeyName
298+ key , err := d .structuredDataFormatter (r .Key )
299+ if err != nil {
300+ return opencdc.Record {}, fmt .Errorf ("failed to get structured data for .Key: %w" , err )
301+ }
302+
303+ r .Key = key
304+ r .Payload .After = payloadAfter
305+ return r , nil
358306}
359307
360- func (d * Destination ) hasKey (e opencdc.Record ) bool {
361- return e .Key != nil && len (e .Key .Bytes ()) > 0
308+ func (d * Destination ) structuredDataFormatter (data opencdc.Data ) (opencdc.StructuredData , error ) {
309+ switch data := data .(type ) {
310+ case opencdc.StructuredData :
311+ // already structured data, no need to convert
312+ return data , nil
313+ case opencdc.RawData :
314+ raw := data .Bytes ()
315+ if len (raw ) == 0 {
316+ return opencdc.StructuredData {}, nil
317+ }
318+ m := make (map [string ]interface {})
319+ err := json .Unmarshal (raw , & m )
320+ if err != nil {
321+ return nil , fmt .Errorf ("failed to JSON unmarshal raw data: %w" , err )
322+ }
323+ return m , nil
324+ case nil :
325+ return opencdc.StructuredData {}, nil
326+ default :
327+ return nil , fmt .Errorf ("unexpected data type %T, expected StructuredData or RawData" , data )
328+ }
362329}
363330
364331func (d * Destination ) formatValue (ctx context.Context , table string , column string , val interface {}) (interface {}, error ) {
0 commit comments