33use std:: str:: FromStr ;
44
55use async_std:: io:: { BufReader , Read , Write } ;
6- use async_std:: prelude:: * ;
6+ use async_std:: { prelude:: * , sync , task } ;
77use http_types:: headers:: { CONTENT_LENGTH , EXPECT , TRANSFER_ENCODING } ;
88use http_types:: { ensure, ensure_eq, format_err} ;
99use http_types:: { Body , Method , Request , Url } ;
1010
1111use crate :: chunked:: ChunkedDecoder ;
12+ use crate :: read_notifier:: ReadNotifier ;
1213use crate :: { MAX_HEADERS , MAX_HEAD_LENGTH } ;
1314
1415const LF : u8 = b'\n' ;
1516
1617/// The number returned from httparse when the request is HTTP 1.1
1718const HTTP_1_1_VERSION : u8 = 1 ;
1819
20+ const CONTINUE_HEADER_VALUE : & str = "100-continue" ;
21+ const CONTINUE_RESPONSE : & [ u8 ] = b"HTTP/1.1 100 Continue\r \n \r \n " ;
22+
1923/// Decode an HTTP request on the server.
2024pub async fn decode < IO > ( mut io : IO ) -> http_types:: Result < Option < Request > >
2125where
7680 req. insert_header ( header. name , std:: str:: from_utf8 ( header. value ) ?) ;
7781 }
7882
79- handle_100_continue ( & req, & mut io) . await ?;
80-
8183 let content_length = req. header ( CONTENT_LENGTH ) ;
8284 let transfer_encoding = req. header ( TRANSFER_ENCODING ) ;
8385
@@ -86,11 +88,32 @@ where
8688 "Unexpected Content-Length header"
8789 ) ;
8890
91+ // Establish a channel to wait for the body to be read. This
92+ // allows us to avoid sending 100-continue in situations that
93+ // respond without reading the body, saving clients from uploading
94+ // their body.
95+ let ( body_read_sender, body_read_receiver) = sync:: channel ( 1 ) ;
96+
97+ if Some ( CONTINUE_HEADER_VALUE ) == req. header ( EXPECT ) . map ( |h| h. as_str ( ) ) {
98+ task:: spawn ( async move {
99+ // If the client expects a 100-continue header, spawn a
100+ // task to wait for the first read attempt on the body.
101+ if let Ok ( ( ) ) = body_read_receiver. recv ( ) . await {
102+ io. write_all ( CONTINUE_RESPONSE ) . await . ok ( ) ;
103+ } ;
104+ // Since the sender is moved into the Body, this task will
105+ // finish when the client disconnects, whether or not
106+ // 100-continue was sent.
107+ } ) ;
108+ }
109+
89110 // Check for Transfer-Encoding
90111 if let Some ( encoding) = transfer_encoding {
91112 if encoding. last ( ) . as_str ( ) == "chunked" {
92113 let trailer_sender = req. send_trailers ( ) ;
93- let reader = BufReader :: new ( ChunkedDecoder :: new ( reader, trailer_sender) ) ;
114+ let reader = ChunkedDecoder :: new ( reader, trailer_sender) ;
115+ let reader = BufReader :: new ( reader) ;
116+ let reader = ReadNotifier :: new ( reader, body_read_sender) ;
94117 req. set_body ( Body :: from_reader ( reader, None ) ) ;
95118 return Ok ( Some ( req) ) ;
96119 }
@@ -100,7 +123,8 @@ where
100123 // Check for Content-Length.
101124 if let Some ( len) = content_length {
102125 let len = len. last ( ) . as_str ( ) . parse :: < usize > ( ) ?;
103- req. set_body ( Body :: from_reader ( reader. take ( len as u64 ) , Some ( len) ) ) ;
126+ let reader = ReadNotifier :: new ( reader. take ( len as u64 ) , body_read_sender) ;
127+ req. set_body ( Body :: from_reader ( reader, Some ( len) ) ) ;
104128 }
105129
106130 Ok ( Some ( req) )
@@ -129,20 +153,6 @@ fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<
129153 }
130154}
131155
132- const EXPECT_HEADER_VALUE : & str = "100-continue" ;
133- const EXPECT_RESPONSE : & [ u8 ] = b"HTTP/1.1 100 Continue\r \n \r \n " ;
134-
135- async fn handle_100_continue < IO > ( req : & Request , io : & mut IO ) -> http_types:: Result < ( ) >
136- where
137- IO : Write + Unpin ,
138- {
139- if let Some ( EXPECT_HEADER_VALUE ) = req. header ( EXPECT ) . map ( |h| h. as_str ( ) ) {
140- io. write_all ( EXPECT_RESPONSE ) . await ?;
141- }
142-
143- Ok ( ( ) )
144- }
145-
146156#[ cfg( test) ]
147157mod tests {
148158 use super :: * ;
@@ -207,36 +217,4 @@ mod tests {
207217 } ,
208218 )
209219 }
210-
211- #[ test]
212- fn handle_100_continue_does_nothing_with_no_expect_header ( ) {
213- let request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
214- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
215- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
216- assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
217- assert ! ( result. is_ok( ) ) ;
218- }
219-
220- #[ test]
221- fn handle_100_continue_sends_header_if_expects_is_exactly_right ( ) {
222- let mut request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
223- request. append_header ( "expect" , "100-continue" ) ;
224- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
225- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
226- assert_eq ! (
227- std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) ,
228- "HTTP/1.1 100 Continue\r \n \r \n "
229- ) ;
230- assert ! ( result. is_ok( ) ) ;
231- }
232-
233- #[ test]
234- fn handle_100_continue_does_nothing_if_expects_header_is_wrong ( ) {
235- let mut request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
236- request. append_header ( "expect" , "110-extensions-not-allowed" ) ;
237- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
238- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
239- assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
240- assert ! ( result. is_ok( ) ) ;
241- }
242220}
0 commit comments