@@ -11,6 +11,7 @@ use std::collections::HashMap;
1111use std:: fmt;
1212use std:: io:: BufReader ;
1313use std:: sync:: Arc ;
14+ use tokio:: sync:: OnceCell ;
1415use tokio_postgres:: error:: SqlState ;
1516use tokio_postgres:: tls:: MakeTlsConnect ;
1617use tokio_postgres:: {
@@ -162,7 +163,7 @@ impl PgReplicationSlotTransaction {
162163/// and streaming changes from the database.
163164#[ derive( Debug , Clone ) ]
164165pub struct PgReplicationClient {
165- client : Arc < Client > ,
166+ client : Arc < ( Client , OnceCell < i32 > ) > ,
166167}
167168
168169impl PgReplicationClient {
@@ -177,6 +178,15 @@ impl PgReplicationClient {
177178 }
178179 }
179180
181+
182+ // Convenience method to avoid having to access the client directly.
183+ async fn simple_query (
184+ & self ,
185+ query : & str ,
186+ ) -> Result < Vec < SimpleQueryMessage > , tokio_postgres:: Error > {
187+ self . client . 0 . simple_query ( query) . await
188+ }
189+
180190 /// Establishes a connection to Postgres without TLS encryption.
181191 ///
182192 /// The connection is configured for logical replication mode.
@@ -190,7 +200,7 @@ impl PgReplicationClient {
190200 info ! ( "successfully connected to postgres without tls" ) ;
191201
192202 Ok ( PgReplicationClient {
193- client : Arc :: new ( client) ,
203+ client : Arc :: new ( ( client, OnceCell :: new ( ) ) ) ,
194204 } )
195205 }
196206
@@ -221,7 +231,7 @@ impl PgReplicationClient {
221231 info ! ( "successfully connected to postgres with tls" ) ;
222232
223233 Ok ( PgReplicationClient {
224- client : Arc :: new ( client) ,
234+ client : Arc :: new ( ( client, OnceCell :: new ( ) ) ) ,
225235 } )
226236 }
227237
@@ -253,7 +263,7 @@ impl PgReplicationClient {
253263 quote_literal( slot_name)
254264 ) ;
255265
256- let results = self . client . simple_query ( & query) . await ?;
266+ let results = self . simple_query ( & query) . await ?;
257267 for result in results {
258268 if let SimpleQueryMessage :: Row ( row) = result {
259269 let confirmed_flush_lsn = Self :: get_row_value :: < PgLsn > (
@@ -315,7 +325,7 @@ impl PgReplicationClient {
315325 quote_identifier( slot_name)
316326 ) ;
317327
318- match self . client . simple_query ( & query) . await {
328+ match self . simple_query ( & query) . await {
319329 Ok ( _) => {
320330 info ! ( "successfully deleted replication slot '{}'" , slot_name) ;
321331
@@ -353,7 +363,7 @@ impl PgReplicationClient {
353363 "select 1 as exists from pg_publication where pubname = {};" ,
354364 quote_literal( publication)
355365 ) ;
356- for msg in self . client . simple_query ( & publication_exists_query) . await ? {
366+ for msg in self . simple_query ( & publication_exists_query) . await ? {
357367 if let SimpleQueryMessage :: Row ( _) = msg {
358368 return Ok ( true ) ;
359369 }
@@ -373,7 +383,7 @@ impl PgReplicationClient {
373383 ) ;
374384
375385 let mut table_names = vec ! [ ] ;
376- for msg in self . client . simple_query ( & publication_query) . await ? {
386+ for msg in self . simple_query ( & publication_query) . await ? {
377387 if let SimpleQueryMessage :: Row ( row) = msg {
378388 let schema =
379389 Self :: get_row_value :: < String > ( & row, "schemaname" , "pg_publication_tables" )
@@ -403,7 +413,7 @@ impl PgReplicationClient {
403413 ) ;
404414
405415 let mut table_ids = vec ! [ ] ;
406- for msg in self . client . simple_query ( & publication_query) . await ? {
416+ for msg in self . simple_query ( & publication_query) . await ? {
407417 if let SimpleQueryMessage :: Row ( row) = msg {
408418 // For the sake of simplicity, we refer to the table oid as table id.
409419 let table_id = Self :: get_row_value :: < TableId > ( & row, "oid" , "pg_class" ) . await ?;
@@ -441,7 +451,11 @@ impl PgReplicationClient {
441451 options
442452 ) ;
443453
444- let copy_stream = self . client . copy_both_simple :: < bytes:: Bytes > ( & query) . await ?;
454+ let copy_stream = self
455+ . client
456+ . 0
457+ . copy_both_simple :: < bytes:: Bytes > ( & query)
458+ . await ?;
445459 let stream = LogicalReplicationStream :: new ( copy_stream) ;
446460
447461 Ok ( stream)
@@ -452,23 +466,22 @@ impl PgReplicationClient {
452466 /// The transaction doesn't make any assumptions about the snapshot in use, since this is a
453467 /// concern of the statements issued within the transaction.
454468 async fn begin_tx ( & self ) -> EtlResult < ( ) > {
455- self . client
456- . simple_query ( "begin read only isolation level repeatable read;" )
469+ self . simple_query ( "begin read only isolation level repeatable read;" )
457470 . await ?;
458471
459472 Ok ( ( ) )
460473 }
461474
462475 /// Commits the current transaction.
463476 async fn commit_tx ( & self ) -> EtlResult < ( ) > {
464- self . client . simple_query ( "commit;" ) . await ?;
477+ self . simple_query ( "commit;" ) . await ?;
465478
466479 Ok ( ( ) )
467480 }
468481
469482 /// Rolls back the current transaction.
470483 async fn rollback_tx ( & self ) -> EtlResult < ( ) > {
471- self . client . simple_query ( "rollback;" ) . await ?;
484+ self . simple_query ( "rollback;" ) . await ?;
472485
473486 Ok ( ( ) )
474487 }
@@ -495,7 +508,7 @@ impl PgReplicationClient {
495508 quote_identifier( slot_name) ,
496509 snapshot_option
497510 ) ;
498- match self . client . simple_query ( & query) . await {
511+ match self . simple_query ( & query) . await {
499512 Ok ( results) => {
500513 for result in results {
501514 if let SimpleQueryMessage :: Row ( row) = result {
@@ -595,7 +608,7 @@ impl PgReplicationClient {
595608 where c.oid = {table_id}" ,
596609 ) ;
597610
598- for message in self . client . simple_query ( & table_info_query) . await ? {
611+ for message in self . simple_query ( & table_info_query) . await ? {
599612 if let SimpleQueryMessage :: Row ( row) = message {
600613 let schema_name =
601614 Self :: get_row_value :: < String > ( & row, "schema_name" , "pg_namespace" ) . await ?;
@@ -626,7 +639,7 @@ impl PgReplicationClient {
626639 publication : Option < & str > ,
627640 ) -> EtlResult < Vec < ColumnSchema > > {
628641 let ( pub_cte, pub_pred) = if let Some ( publication) = publication {
629- let is_pg14_or_earlier = self . is_postgres_14_or_earlier ( ) . await ? ;
642+ let is_pg14_or_earlier = self . get_server_version ( ) . await . unwrap_or ( 0 ) < 150000 ;
630643
631644 if !is_pg14_or_earlier {
632645 (
@@ -690,7 +703,7 @@ impl PgReplicationClient {
690703
691704 let mut column_schemas = vec ! [ ] ;
692705
693- for message in self . client . simple_query ( & column_info_query) . await ? {
706+ for message in self . simple_query ( & column_info_query) . await ? {
694707 if let SimpleQueryMessage :: Row ( row) = message {
695708 let name = Self :: get_row_value :: < String > ( & row, "attname" , "pg_attribute" ) . await ?;
696709 let type_oid = Self :: get_row_value :: < u32 > ( & row, "atttypid" , "pg_attribute" ) . await ?;
@@ -716,24 +729,36 @@ impl PgReplicationClient {
716729 Ok ( column_schemas)
717730 }
718731
719- async fn is_postgres_14_or_earlier ( & self ) -> EtlResult < bool > {
720- let version_query = "SHOW server_version_num" ;
721-
722- for message in self . client . simple_query ( version_query) . await ? {
723- if let SimpleQueryMessage :: Row ( row) = message {
724- let version_str =
725- Self :: get_row_value :: < String > ( & row, "server_version_num" , "server_version_num" )
732+ /// Gets the PostgreSQL server version.
733+ ///
734+ /// Returns the version in the format: MAJOR * 10000 + MINOR * 100 + PATCH
735+ /// For example: PostgreSQL 14.2 = 140200, PostgreSQL 15.1 = 150100
736+ async fn get_server_version ( & self ) -> EtlResult < i32 > {
737+ let version = self
738+ . client
739+ . 1
740+ . get_or_try_init ( || async {
741+ let version_query = "SHOW server_version_num" ;
742+
743+ for message in self . simple_query ( version_query) . await ? {
744+ if let SimpleQueryMessage :: Row ( row) = message {
745+ let version_str = Self :: get_row_value :: < String > (
746+ & row,
747+ "server_version_num" ,
748+ "server_version_num" ,
749+ )
726750 . await ?;
727- let server_version: i32 = version_str. parse ( ) . unwrap_or ( 0 ) ;
751+ let version: i32 = version_str. parse ( ) . unwrap_or ( 0 ) ;
752+ return Ok :: < _ , EtlError > ( version) ;
753+ }
754+ }
728755
729- // PostgreSQL version format is typically: MAJOR * 10000 + MINOR * 100 + PATCH
730- // For version 14.x.x, this would be 140000 + minor * 100 + patch
731- // For version 15.x.x, this would be 150000 + minor * 100 + patch
732- return Ok ( server_version < 150000 ) ;
733- }
734- }
756+ // If we can't determine, return 0 (which will be treated as very old version)
757+ Ok ( 0 )
758+ } )
759+ . await ?;
735760
736- Ok ( false )
761+ Ok ( * version )
737762 }
738763
739764 /// Creates a COPY stream for reading data from a table using its OID.
@@ -759,7 +784,7 @@ impl PgReplicationClient {
759784 column_list
760785 ) ;
761786
762- let stream = self . client . copy_out_simple ( & copy_query) . await ?;
787+ let stream = self . client . 0 . copy_out_simple ( & copy_query) . await ?;
763788
764789 Ok ( stream)
765790 }
0 commit comments