@@ -12,7 +12,7 @@ use tokio::net::TcpStream;
1212use tokio:: sync:: broadcast:: Receiver ;
1313use tokio:: sync:: mpsc:: Sender ;
1414
15- use crate :: admin:: { generate_server_info_for_admin , handle_admin} ;
15+ use crate :: admin:: { generate_server_parameters_for_admin , handle_admin} ;
1616use crate :: auth_passthrough:: refetch_auth_hash;
1717use crate :: config:: {
1818 get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address , PoolMode ,
@@ -22,7 +22,7 @@ use crate::messages::*;
2222use crate :: plugins:: PluginOutput ;
2323use crate :: pool:: { get_pool, ClientServerMap , ConnectionPool } ;
2424use crate :: query_router:: { Command , QueryRouter } ;
25- use crate :: server:: Server ;
25+ use crate :: server:: { Server , ServerParameters } ;
2626use crate :: stats:: { ClientStats , ServerStats } ;
2727use crate :: tls:: Tls ;
2828
@@ -96,8 +96,8 @@ pub struct Client<S, T> {
9696 /// Postgres user for this client (This comes from the user in the connection string)
9797 username : String ,
9898
99- /// Application name for this client (defaults to pgcat)
100- application_name : String ,
99+ /// Server startup and session parameters that we're going to track
100+ server_parameters : ServerParameters ,
101101
102102 /// Used to notify clients about an impending shutdown
103103 shutdown : Receiver < ( ) > ,
@@ -502,7 +502,7 @@ where
502502 } ;
503503
504504 // Authenticate admin user.
505- let ( transaction_mode, server_info ) = if admin {
505+ let ( transaction_mode, mut server_parameters ) = if admin {
506506 let config = get_config ( ) ;
507507
508508 // Compare server and client hashes.
@@ -521,7 +521,7 @@ where
521521 return Err ( error) ;
522522 }
523523
524- ( false , generate_server_info_for_admin ( ) )
524+ ( false , generate_server_parameters_for_admin ( ) )
525525 }
526526 // Authenticate normal user.
527527 else {
@@ -654,13 +654,16 @@ where
654654 }
655655 }
656656
657- ( transaction_mode, pool. server_info ( ) )
657+ ( transaction_mode, pool. server_parameters ( ) )
658658 } ;
659659
660+ // Update the parameters to merge what the application sent and what's originally on the server
661+ server_parameters. set_from_hashmap ( & parameters, false ) ;
662+
660663 debug ! ( "Password authentication successful" ) ;
661664
662665 auth_ok ( & mut write) . await ?;
663- write_all ( & mut write, server_info ) . await ?;
666+ write_all ( & mut write, ( & server_parameters ) . into ( ) ) . await ?;
664667 backend_key_data ( & mut write, process_id, secret_key) . await ?;
665668 ready_for_query ( & mut write) . await ?;
666669
@@ -690,7 +693,7 @@ where
690693 last_server_stats : None ,
691694 pool_name : pool_name. clone ( ) ,
692695 username : username. clone ( ) ,
693- application_name : application_name . to_string ( ) ,
696+ server_parameters ,
694697 shutdown,
695698 connected_to_server : false ,
696699 prepared_statements : HashMap :: new ( ) ,
@@ -725,7 +728,7 @@ where
725728 last_server_stats : None ,
726729 pool_name : String :: from ( "undefined" ) ,
727730 username : String :: from ( "undefined" ) ,
728- application_name : String :: from ( "undefined" ) ,
731+ server_parameters : ServerParameters :: new ( ) ,
729732 shutdown,
730733 connected_to_server : false ,
731734 prepared_statements : HashMap :: new ( ) ,
@@ -774,8 +777,11 @@ where
774777 let mut prepared_statement = None ;
775778 let mut will_prepare = false ;
776779
777- let client_identifier =
778- ClientIdentifier :: new ( & self . application_name , & self . username , & self . pool_name ) ;
780+ let client_identifier = ClientIdentifier :: new (
781+ & self . server_parameters . get_application_name ( ) ,
782+ & self . username ,
783+ & self . pool_name ,
784+ ) ;
779785
780786 // Our custom protocol loop.
781787 // We expect the client to either start a transaction with regular queries
@@ -1115,10 +1121,7 @@ where
11151121 server. address( )
11161122 ) ;
11171123
1118- // TODO: investigate other parameters and set them too.
1119-
1120- // Set application_name.
1121- server. set_name ( & self . application_name ) . await ?;
1124+ server. sync_parameters ( & self . server_parameters ) . await ?;
11221125
11231126 let mut initial_message = Some ( message) ;
11241127
@@ -1296,7 +1299,9 @@ where
12961299 if !server. in_transaction ( ) {
12971300 // Report transaction executed statistics.
12981301 self . stats . transaction ( ) ;
1299- server. stats ( ) . transaction ( & self . application_name ) ;
1302+ server
1303+ . stats ( )
1304+ . transaction ( & self . server_parameters . get_application_name ( ) ) ;
13001305
13011306 // Release server back to the pool if we are in transaction mode.
13021307 // If we are in session mode, we keep the server until the client disconnects.
@@ -1446,7 +1451,9 @@ where
14461451
14471452 if !server. in_transaction ( ) {
14481453 self . stats . transaction ( ) ;
1449- server. stats ( ) . transaction ( & self . application_name ) ;
1454+ server
1455+ . stats ( )
1456+ . transaction ( & self . server_parameters . get_application_name ( ) ) ;
14501457
14511458 // Release server back to the pool if we are in transaction mode.
14521459 // If we are in session mode, we keep the server until the client disconnects.
@@ -1495,7 +1502,9 @@ where
14951502
14961503 if !server. in_transaction ( ) {
14971504 self . stats . transaction ( ) ;
1498- server. stats ( ) . transaction ( & self . application_name ) ;
1505+ server
1506+ . stats ( )
1507+ . transaction ( self . server_parameters . get_application_name ( ) ) ;
14991508
15001509 // Release server back to the pool if we are in transaction mode.
15011510 // If we are in session mode, we keep the server until the client disconnects.
@@ -1547,7 +1556,9 @@ where
15471556
15481557 Err ( Error :: ClientError ( format ! (
15491558 "Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}" ,
1550- self . pool_name, self . username, self . application_name
1559+ self . pool_name,
1560+ self . username,
1561+ self . server_parameters. get_application_name( )
15511562 ) ) )
15521563 }
15531564 }
@@ -1704,7 +1715,7 @@ where
17041715 client_stats. query ( ) ;
17051716 server. stats ( ) . query (
17061717 Instant :: now ( ) . duration_since ( query_start) . as_millis ( ) as u64 ,
1707- & self . application_name ,
1718+ & self . server_parameters . get_application_name ( ) ,
17081719 ) ;
17091720
17101721 Ok ( ( ) )
@@ -1733,38 +1744,18 @@ where
17331744 pool : & ConnectionPool ,
17341745 client_stats : & ClientStats ,
17351746 ) -> Result < BytesMut , Error > {
1736- if pool. settings . user . statement_timeout > 0 {
1737- match tokio:: time:: timeout (
1738- tokio:: time:: Duration :: from_millis ( pool. settings . user . statement_timeout ) ,
1739- server. recv ( ) ,
1740- )
1741- . await
1742- {
1743- Ok ( result) => match result {
1744- Ok ( message) => Ok ( message) ,
1745- Err ( err) => {
1746- pool. ban ( address, BanReason :: MessageReceiveFailed , Some ( client_stats) ) ;
1747- error_response_terminal (
1748- & mut self . write ,
1749- & format ! ( "error receiving data from server: {:?}" , err) ,
1750- )
1751- . await ?;
1752- Err ( err)
1753- }
1754- } ,
1755- Err ( _) => {
1756- error ! (
1757- "Statement timeout while talking to {:?} with user {}" ,
1758- address, pool. settings. user. username
1759- ) ;
1760- server. mark_bad ( ) ;
1761- pool. ban ( address, BanReason :: StatementTimeout , Some ( client_stats) ) ;
1762- error_response_terminal ( & mut self . write , "pool statement timeout" ) . await ?;
1763- Err ( Error :: StatementTimeout )
1764- }
1765- }
1766- } else {
1767- match server. recv ( ) . await {
1747+ let statement_timeout_duration = match pool. settings . user . statement_timeout {
1748+ 0 => tokio:: time:: Duration :: MAX ,
1749+ timeout => tokio:: time:: Duration :: from_millis ( timeout) ,
1750+ } ;
1751+
1752+ match tokio:: time:: timeout (
1753+ statement_timeout_duration,
1754+ server. recv ( Some ( & mut self . server_parameters ) ) ,
1755+ )
1756+ . await
1757+ {
1758+ Ok ( result) => match result {
17681759 Ok ( message) => Ok ( message) ,
17691760 Err ( err) => {
17701761 pool. ban ( address, BanReason :: MessageReceiveFailed , Some ( client_stats) ) ;
@@ -1775,6 +1766,16 @@ where
17751766 . await ?;
17761767 Err ( err)
17771768 }
1769+ } ,
1770+ Err ( _) => {
1771+ error ! (
1772+ "Statement timeout while talking to {:?} with user {}" ,
1773+ address, pool. settings. user. username
1774+ ) ;
1775+ server. mark_bad ( ) ;
1776+ pool. ban ( address, BanReason :: StatementTimeout , Some ( client_stats) ) ;
1777+ error_response_terminal ( & mut self . write , "pool statement timeout" ) . await ?;
1778+ Err ( Error :: StatementTimeout )
17781779 }
17791780 }
17801781 }
0 commit comments