11use bytes:: { Bytes , BytesMut } ;
2+ use reqwest:: header:: { HeaderMap , HeaderValue , ACCEPT } ;
23use reqwest:: Client ;
34use std:: time:: Duration ;
45use tokio:: sync:: { mpsc, oneshot} ;
@@ -39,11 +40,15 @@ impl SseStream {
3940 & self ,
4041 mut endpoint_event_tx : Option < oneshot:: Sender < Option < String > > > ,
4142 cancellation_token : CancellationToken ,
43+ custom_headers : & Option < HeaderMap > ,
4244 ) {
4345 let mut retry_count = 0 ;
4446 let mut buffer = BytesMut :: with_capacity ( BUFFER_CAPACITY ) ;
4547 let mut endpoint_event_received = false ;
4648
49+ let mut request_headers: HeaderMap = custom_headers. to_owned ( ) . unwrap_or_default ( ) ;
50+ request_headers. insert ( ACCEPT , HeaderValue :: from_static ( "text/event-stream" ) ) ;
51+
4752 // Main loop for reconnection attempts
4853 loop {
4954 // Check for cancellation before attempting connection
@@ -56,7 +61,7 @@ impl SseStream {
5661 let response = match self
5762 . sse_client
5863 . get ( & self . sse_url )
59- . header ( "Accept" , "text/event-stream" )
64+ . headers ( request_headers . clone ( ) )
6065 . send ( )
6166 . await
6267 {
@@ -86,7 +91,18 @@ impl SseStream {
8691 chunk = stream. next( ) => {
8792 match chunk {
8893 Some ( chunk) => chunk,
89- None => break , // Stream ended, break from inner loop to reconnect
94+ None => {
95+ if retry_count >= self . max_retries {
96+ tracing:: error!( "Max retries ({}) reached, giving up" , self . max_retries) ;
97+ if let Some ( tx) = endpoint_event_tx. take( ) {
98+ let _ = tx. send( None ) ;
99+ }
100+ return ;
101+ }
102+ retry_count += 1 ;
103+ time:: sleep( self . retry_delay) . await ;
104+ break ; // Stream ended, break from inner loop to reconnect
105+ }
90106 }
91107 }
92108 // Wait for cancellation
@@ -177,4 +193,81 @@ impl SseStream {
177193}
178194
179195#[ cfg( test) ]
180- mod tests { }
196+ mod tests {
197+ use super :: * ;
198+ use crate :: utils:: CancellationTokenSource ;
199+ use reqwest:: header:: { HeaderMap , HeaderValue } ;
200+ use tokio:: time:: Duration ;
201+ use wiremock:: matchers:: { header, method, path} ;
202+ use wiremock:: { Mock , MockServer , ResponseTemplate } ;
203+
204+ #[ tokio:: test]
205+ async fn test_sse_client_sends_custom_headers_on_connection ( ) {
206+ // Start WireMock server
207+ let mock_server = MockServer :: builder ( ) . start ( ) . await ;
208+
209+ // Create WireMock stub with connection close
210+ Mock :: given ( method ( "GET" ) )
211+ . and ( path ( "/sse" ) )
212+ . and ( header ( "Accept" , "text/event-stream" ) )
213+ . and ( header ( "X-Custom-Header" , "CustomValue" ) )
214+ . respond_with (
215+ ResponseTemplate :: new ( 200 )
216+ . set_body_string ( "event: endpoint\n data: mock-endpoint\n \n " )
217+ . append_header ( "Content-Type" , "text/event-stream" )
218+ . append_header ( "Connection" , "close" ) , // Ensure connection closes
219+ )
220+ . expect ( 1 ) // Expect exactly one request
221+ . mount ( & mock_server)
222+ . await ;
223+
224+ // Create custom headers
225+ let mut custom_headers = HeaderMap :: new ( ) ;
226+ custom_headers. insert ( "X-Custom-Header" , HeaderValue :: from_static ( "CustomValue" ) ) ;
227+
228+ // Create channel and SseStream
229+ let ( read_tx, _read_rx) = mpsc:: channel :: < Bytes > ( 64 ) ;
230+ let sse = SseStream {
231+ sse_client : reqwest:: Client :: new ( ) ,
232+ sse_url : format ! ( "{}/sse" , mock_server. uri( ) ) ,
233+ max_retries : 0 , // to receive one request only
234+ retry_delay : Duration :: from_millis ( 100 ) ,
235+ read_tx,
236+ } ;
237+
238+ // Create cancellation token and endpoint channel
239+ let ( cancellation_source, cancellation_token) = CancellationTokenSource :: new ( ) ;
240+ let ( endpoint_event_tx, endpoint_event_rx) = oneshot:: channel :: < Option < String > > ( ) ;
241+
242+ // Spawn the run method
243+ let sse_task = tokio:: spawn ( {
244+ async move {
245+ sse. run (
246+ Some ( endpoint_event_tx) ,
247+ cancellation_token,
248+ & Some ( custom_headers) ,
249+ )
250+ . await ;
251+ }
252+ } ) ;
253+
254+ // Wait for the endpoint event or timeout
255+ let event_result =
256+ tokio:: time:: timeout ( Duration :: from_millis ( 500 ) , endpoint_event_rx) . await ;
257+
258+ // Cancel the task to ensure loop exits
259+ let _ = cancellation_source. cancel ( ) ;
260+
261+ // Wait for the task to complete with a timeout
262+ match tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , sse_task) . await {
263+ Ok ( result) => result. unwrap ( ) ,
264+ Err ( _) => panic ! ( "Test timed out after 1 second" ) ,
265+ }
266+
267+ // Verify the endpoint event was received
268+ match event_result {
269+ Ok ( Ok ( Some ( event) ) ) => assert_eq ! ( event, "mock-endpoint" , "Expected endpoint event" ) ,
270+ _ => panic ! ( "Did not receive expected endpoint event" ) ,
271+ }
272+ }
273+ }
0 commit comments