@@ -16,6 +16,7 @@ use crate::auth_passthrough::refetch_auth_hash;
1616use crate :: config:: { get_config, get_idle_client_in_transaction_timeout, Address , PoolMode } ;
1717use crate :: constants:: * ;
1818use crate :: messages:: * ;
19+ use crate :: plugins:: PluginOutput ;
1920use crate :: pool:: { get_pool, ClientServerMap , ConnectionPool } ;
2021use crate :: query_router:: { Command , QueryRouter } ;
2122use crate :: server:: Server ;
@@ -765,6 +766,9 @@ where
765766
766767 self . stats . register ( self . stats . clone ( ) ) ;
767768
769+ // Result returned by one of the plugins.
770+ let mut plugin_output = None ;
771+
768772 // Our custom protocol loop.
769773 // We expect the client to either start a transaction with regular queries
770774 // or issue commands for our sharding and server selection protocol.
@@ -815,15 +819,39 @@ where
815819
816820 'Q' => {
817821 if query_router. query_parser_enabled ( ) {
818- query_router. infer ( & message) ;
822+ if let Ok ( ast) = QueryRouter :: parse ( & message) {
823+ let plugin_result = query_router. execute_plugins ( & ast) . await ;
824+
825+ match plugin_result {
826+ Ok ( PluginOutput :: Deny ( error) ) => {
827+ error_response ( & mut self . write , & error) . await ?;
828+ continue ;
829+ }
830+
831+ Ok ( PluginOutput :: Intercept ( result) ) => {
832+ write_all ( & mut self . write , result) . await ?;
833+ continue ;
834+ }
835+
836+ _ => ( ) ,
837+ } ;
838+
839+ let _ = query_router. infer ( & ast) ;
840+ }
819841 }
820842 }
821843
822844 'P' => {
823845 self . buffer . put ( & message[ ..] ) ;
824846
825847 if query_router. query_parser_enabled ( ) {
826- query_router. infer ( & message) ;
848+ if let Ok ( ast) = QueryRouter :: parse ( & message) {
849+ if let Ok ( output) = query_router. execute_plugins ( & ast) . await {
850+ plugin_output = Some ( output) ;
851+ }
852+
853+ let _ = query_router. infer ( & ast) ;
854+ }
827855 }
828856
829857 continue ;
@@ -857,6 +885,18 @@ where
857885 continue ;
858886 }
859887
888+ // Check on plugin results.
889+ match plugin_output {
890+ Some ( PluginOutput :: Deny ( error) ) => {
891+ self . buffer . clear ( ) ;
892+ error_response ( & mut self . write , & error) . await ?;
893+ plugin_output = None ;
894+ continue ;
895+ }
896+
897+ _ => ( ) ,
898+ } ;
899+
860900 // Get a pool instance referenced by the most up-to-date
861901 // pointer. This ensures we always read the latest config
862902 // when starting a query.
@@ -1085,6 +1125,27 @@ where
10851125 match code {
10861126 // Query
10871127 'Q' => {
1128+ if query_router. query_parser_enabled ( ) {
1129+ if let Ok ( ast) = QueryRouter :: parse ( & message) {
1130+ let plugin_result = query_router. execute_plugins ( & ast) . await ;
1131+
1132+ match plugin_result {
1133+ Ok ( PluginOutput :: Deny ( error) ) => {
1134+ error_response ( & mut self . write , & error) . await ?;
1135+ continue ;
1136+ }
1137+
1138+ Ok ( PluginOutput :: Intercept ( result) ) => {
1139+ write_all ( & mut self . write , result) . await ?;
1140+ continue ;
1141+ }
1142+
1143+ _ => ( ) ,
1144+ } ;
1145+
1146+ let _ = query_router. infer ( & ast) ;
1147+ }
1148+ }
10881149 debug ! ( "Sending query to server" ) ;
10891150
10901151 self . send_and_receive_loop (
@@ -1124,6 +1185,14 @@ where
11241185 // Parse
11251186 // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
11261187 'P' => {
1188+ if query_router. query_parser_enabled ( ) {
1189+ if let Ok ( ast) = QueryRouter :: parse ( & message) {
1190+ if let Ok ( output) = query_router. execute_plugins ( & ast) . await {
1191+ plugin_output = Some ( output) ;
1192+ }
1193+ }
1194+ }
1195+
11271196 self . buffer . put ( & message[ ..] ) ;
11281197 }
11291198
@@ -1155,6 +1224,24 @@ where
11551224 'S' => {
11561225 debug ! ( "Sending query to server" ) ;
11571226
1227+ match plugin_output {
1228+ Some ( PluginOutput :: Deny ( error) ) => {
1229+ error_response ( & mut self . write , & error) . await ?;
1230+ plugin_output = None ;
1231+ self . buffer . clear ( ) ;
1232+ continue ;
1233+ }
1234+
1235+ Some ( PluginOutput :: Intercept ( result) ) => {
1236+ write_all ( & mut self . write , result) . await ?;
1237+ plugin_output = None ;
1238+ self . buffer . clear ( ) ;
1239+ continue ;
1240+ }
1241+
1242+ _ => ( ) ,
1243+ } ;
1244+
11581245 self . buffer . put ( & message[ ..] ) ;
11591246
11601247 let first_message_code = ( * self . buffer . get ( 0 ) . unwrap_or ( & 0 ) ) as char ;
0 commit comments