@@ -151,9 +151,6 @@ pub trait TlsConnection: Sized {
151151 /// Library-specific config struct
152152 type Config ;
153153
154- /// Name of the connection type
155- fn name ( ) -> String ;
156-
157154 /// Make connection from existing config and buffer
158155 fn new_from_config (
159156 mode : Mode ,
@@ -166,23 +163,42 @@ pub trait TlsConnection: Sized {
166163
167164 fn handshake_completed ( & self ) -> bool ;
168165
166+ /// Send `data` to the peer.
167+ fn send ( & mut self , data : & [ u8 ] ) -> Result < ( ) , Box < dyn Error > > ;
168+
169+ /// Read application data from the peer into `data`.
170+ fn recv ( & mut self , data : & mut [ u8 ] ) -> Result < ( ) , Box < dyn Error > > ;
171+
172+ /// Send a `CloseNotify` to the peer.
173+ ///
174+ /// This does not read the `CloseNotify` from the peer.
175+ ///
176+ /// Must be followed by a call to [`TlsConnection::shutdown_finish`] to ensure
177+ /// that any `CloseNotify` alerts from the peer are read.
178+ fn shutdown_send ( & mut self ) ;
179+
180+ /// Attempt to read the `CloseNotify` from the peer.
181+ ///
182+ /// Returns `true` if the connection was successfully shutdown, `false` otherwise.
183+ ///
184+ /// The `CloseNotify` might already have been read by `shutdown_send`, depending
185+ /// on the order of client/server [`TlsConnection::shutdown_send`] calls.
186+ fn shutdown_finish ( & mut self ) -> bool ;
187+ }
188+
189+ pub trait TlsInfo : Sized {
190+ fn name ( ) -> String ;
169191 fn get_negotiated_cipher_suite ( & self ) -> CipherSuite ;
170192
171193 fn negotiated_tls13 ( & self ) -> bool ;
172194
173195 /// Describes whether a connection was resumed. This method is only valid on
174196 /// server connections because of rustls API limitations.
175197 fn resumed_connection ( & self ) -> bool ;
176-
177- /// Send application data to ConnectedBuffer
178- fn send ( & mut self , data : & [ u8 ] ) -> Result < ( ) , Box < dyn Error > > ;
179-
180- /// Read application data from ConnectedBuffer
181- fn recv ( & mut self , data : & mut [ u8 ] ) -> Result < ( ) , Box < dyn Error > > ;
182198}
183199
184200/// A TlsConnPair owns the client and server tls connections along with the IO buffers.
185- pub struct TlsConnPair < C : TlsConnection , S : TlsConnection > {
201+ pub struct TlsConnPair < C , S > {
186202 pub client : C ,
187203 pub server : S ,
188204 pub io : TestPairIO ,
@@ -295,18 +311,6 @@ where
295311 self . client . handshake_completed ( ) && self . server . handshake_completed ( )
296312 }
297313
298- pub fn get_negotiated_cipher_suite ( & self ) -> CipherSuite {
299- assert ! ( self . handshake_completed( ) ) ;
300- assert ! (
301- self . client. get_negotiated_cipher_suite( ) == self . server. get_negotiated_cipher_suite( )
302- ) ;
303- self . client . get_negotiated_cipher_suite ( )
304- }
305-
306- pub fn negotiated_tls13 ( & self ) -> bool {
307- self . client . negotiated_tls13 ( ) && self . server . negotiated_tls13 ( )
308- }
309-
310314 /// Send data from client to server, and then from server to client
311315 pub fn round_trip_transfer ( & mut self , data : & mut [ u8 ] ) -> Result < ( ) , Box < dyn Error > > {
312316 // send data from client to server
@@ -319,6 +323,45 @@ where
319323
320324 Ok ( ( ) )
321325 }
326+
327+ pub fn shutdown ( & mut self ) -> Result < ( ) , Box < dyn Error > > {
328+ // These assertions do not _have_ to be true, but you are likely making
329+ // a mistake if you are hitting it. Generally all data should have been
330+ // read before attempting to shutdown
331+ assert_eq ! ( self . io. client_tx_stream. borrow( ) . len( ) , 0 ) ;
332+ assert_eq ! ( self . io. server_tx_stream. borrow( ) . len( ) , 0 ) ;
333+
334+ self . client . shutdown_send ( ) ;
335+ self . server . shutdown_send ( ) ;
336+
337+ let client_shutdown = self . client . shutdown_finish ( ) ;
338+ let server_shutdown = self . server . shutdown_finish ( ) ;
339+ if client_shutdown && server_shutdown {
340+ Ok ( ( ) )
341+ } else {
342+ Err (
343+ format ! ( "Shutdown Failed: client - {client_shutdown} server - {server_shutdown}" )
344+ . into ( ) ,
345+ )
346+ }
347+ }
348+ }
349+
350+ impl < C , S > TlsConnPair < C , S >
351+ where
352+ C : TlsInfo ,
353+ S : TlsInfo ,
354+ {
355+ pub fn get_negotiated_cipher_suite ( & self ) -> CipherSuite {
356+ assert ! (
357+ self . client. get_negotiated_cipher_suite( ) == self . server. get_negotiated_cipher_suite( )
358+ ) ;
359+ self . client . get_negotiated_cipher_suite ( )
360+ }
361+
362+ pub fn negotiated_tls13 ( & self ) -> bool {
363+ self . client . negotiated_tls13 ( ) && self . server . negotiated_tls13 ( )
364+ }
322365}
323366
324367#[ cfg( test) ]
@@ -349,8 +392,8 @@ mod tests {
349392
350393 fn test_type < C , S > ( )
351394 where
352- S : TlsConnection ,
353- C : TlsConnection ,
395+ S : TlsConnection + TlsInfo ,
396+ C : TlsConnection + TlsInfo ,
354397 C :: Config : TlsBenchConfig ,
355398 S :: Config : TlsBenchConfig ,
356399 {
@@ -361,8 +404,8 @@ mod tests {
361404
362405 fn handshake_configs < C , S > ( )
363406 where
364- S : TlsConnection ,
365- C : TlsConnection ,
407+ S : TlsConnection + TlsInfo ,
408+ C : TlsConnection + TlsInfo ,
366409 C :: Config : TlsBenchConfig ,
367410 S :: Config : TlsBenchConfig ,
368411 {
@@ -381,6 +424,13 @@ mod tests {
381424
382425 assert ! ( conn_pair. negotiated_tls13( ) ) ;
383426 assert_eq ! ( cipher_suite, conn_pair. get_negotiated_cipher_suite( ) ) ;
427+
428+ // read in "application data" handshake messages.
429+ // "NewSessionTicket" in the case of resumption
430+ let err = conn_pair. client_mut ( ) . recv ( & mut [ 0 ] ) . unwrap_err ( ) ;
431+ assert_eq ! ( & err. to_string( ) , "blocking" ) ;
432+
433+ conn_pair. shutdown ( ) . unwrap ( ) ;
384434 }
385435 }
386436 }
@@ -389,8 +439,8 @@ mod tests {
389439
390440 fn session_resumption < C , S > ( )
391441 where
392- S : TlsConnection ,
393- C : TlsConnection ,
442+ S : TlsConnection + TlsInfo ,
443+ C : TlsConnection + TlsInfo ,
394444 C :: Config : TlsBenchConfig ,
395445 S :: Config : TlsBenchConfig ,
396446 {
@@ -399,7 +449,12 @@ mod tests {
399449 TlsConnPair :: < C , S > :: new_bench_pair ( CryptoConfig :: default ( ) , HandshakeType :: Resumption )
400450 . unwrap ( ) ;
401451 conn_pair. handshake ( ) . unwrap ( ) ;
452+ // read the session tickets which were sent
453+ let err = conn_pair. client_mut ( ) . recv ( & mut [ 0 ] ) . unwrap_err ( ) ;
454+ assert_eq ! ( & err. to_string( ) , "blocking" ) ;
455+
402456 assert ! ( conn_pair. server( ) . resumed_connection( ) ) ;
457+ conn_pair. shutdown ( ) . unwrap ( ) ;
403458 }
404459
405460 #[ test]
@@ -439,6 +494,7 @@ mod tests {
439494 . unwrap ( ) ;
440495 conn_pair. handshake ( ) . unwrap ( ) ;
441496 conn_pair. round_trip_transfer ( & mut buf) . unwrap ( ) ;
497+ conn_pair. shutdown ( ) . unwrap ( ) ;
442498 }
443499 }
444500}
0 commit comments