@@ -35,6 +35,7 @@ use crate::{
3535 custom_serve:: { CustomServeTrait , HibernationResult } ,
3636 errors, metrics,
3737 request_context:: RequestContext ,
38+ task_group:: TaskGroup ,
3839} ;
3940
4041const X_RIVET_TARGET : HeaderName = HeaderName :: from_static ( "x-rivet-target" ) ;
@@ -350,6 +351,7 @@ pub struct ProxyState {
350351 in_flight_counters : Cache < ( Id , std:: net:: IpAddr ) , Arc < Mutex < InFlightCounter > > > ,
351352 port_type : PortType ,
352353 clickhouse_inserter : Option < clickhouse_inserter:: ClickHouseInserterHandle > ,
354+ tasks : Arc < TaskGroup > ,
353355}
354356
355357impl ProxyState {
@@ -377,6 +379,7 @@ impl ProxyState {
377379 . build ( ) ,
378380 port_type,
379381 clickhouse_inserter,
382+ tasks : TaskGroup :: new ( ) ,
380383 }
381384 }
382385
@@ -782,14 +785,6 @@ impl ProxyService {
782785 metrics:: PROXY_REQUEST_PENDING . add ( 1 , & [ ] ) ;
783786 metrics:: PROXY_REQUEST_TOTAL . add ( 1 , & [ ] ) ;
784787
785- // Prepare to release in-flight counter when done
786- let state_clone = self . state . clone ( ) ;
787- crate :: defer! {
788- tokio:: spawn( async move {
789- state_clone. release_in_flight( client_ip, & actor_id) . await ;
790- } . instrument( tracing:: info_span!( "release_in_flight_task" ) ) ) ;
791- }
792-
793788 // Update request context with target info
794789 if let Some ( actor_id) = actor_id {
795790 request_context. service_actor_id = Some ( actor_id) ;
@@ -814,6 +809,15 @@ impl ProxyService {
814809
815810 metrics:: PROXY_REQUEST_PENDING . add ( -1 , & [ ] ) ;
816811
812+ // Release in-flight counter when done
813+ let state_clone = self . state . clone ( ) ;
814+ tokio:: spawn (
815+ async move {
816+ state_clone. release_in_flight ( client_ip, & actor_id) . await ;
817+ }
818+ . instrument ( tracing:: info_span!( "release_in_flight_task" ) ) ,
819+ ) ;
820+
817821 res
818822 }
819823
@@ -1254,7 +1258,7 @@ impl ProxyService {
12541258 match target {
12551259 ResolveRouteOutput :: Target ( mut target) => {
12561260 tracing:: debug!( "Spawning task to handle WebSocket communication" ) ;
1257- tokio :: spawn (
1261+ self . state . tasks . spawn (
12581262 async move {
12591263 // Set up a timeout for the entire operation
12601264 let timeout_duration = Duration :: from_secs ( 30 ) ; // 30 seconds timeout
@@ -1837,7 +1841,7 @@ impl ProxyService {
18371841 let req_path = req_path. clone ( ) ;
18381842 let req_host = req_host. clone ( ) ;
18391843
1840- tokio :: spawn (
1844+ self . state . tasks . spawn (
18411845 async move {
18421846 let request_id = Uuid :: new_v4 ( ) ;
18431847 let mut ws_hibernation_close = false ;
@@ -2194,7 +2198,7 @@ impl ProxyService {
21942198 Ok ( ( client_response, client_ws) ) => {
21952199 tracing:: debug!( "Client WebSocket upgrade for error proxy successful" ) ;
21962200
2197- tokio :: spawn (
2201+ self . state . tasks . spawn (
21982202 async move {
21992203 let ws_handle = match WebSocketHandle :: new ( client_ws) . await {
22002204 Ok ( ws_handle) => ws_handle,
@@ -2337,11 +2341,14 @@ impl ProxyService {
23372341
23382342 // Insert analytics event asynchronously
23392343 let mut context_clone = request_context. clone ( ) ;
2340- tokio:: spawn ( async move {
2341- if let Err ( error) = context_clone. insert_event ( ) . await {
2342- tracing:: warn!( ?error, "failed to insert guard analytics event" ) ;
2344+ tokio:: spawn (
2345+ async move {
2346+ if let Err ( error) = context_clone. insert_event ( ) . await {
2347+ tracing:: warn!( ?error, "failed to insert guard analytics event" ) ;
2348+ }
23432349 }
2344- } ) ;
2350+ . instrument ( tracing:: info_span!( "insert_event_task" ) ) ,
2351+ ) ;
23452352
23462353 let content_length = res
23472354 . headers ( )
@@ -2407,24 +2414,10 @@ impl ProxyServiceFactory {
24072414 pub fn create_service ( & self , remote_addr : SocketAddr ) -> ProxyService {
24082415 ProxyService :: new ( self . state . clone ( ) , remote_addr)
24092416 }
2410- }
24112417
2412- // Helper macro for defer-like functionality
2413- #[ macro_export]
2414- macro_rules! defer {
2415- ( $( $body: tt) * ) => {
2416- let _guard = {
2417- struct Guard <F : FnOnce ( ) >( Option <F >) ;
2418- impl <F : FnOnce ( ) > Drop for Guard <F > {
2419- fn drop( & mut self ) {
2420- if let Some ( f) = self . 0 . take( ) {
2421- f( )
2422- }
2423- }
2424- }
2425- Guard ( Some ( || { $( $body) * } ) )
2426- } ;
2427- } ;
2418+ pub async fn wait_idle ( & self ) {
2419+ self . state . tasks . wait_idle ( ) . await
2420+ }
24282421}
24292422
24302423fn add_proxy_headers_with_addr (
0 commit comments