@@ -58,6 +58,8 @@ pub struct Builder<E> {
5858 http1 : http1:: Builder ,
5959 #[ cfg( feature = "http2" ) ]
6060 http2 : http2:: Builder < E > ,
61+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
62+ version : Option < Version > ,
6163 #[ cfg( not( feature = "http2" ) ) ]
6264 _executor : E ,
6365}
@@ -84,6 +86,8 @@ impl<E> Builder<E> {
8486 http1 : http1:: Builder :: new ( ) ,
8587 #[ cfg( feature = "http2" ) ]
8688 http2 : http2:: Builder :: new ( executor) ,
89+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
90+ version : None ,
8791 #[ cfg( not( feature = "http2" ) ) ]
8892 _executor : executor,
8993 }
@@ -101,6 +105,26 @@ impl<E> Builder<E> {
101105 Http2Builder { inner : self }
102106 }
103107
108+ /// Only accepts HTTP/2
109+ ///
110+ /// Does not do anything if used with [`serve_connection_with_upgrades`]
111+ #[ cfg( feature = "http2" ) ]
112+ pub fn http2_only ( mut self ) -> Self {
113+ assert ! ( self . version. is_none( ) ) ;
114+ self . version = Some ( Version :: H2 ) ;
115+ self
116+ }
117+
118+ /// Only accepts HTTP/1
119+ ///
120+ /// Does not do anything if used with [`serve_connection_with_upgrades`]
121+ #[ cfg( feature = "http1" ) ]
122+ pub fn http1_only ( mut self ) -> Self {
123+ assert ! ( self . version. is_none( ) ) ;
124+ self . version = Some ( Version :: H1 ) ;
125+ self
126+ }
127+
104128 /// Bind a connection together with a [`Service`].
105129 pub fn serve_connection < I , S , B > ( & self , io : I , service : S ) -> Connection < ' _ , I , S , E >
106130 where
@@ -112,13 +136,28 @@ impl<E> Builder<E> {
112136 I : Read + Write + Unpin + ' static ,
113137 E : HttpServerConnExec < S :: Future , B > ,
114138 {
115- Connection {
116- state : ConnState :: ReadVersion {
139+ let state = match self . version {
140+ #[ cfg( feature = "http1" ) ]
141+ Some ( Version :: H1 ) => {
142+ let io = Rewind :: new_buffered ( io, Bytes :: new ( ) ) ;
143+ let conn = self . http1 . serve_connection ( io, service) ;
144+ ConnState :: H1 { conn }
145+ }
146+ #[ cfg( feature = "http2" ) ]
147+ Some ( Version :: H2 ) => {
148+ let io = Rewind :: new_buffered ( io, Bytes :: new ( ) ) ;
149+ let conn = self . http2 . serve_connection ( io, service) ;
150+ ConnState :: H2 { conn }
151+ }
152+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
153+ _ => ConnState :: ReadVersion {
117154 read_version : read_version ( io) ,
118155 builder : self ,
119156 service : Some ( service) ,
120157 } ,
121- }
158+ } ;
159+
160+ Connection { state }
122161 }
123162
124163 /// Bind a connection together with a [`Service`], with the ability to
@@ -148,7 +187,7 @@ impl<E> Builder<E> {
148187 }
149188}
150189
151- #[ derive( Copy , Clone ) ]
190+ #[ derive( Copy , Clone , Debug ) ]
152191enum Version {
153192 H1 ,
154193 H2 ,
@@ -906,7 +945,7 @@ mod tests {
906945 #[ cfg( not( miri) ) ]
907946 #[ tokio:: test]
908947 async fn http1 ( ) {
909- let addr = start_server ( ) . await ;
948+ let addr = start_server ( false , false ) . await ;
910949 let mut sender = connect_h1 ( addr) . await ;
911950
912951 let response = sender
@@ -922,7 +961,23 @@ mod tests {
922961 #[ cfg( not( miri) ) ]
923962 #[ tokio:: test]
924963 async fn http2 ( ) {
925- let addr = start_server ( ) . await ;
964+ let addr = start_server ( false , false ) . await ;
965+ let mut sender = connect_h2 ( addr) . await ;
966+
967+ let response = sender
968+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
969+ . await
970+ . unwrap ( ) ;
971+
972+ let body = response. into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
973+
974+ assert_eq ! ( body, BODY ) ;
975+ }
976+
977+ #[ cfg( not( miri) ) ]
978+ #[ tokio:: test]
979+ async fn http2_only ( ) {
980+ let addr = start_server ( false , true ) . await ;
926981 let mut sender = connect_h2 ( addr) . await ;
927982
928983 let response = sender
@@ -935,6 +990,46 @@ mod tests {
935990 assert_eq ! ( body, BODY ) ;
936991 }
937992
993+ #[ cfg( not( miri) ) ]
994+ #[ tokio:: test]
995+ async fn http2_only_fail_if_client_is_http1 ( ) {
996+ let addr = start_server ( false , true ) . await ;
997+ let mut sender = connect_h1 ( addr) . await ;
998+
999+ let _ = sender
1000+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
1001+ . await
1002+ . expect_err ( "should fail" ) ;
1003+ }
1004+
1005+ #[ cfg( not( miri) ) ]
1006+ #[ tokio:: test]
1007+ async fn http1_only ( ) {
1008+ let addr = start_server ( true , false ) . await ;
1009+ let mut sender = connect_h1 ( addr) . await ;
1010+
1011+ let response = sender
1012+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
1013+ . await
1014+ . unwrap ( ) ;
1015+
1016+ let body = response. into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
1017+
1018+ assert_eq ! ( body, BODY ) ;
1019+ }
1020+
1021+ #[ cfg( not( miri) ) ]
1022+ #[ tokio:: test]
1023+ async fn http1_only_fail_if_client_is_http2 ( ) {
1024+ let addr = start_server ( true , false ) . await ;
1025+ let mut sender = connect_h2 ( addr) . await ;
1026+
1027+ let _ = sender
1028+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
1029+ . await
1030+ . expect_err ( "should fail" ) ;
1031+ }
1032+
9381033 #[ cfg( not( miri) ) ]
9391034 #[ tokio:: test]
9401035 async fn graceful_shutdown ( ) {
@@ -1000,7 +1095,7 @@ mod tests {
10001095 sender
10011096 }
10021097
1003- async fn start_server ( ) -> SocketAddr {
1098+ async fn start_server ( h1_only : bool , h2_only : bool ) -> SocketAddr {
10041099 let addr: SocketAddr = ( [ 127 , 0 , 0 , 1 ] , 0 ) . into ( ) ;
10051100 let listener = TcpListener :: bind ( addr) . await . unwrap ( ) ;
10061101
@@ -1011,11 +1106,20 @@ mod tests {
10111106 let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
10121107 let stream = TokioIo :: new ( stream) ;
10131108 tokio:: task:: spawn ( async move {
1014- let _ = auto:: Builder :: new ( TokioExecutor :: new ( ) )
1015- . http2 ( )
1016- . max_header_list_size ( 4096 )
1017- . serve_connection_with_upgrades ( stream, service_fn ( hello) )
1018- . await ;
1109+ let mut builder = auto:: Builder :: new ( TokioExecutor :: new ( ) ) ;
1110+ if h1_only {
1111+ builder = builder. http1_only ( ) ;
1112+ builder. serve_connection ( stream, service_fn ( hello) ) . await ;
1113+ } else if h2_only {
1114+ builder = builder. http2_only ( ) ;
1115+ builder. serve_connection ( stream, service_fn ( hello) ) . await ;
1116+ } else {
1117+ builder
1118+ . http2 ( )
1119+ . max_header_list_size ( 4096 )
1120+ . serve_connection_with_upgrades ( stream, service_fn ( hello) )
1121+ . await ;
1122+ }
10191123 } ) ;
10201124 }
10211125 } ) ;
0 commit comments