@@ -3,8 +3,9 @@ use crate::pool::BanReason;
33/// Handle clients by pretending to be a PostgreSQL server.
44use bytes:: { Buf , BufMut , BytesMut } ;
55use log:: { debug, error, info, trace, warn} ;
6+ use once_cell:: sync:: Lazy ;
67use std:: collections:: HashMap ;
7- use std:: sync:: Arc ;
8+ use std:: sync:: { atomic :: AtomicUsize , Arc } ;
89use std:: time:: Instant ;
910use tokio:: io:: { split, AsyncReadExt , BufReader , ReadHalf , WriteHalf } ;
1011use tokio:: net:: TcpStream ;
@@ -13,7 +14,9 @@ use tokio::sync::mpsc::Sender;
1314
1415use crate :: admin:: { generate_server_info_for_admin, handle_admin} ;
1516use crate :: auth_passthrough:: refetch_auth_hash;
16- use crate :: config:: { get_config, get_idle_client_in_transaction_timeout, Address , PoolMode } ;
17+ use crate :: config:: {
18+ get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address , PoolMode ,
19+ } ;
1720use crate :: constants:: * ;
1821use crate :: messages:: * ;
1922use crate :: plugins:: PluginOutput ;
@@ -25,6 +28,11 @@ use crate::tls::Tls;
2528
2629use tokio_rustls:: server:: TlsStream ;
2730
31+ /// Incrementally count prepared statements
32+ /// to avoid random conflicts in places where the random number generator is weak.
33+ pub static PREPARED_STATEMENT_COUNTER : Lazy < Arc < AtomicUsize > > =
34+ Lazy :: new ( || Arc :: new ( AtomicUsize :: new ( 0 ) ) ) ;
35+
2836/// Type of connection received from client.
2937enum ClientConnectionType {
3038 Startup ,
@@ -93,6 +101,9 @@ pub struct Client<S, T> {
93101
94102 /// Used to notify clients about an impending shutdown
95103 shutdown : Receiver < ( ) > ,
104+
105+ /// Prepared statements
106+ prepared_statements : HashMap < String , Parse > ,
96107}
97108
98109/// Client entrypoint.
@@ -682,6 +693,7 @@ where
682693 application_name : application_name. to_string ( ) ,
683694 shutdown,
684695 connected_to_server : false ,
696+ prepared_statements : HashMap :: new ( ) ,
685697 } )
686698 }
687699
@@ -716,6 +728,7 @@ where
716728 application_name : String :: from ( "undefined" ) ,
717729 shutdown,
718730 connected_to_server : false ,
731+ prepared_statements : HashMap :: new ( ) ,
719732 } )
720733 }
721734
@@ -757,6 +770,10 @@ where
757770 // Result returned by one of the plugins.
758771 let mut plugin_output = None ;
759772
773+ // Prepared statement being executed
774+ let mut prepared_statement = None ;
775+ let mut will_prepare = false ;
776+
760777 // Our custom protocol loop.
761778 // We expect the client to either start a transaction with regular queries
762779 // or issue commands for our sharding and server selection protocol.
@@ -766,13 +783,16 @@ where
766783 self . transaction_mode
767784 ) ;
768785
786+ // Should we rewrite prepared statements and bind messages?
787+ let mut prepared_statements_enabled = get_prepared_statements ( ) ;
788+
769789 // Read a complete message from the client, which normally would be
770790 // either a `Q` (query) or `P` (prepare, extended protocol).
771791 // We can parse it here before grabbing a server from the pool,
772792 // in case the client is sending some custom protocol messages, e.g.
773793 // SET SHARDING KEY TO 'bigint';
774794
775- let message = tokio:: select! {
795+ let mut message = tokio:: select! {
776796 _ = self . shutdown. recv( ) => {
777797 if !self . admin {
778798 error_response_terminal(
@@ -800,7 +820,21 @@ where
800820 // allocate a connection, we wouldn't be able to send back an error message
801821 // to the client so we buffer them and defer the decision to error out or not
802822 // to when we get the S message
803- 'D' | 'E' => {
823+ 'D' => {
824+ if prepared_statements_enabled {
825+ let name;
826+ ( name, message) = self . rewrite_describe ( message) . await ?;
827+
828+ if let Some ( name) = name {
829+ prepared_statement = Some ( name) ;
830+ }
831+ }
832+
833+ self . buffer . put ( & message[ ..] ) ;
834+ continue ;
835+ }
836+
837+ 'E' => {
804838 self . buffer . put ( & message[ ..] ) ;
805839 continue ;
806840 }
@@ -830,6 +864,11 @@ where
830864 }
831865
832866 'P' => {
867+ if prepared_statements_enabled {
868+ ( prepared_statement, message) = self . rewrite_parse ( message) ?;
869+ will_prepare = true ;
870+ }
871+
833872 self . buffer . put ( & message[ ..] ) ;
834873
835874 if query_router. query_parser_enabled ( ) {
@@ -846,6 +885,10 @@ where
846885 }
847886
848887 'B' => {
888+ if prepared_statements_enabled {
889+ ( prepared_statement, message) = self . rewrite_bind ( message) . await ?;
890+ }
891+
849892 self . buffer . put ( & message[ ..] ) ;
850893
851894 if query_router. query_parser_enabled ( ) {
@@ -1054,7 +1097,48 @@ where
10541097 // If the client is in session mode, no more custom protocol
10551098 // commands will be accepted.
10561099 loop {
1057- let message = match initial_message {
1100+ // Only check if we should rewrite prepared statements
1101+ // in session mode. In transaction mode, we check at the beginning of
1102+ // each transaction.
1103+ if !self . transaction_mode {
1104+ prepared_statements_enabled = get_prepared_statements ( ) ;
1105+ }
1106+
1107+ debug ! ( "Prepared statement active: {:?}" , prepared_statement) ;
1108+
1109+ // We are processing a prepared statement.
1110+ if let Some ( ref name) = prepared_statement {
1111+ debug ! ( "Checking prepared statement is on server" ) ;
1112+ // Get the prepared statement the server expects to see.
1113+ let statement = match self . prepared_statements . get ( name) {
1114+ Some ( statement) => {
1115+ debug ! ( "Prepared statement `{}` found in cache" , name) ;
1116+ statement
1117+ }
1118+ None => {
1119+ return Err ( Error :: ClientError ( format ! (
1120+ "prepared statement `{}` not found" ,
1121+ name
1122+ ) ) )
1123+ }
1124+ } ;
1125+
1126+ // Since it's already in the buffer, we don't need to prepare it on this server.
1127+ if will_prepare {
1128+ server. will_prepare ( & statement. name ) ;
1129+ will_prepare = false ;
1130+ } else {
1131+ // The statement is not prepared on the server, so we need to prepare it.
1132+ if server. should_prepare ( & statement. name ) {
1133+ server. prepare ( statement) . await ?;
1134+ }
1135+ }
1136+
1137+ // Done processing the prepared statement.
1138+ prepared_statement = None ;
1139+ }
1140+
1141+ let mut message = match initial_message {
10581142 None => {
10591143 trace ! ( "Waiting for message inside transaction or in session mode" ) ;
10601144
@@ -1173,6 +1257,11 @@ where
11731257 // Parse
11741258 // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
11751259 'P' => {
1260+ if prepared_statements_enabled {
1261+ ( prepared_statement, message) = self . rewrite_parse ( message) ?;
1262+ will_prepare = true ;
1263+ }
1264+
11761265 if query_router. query_parser_enabled ( ) {
11771266 if let Ok ( ast) = QueryRouter :: parse ( & message) {
11781267 if let Ok ( output) = query_router. execute_plugins ( & ast) . await {
@@ -1187,12 +1276,25 @@ where
11871276 // Bind
11881277 // The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
11891278 'B' => {
1279+ if prepared_statements_enabled {
1280+ ( prepared_statement, message) = self . rewrite_bind ( message) . await ?;
1281+ }
1282+
11901283 self . buffer . put ( & message[ ..] ) ;
11911284 }
11921285
11931286 // Describe
11941287 // Command a client can issue to describe a previously prepared named statement.
11951288 'D' => {
1289+ if prepared_statements_enabled {
1290+ let name;
1291+ ( name, message) = self . rewrite_describe ( message) . await ?;
1292+
1293+ if let Some ( name) = name {
1294+ prepared_statement = Some ( name) ;
1295+ }
1296+ }
1297+
11961298 self . buffer . put ( & message[ ..] ) ;
11971299 }
11981300
@@ -1235,7 +1337,7 @@ where
12351337 let first_message_code = ( * self . buffer . get ( 0 ) . unwrap_or ( & 0 ) ) as char ;
12361338
12371339 // Almost certainly true
1238- if first_message_code == 'P' {
1340+ if first_message_code == 'P' && !prepared_statements_enabled {
12391341 // Message layout
12401342 // P followed by 32 int followed by null-terminated statement name
12411343 // So message code should be in offset 0 of the buffer, first character
@@ -1363,6 +1465,107 @@ where
13631465 }
13641466 }
13651467
1468+ /// Rewrite Parse (F) message to set the prepared statement name to one we control.
1469+ /// Save it into the client cache.
1470+ fn rewrite_parse ( & mut self , message : BytesMut ) -> Result < ( Option < String > , BytesMut ) , Error > {
1471+ let parse: Parse = ( & message) . try_into ( ) ?;
1472+
1473+ let name = parse. name . clone ( ) ;
1474+
1475+ // Don't rewrite anonymous prepared statements
1476+ if parse. anonymous ( ) {
1477+ debug ! ( "Anonymous prepared statement" ) ;
1478+ return Ok ( ( None , message) ) ;
1479+ }
1480+
1481+ let parse = parse. rename ( ) ;
1482+
1483+ debug ! (
1484+ "Renamed prepared statement `{}` to `{}` and saved to cache" ,
1485+ name, parse. name
1486+ ) ;
1487+
1488+ self . prepared_statements . insert ( name. clone ( ) , parse. clone ( ) ) ;
1489+
1490+ Ok ( ( Some ( name) , parse. try_into ( ) ?) )
1491+ }
1492+
1493+ /// Rewrite the Bind (F) message to use the prepared statement name
1494+ /// saved in the client cache.
1495+ async fn rewrite_bind (
1496+ & mut self ,
1497+ message : BytesMut ,
1498+ ) -> Result < ( Option < String > , BytesMut ) , Error > {
1499+ let bind: Bind = ( & message) . try_into ( ) ?;
1500+ let name = bind. prepared_statement . clone ( ) ;
1501+
1502+ if bind. anonymous ( ) {
1503+ debug ! ( "Anonymous bind message" ) ;
1504+ return Ok ( ( None , message) ) ;
1505+ }
1506+
1507+ match self . prepared_statements . get ( & name) {
1508+ Some ( prepared_stmt) => {
1509+ let bind = bind. reassign ( prepared_stmt) ;
1510+
1511+ debug ! ( "Rewrote bind `{}` to `{}`" , name, bind. prepared_statement) ;
1512+
1513+ Ok ( ( Some ( name) , bind. try_into ( ) ?) )
1514+ }
1515+ None => {
1516+ debug ! ( "Got bind for unknown prepared statement {:?}" , bind) ;
1517+
1518+ error_response (
1519+ & mut self . write ,
1520+ & format ! (
1521+ "prepared statement \" {}\" does not exist" ,
1522+ bind. prepared_statement
1523+ ) ,
1524+ )
1525+ . await ?;
1526+
1527+ Err ( Error :: ClientError ( format ! (
1528+ "Prepared statement `{}` doesn't exist" ,
1529+ name
1530+ ) ) )
1531+ }
1532+ }
1533+ }
1534+
1535+ /// Rewrite the Describe (F) message to use the prepared statement name
1536+ /// saved in the client cache.
1537+ async fn rewrite_describe (
1538+ & mut self ,
1539+ message : BytesMut ,
1540+ ) -> Result < ( Option < String > , BytesMut ) , Error > {
1541+ let describe: Describe = ( & message) . try_into ( ) ?;
1542+ let name = describe. statement_name . clone ( ) ;
1543+
1544+ if describe. anonymous ( ) {
1545+ debug ! ( "Anonymous describe" ) ;
1546+ return Ok ( ( None , message) ) ;
1547+ }
1548+
1549+ match self . prepared_statements . get ( & name) {
1550+ Some ( prepared_stmt) => {
1551+ let describe = describe. rename ( & prepared_stmt. name ) ;
1552+
1553+ debug ! (
1554+ "Rewrote describe `{}` to `{}`" ,
1555+ name, describe. statement_name
1556+ ) ;
1557+
1558+ Ok ( ( Some ( name) , describe. try_into ( ) ?) )
1559+ }
1560+
1561+ None => {
1562+ debug ! ( "Got describe for unknown prepared statement {:?}" , describe) ;
1563+
1564+ Ok ( ( None , message) )
1565+ }
1566+ }
1567+ }
1568+
13661569 /// Release the server from the client: it can't cancel its queries anymore.
13671570 pub fn release ( & self ) {
13681571 let mut guard = self . client_server_map . lock ( ) ;
0 commit comments