@@ -17,6 +17,7 @@ use std::{
1717 boxed,
1818 fmt:: { self , Debug , Display , Formatter } ,
1919 future:: Future ,
20+ io:: ErrorKind ,
2021 mem,
2122 pin:: Pin ,
2223 str:: FromStr ,
@@ -59,6 +60,10 @@ pub trait Client: Send + Sync + private::Sealed {
5960 * TODO specify list of stati to not retry (e.g. 204)
6061 */
6162
63+ /// Maximum amount of redirects that the client will follow before
64+ /// giving up, if not overridden via [ClientBuilder::redirect_limit].
65+ pub const DEFAULT_REDIRECT_LIMIT : u32 = 16 ;
66+
6267/// ClientBuilder provides a series of builder methods to easily construct a [`Client`].
6368pub struct ClientBuilder {
6469 url : Uri ,
@@ -68,6 +73,7 @@ pub struct ClientBuilder {
6873 last_event_id : Option < String > ,
6974 method : String ,
7075 body : Option < String > ,
76+ max_redirects : Option < u32 > ,
7177}
7278
7379impl ClientBuilder {
@@ -88,6 +94,7 @@ impl ClientBuilder {
8894 read_timeout : None ,
8995 last_event_id : None ,
9096 method : String :: from ( "GET" ) ,
97+ max_redirects : None ,
9198 body : None ,
9299 } )
93100 }
@@ -137,6 +144,14 @@ impl ClientBuilder {
137144 self
138145 }
139146
147+ /// Customize the client's following behavior when served a redirect.
148+ /// To disable following redirects, pass `0`.
149+ /// By default, the limit is [`DEFAULT_REDIRECT_LIMIT`].
150+ pub fn redirect_limit ( mut self , limit : u32 ) -> ClientBuilder {
151+ self . max_redirects = Some ( limit) ;
152+ self
153+ }
154+
140155 /// Build with a specific client connector.
141156 pub fn build_with_conn < C > ( self , conn : C ) -> impl Client
142157 where
@@ -158,6 +173,7 @@ impl ClientBuilder {
158173 method : self . method ,
159174 body : self . body ,
160175 reconnect_opts : self . reconnect_opts ,
176+ max_redirects : self . max_redirects . unwrap_or ( DEFAULT_REDIRECT_LIMIT ) ,
161177 } ,
162178 last_event_id : self . last_event_id ,
163179 }
@@ -188,6 +204,7 @@ impl ClientBuilder {
188204 method : self . method ,
189205 body : self . body ,
190206 reconnect_opts : self . reconnect_opts ,
207+ max_redirects : self . max_redirects . unwrap_or ( DEFAULT_REDIRECT_LIMIT ) ,
191208 } ,
192209 last_event_id : self . last_event_id ,
193210 }
@@ -201,6 +218,7 @@ struct RequestProps {
201218 method : String ,
202219 body : Option < String > ,
203220 reconnect_opts : ReconnectOptions ,
221+ max_redirects : u32 ,
204222}
205223
206224/// A client implementation that connects to a server using the Server-Sent Events protocol
@@ -243,6 +261,7 @@ enum State {
243261 } ,
244262 Connected ( #[ pin] hyper:: Body ) ,
245263 WaitingToReconnect ( #[ pin] Sleep ) ,
264+ FollowingRedirect ( Option < HeaderValue > ) ,
246265 StreamClosed ,
247266}
248267
@@ -254,6 +273,7 @@ impl State {
254273 State :: Connecting { retry : true , .. } => "connecting(retry)" ,
255274 State :: Connected ( _) => "connected" ,
256275 State :: WaitingToReconnect ( _) => "waiting-to-reconnect" ,
276+ State :: FollowingRedirect ( _) => "following-redirect" ,
257277 State :: StreamClosed => "closed" ,
258278 }
259279 }
@@ -273,6 +293,8 @@ pub struct ReconnectingRequest<C> {
273293 #[ pin]
274294 state : State ,
275295 next_reconnect_delay : Duration ,
296+ current_url : Uri ,
297+ redirect_count : u32 ,
276298 event_parser : EventParser ,
277299 last_event_id : Option < String > ,
278300}
@@ -284,11 +306,14 @@ impl<C> ReconnectingRequest<C> {
284306 last_event_id : Option < String > ,
285307 ) -> ReconnectingRequest < C > {
286308 let reconnect_delay = props. reconnect_opts . delay ;
309+ let url = props. url . clone ( ) ;
287310 ReconnectingRequest {
288311 props,
289312 http,
290313 state : State :: New ,
291314 next_reconnect_delay : reconnect_delay,
315+ redirect_count : 0 ,
316+ current_url : url,
292317 event_parser : EventParser :: new ( ) ,
293318 last_event_id,
294319 }
@@ -300,7 +325,7 @@ impl<C> ReconnectingRequest<C> {
300325 {
301326 let mut request_builder = Request :: builder ( )
302327 . method ( self . props . method . as_str ( ) )
303- . uri ( & self . props . url ) ;
328+ . uri ( & self . current_url ) ;
304329
305330 for ( name, value) in & self . props . headers {
306331 request_builder = request_builder. header ( name, value) ;
@@ -343,6 +368,21 @@ impl<C> ReconnectingRequest<C> {
343368 let this = self . project ( ) ;
344369 mem:: swap ( this. next_reconnect_delay , & mut delay) ;
345370 }
371+
372+ fn reset_redirects ( self : Pin < & mut Self > ) {
373+ let url = self . props . url . clone ( ) ;
374+ let this = self . project ( ) ;
375+ * this. current_url = url;
376+ * this. redirect_count = 0 ;
377+ }
378+
379+ fn increment_redirect_counter ( self : Pin < & mut Self > ) -> bool {
380+ if self . redirect_count == self . props . max_redirects {
381+ return false ;
382+ }
383+ * self . project ( ) . redirect_count += 1 ;
384+ true
385+ }
346386}
347387
348388impl < C > Stream for ReconnectingRequest < C >
@@ -400,16 +440,39 @@ where
400440 Ok ( resp) => {
401441 debug ! ( "HTTP response: {:#?}" , resp) ;
402442
403- if !resp. status ( ) . is_success ( ) {
404- self . as_mut ( ) . project ( ) . state . set ( State :: New ) ;
405- return Poll :: Ready ( Some ( Err ( Error :: HttpRequest ( resp. status ( ) ) ) ) ) ;
443+ if resp. status ( ) . is_success ( ) {
444+ self . as_mut ( ) . reset_backoff ( ) ;
445+ self . as_mut ( ) . reset_redirects ( ) ;
446+ self . as_mut ( )
447+ . project ( )
448+ . state
449+ . set ( State :: Connected ( resp. into_body ( ) ) ) ;
450+ continue ;
406451 }
407452
408- self . as_mut ( ) . reset_backoff ( ) ;
409- self . as_mut ( )
410- . project ( )
411- . state
412- . set ( State :: Connected ( resp. into_body ( ) ) ) ;
453+ if resp. status ( ) == 301 || resp. status ( ) == 307 {
454+ debug ! ( "got redirected ({})" , resp. status( ) ) ;
455+
456+ if self . as_mut ( ) . increment_redirect_counter ( ) {
457+ debug ! ( "following redirect {}" , self . redirect_count) ;
458+
459+ self . as_mut ( ) . project ( ) . state . set ( State :: FollowingRedirect (
460+ resp. headers ( ) . get ( hyper:: header:: LOCATION ) . cloned ( ) ,
461+ ) ) ;
462+ continue ;
463+ } else {
464+ debug ! ( "redirect limit reached ({})" , self . props. max_redirects) ;
465+
466+ self . as_mut ( ) . project ( ) . state . set ( State :: StreamClosed ) ;
467+ return Poll :: Ready ( Some ( Err ( Error :: MaxRedirectLimitReached (
468+ self . props . max_redirects ,
469+ ) ) ) ) ;
470+ }
471+ }
472+
473+ self . as_mut ( ) . reset_redirects ( ) ;
474+ self . as_mut ( ) . project ( ) . state . set ( State :: New ) ;
475+ return Poll :: Ready ( Some ( Err ( Error :: UnexpectedResponse ( resp. status ( ) ) ) ) ) ;
413476 }
414477 Err ( e) => {
415478 // This seems basically impossible. AFAIK we can only get this way if we
@@ -426,6 +489,16 @@ where
426489 . set ( State :: WaitingToReconnect ( delay ( duration, "retrying" ) ) )
427490 }
428491 } ,
492+ StateProj :: FollowingRedirect ( maybe_header) => match uri_from_header ( maybe_header) {
493+ Ok ( uri) => {
494+ * self . as_mut ( ) . project ( ) . current_url = uri;
495+ self . as_mut ( ) . project ( ) . state . set ( State :: New ) ;
496+ }
497+ Err ( e) => {
498+ self . as_mut ( ) . project ( ) . state . set ( State :: StreamClosed ) ;
499+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
500+ }
501+ } ,
429502 StateProj :: Connected ( body) => match ready ! ( body. poll_data( cx) ) {
430503 Some ( Ok ( result) ) => {
431504 this. event_parser . process_bytes ( result) ?;
@@ -473,6 +546,23 @@ where
473546 }
474547}
475548
549+ fn uri_from_header ( maybe_header : & Option < HeaderValue > ) -> Result < Uri > {
550+ let header = maybe_header. as_ref ( ) . ok_or_else ( || {
551+ Error :: MalformedLocationHeader ( Box :: new ( std:: io:: Error :: new (
552+ ErrorKind :: NotFound ,
553+ "missing Location header" ,
554+ ) ) )
555+ } ) ?;
556+
557+ let header_string = header
558+ . to_str ( )
559+ . map_err ( |e| Error :: MalformedLocationHeader ( Box :: new ( e) ) ) ?;
560+
561+ header_string
562+ . parse :: < Uri > ( )
563+ . map_err ( |e| Error :: MalformedLocationHeader ( Box :: new ( e) ) )
564+ }
565+
476566fn delay ( dur : Duration , description : & str ) -> Sleep {
477567 info ! ( "Waiting {:?} before {}" , dur, description) ;
478568 tokio:: time:: sleep ( dur)
0 commit comments