@@ -20,24 +20,21 @@ use {
2020 std:: marker:: Unpin ,
2121 std:: pin:: Pin ,
2222 std:: task:: { Context as TaskContext , Poll } ,
23+ tokio:: io:: { AsyncRead , AsyncWrite , ReadBuf } ,
24+ crate :: ssl:: async_utils:: IoAdapter ,
2325} ;
2426
2527use mbedtls_sys:: types:: raw_types:: { c_int, c_uchar, c_void} ;
2628use mbedtls_sys:: types:: size_t;
2729use mbedtls_sys:: * ;
2830
29- #[ cfg( all( feature = "std" , feature = "async" ) ) ]
30- use tokio:: io:: { AsyncRead , AsyncWrite , ReadBuf } ;
31-
3231#[ cfg( not( feature = "std" ) ) ]
3332use crate :: alloc_prelude:: * ;
3433use crate :: alloc:: { List as MbedtlsList } ;
3534use crate :: error:: { Error , Result , IntoResult } ;
3635use crate :: pk:: Pk ;
3736use crate :: private:: UnsafeFrom ;
3837use crate :: ssl:: config:: { Config , Version , AuthMode } ;
39- #[ cfg( all( feature = "std" , feature = "async" ) ) ]
40- use crate :: ssl:: async_utils:: IoAdapter ;
4138use crate :: x509:: { Certificate , Crl , VerifyError } ;
4239
4340pub trait IoCallback {
@@ -199,7 +196,7 @@ define!(
199196 struct HandshakeContext {
200197 handshake_ca_cert: Option <Arc <MbedtlsList <Certificate >>>,
201198 handshake_crl: Option <Arc <Crl >>,
202-
199+
203200 handshake_cert: Vec <Arc <MbedtlsList <Certificate >>>,
204201 handshake_pk: Vec <Arc <Pk >>,
205202 } ;
@@ -213,10 +210,10 @@ define!(
213210pub struct Context < T > {
214211 // Base structure used in SNI callback where we cannot determine the io type.
215212 inner : HandshakeContext ,
216-
213+
217214 // config is used read-only for multiple contexts and is immutable once configured.
218- config : Arc < Config > ,
219-
215+ config : Arc < Config > ,
216+
220217 // Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated.
221218 io : Option < Box < T > > ,
222219
@@ -240,14 +237,10 @@ impl<'a, T> Into<*mut ssl_context> for &'a mut Context<T> {
240237 }
241238}
242239
243- #[ cfg( all( feature = "std" , feature = "async" ) ) ]
244- pub type AsyncContext < T > = Context < IoAdapter < T > > ;
245-
246-
247240impl < T > Context < T > {
248241 pub fn new ( config : Arc < Config > ) -> Self {
249242 let mut inner = ssl_context:: default ( ) ;
250-
243+
251244 unsafe {
252245 ssl_init ( & mut inner) ;
253246 ssl_setup ( & mut inner, ( & * config) . into ( ) ) ;
@@ -258,7 +251,7 @@ impl<T> Context<T> {
258251 inner,
259252 handshake_ca_cert : None ,
260253 handshake_crl : None ,
261-
254+
262255 handshake_cert : vec ! [ ] ,
263256 handshake_pk : vec ! [ ] ,
264257 } ,
@@ -268,11 +261,11 @@ impl<T> Context<T> {
268261 client_transport_id : None ,
269262 }
270263 }
271-
264+
272265 pub ( crate ) fn handle ( & self ) -> & :: mbedtls_sys:: ssl_context {
273266 self . inner . handle ( )
274267 }
275-
268+
276269 pub ( crate ) fn handle_mut ( & mut self ) -> & mut :: mbedtls_sys:: ssl_context {
277270 self . inner . handle_mut ( )
278271 }
@@ -385,23 +378,23 @@ impl<T> Context<T> {
385378 pub fn config ( & self ) -> & Arc < Config > {
386379 & self . config
387380 }
388-
381+
389382 pub fn close ( & mut self ) {
390383 unsafe {
391384 ssl_close_notify ( self . into ( ) ) ;
392385 ssl_set_bio ( self . into ( ) , :: core:: ptr:: null_mut ( ) , None , None , None ) ;
393386 self . io = None ;
394387 }
395388 }
396-
389+
397390 pub fn io ( & self ) -> Option < & T > {
398391 self . io . as_ref ( ) . map ( |v| & * * v)
399392 }
400-
393+
401394 pub fn io_mut ( & mut self ) -> Option < & mut T > {
402395 self . io . as_mut ( ) . map ( |v| & mut * * v)
403396 }
404-
397+
405398 /// Return the minor number of the negotiated TLS version
406399 pub fn minor_version ( & self ) -> i32 {
407400 self . handle ( ) . minor_ver
@@ -433,15 +426,15 @@ impl<T> Context<T> {
433426
434427
435428 // Session specific functions
436-
429+
437430 /// Return the 16-bit ciphersuite identifier.
438431 /// All assigned ciphersuites are listed by the IANA in
439432 /// <https://www.iana.org/assignments/tls-parameters/tls-parameters.txt>
440433 pub fn ciphersuite ( & self ) -> Result < u16 > {
441434 if self . handle ( ) . session . is_null ( ) {
442435 return Err ( Error :: SslBadInputData ) ;
443436 }
444-
437+
445438 Ok ( unsafe { self . handle ( ) . session . as_ref ( ) . unwrap ( ) . ciphersuite as u16 } )
446439 }
447440
@@ -578,12 +571,12 @@ impl HandshakeContext {
578571 self . handshake_ca_cert = None ;
579572 self . handshake_crl = None ;
580573 }
581-
574+
582575 pub fn set_authmode ( & mut self , am : AuthMode ) -> Result < ( ) > {
583576 if self . inner . handshake as * const _ == :: core:: ptr:: null ( ) {
584577 return Err ( Error :: SslBadInputData ) ;
585578 }
586-
579+
587580 unsafe { ssl_set_hs_authmode ( self . into ( ) , am as i32 ) }
588581 Ok ( ( ) )
589582 }
@@ -637,6 +630,9 @@ impl HandshakeContext {
637630 }
638631}
639632
633+ #[ cfg( all( feature = "std" , feature = "async" ) ) ]
634+ pub type AsyncContext < T > = Context < IoAdapter < T > > ;
635+
640636#[ cfg( all( feature = "std" , feature = "async" ) ) ]
641637pub trait IoAsyncCallback {
642638 unsafe extern "C" fn call_recv_async ( user_data : * mut c_void , data : * mut c_uchar , len : size_t ) -> c_int where Self : Sized ;
@@ -700,7 +696,7 @@ impl<T> std::future::Future for HandshakeFuture<'_, T> {
700696 fn poll ( mut self : Pin < & mut Self > , ctx : & mut TaskContext ) -> std:: task:: Poll < Self :: Output > {
701697 self . 0 . io_mut ( ) . ok_or ( Error :: NetInvalidContext ) ?
702698 . ecx . set ( ctx) ;
703-
699+
704700 let result = match self . 0 . handshake ( ) {
705701 Err ( Error :: SslWantRead ) |
706702 Err ( Error :: SslWantWrite ) => {
@@ -709,9 +705,9 @@ impl<T> std::future::Future for HandshakeFuture<'_, T> {
709705 Err ( e) => Poll :: Ready ( Err ( e) ) ,
710706 Ok ( ( ) ) => Poll :: Ready ( Ok ( ( ) ) )
711707 } ;
712-
708+
713709 self . 0 . io_mut ( ) . map ( |v| v. ecx . clear ( ) ) ;
714-
710+
715711 result
716712 }
717713}
@@ -741,7 +737,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncContext<T> {
741737 ) ;
742738
743739 self . io = Some ( io) ;
744- self . inner . reset_handshake ( ) ;
740+ self . inner . reset_handshake ( ) ;
745741 }
746742
747743 HandshakeFuture ( self ) . await
@@ -762,7 +758,7 @@ impl<T: AsyncRead> AsyncRead for Context<IoAdapter<T>> {
762758
763759 self . io_mut ( ) . ok_or ( IoError :: new ( IoErrorKind :: Other , "stream has been shutdown" ) ) ?
764760 . ecx . set ( cx) ;
765-
761+
766762 let result = match unsafe { ssl_read ( ( & mut * self ) . into ( ) , buf. initialize_unfilled ( ) . as_mut_ptr ( ) , buf. initialize_unfilled ( ) . len ( ) ) . into_result ( ) } {
767763 Err ( Error :: SslPeerCloseNotify ) => Poll :: Ready ( Ok ( ( ) ) ) ,
768764 Err ( Error :: SslWantRead ) => Poll :: Pending ,
@@ -798,10 +794,10 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for Context<IoAdapter<T>> {
798794 io. write_tracker . adjust_buf ( buf)
799795 } ?;
800796
801-
797+
802798 self . io_mut ( ) . ok_or ( IoError :: new ( IoErrorKind :: Other , "stream has been shutdown" ) ) ?
803799 . ecx . set ( cx) ;
804-
800+
805801 let result = match unsafe { ssl_write ( ( & mut * self ) . into ( ) , buf. as_ptr ( ) , buf. len ( ) ) . into_result ( ) } {
806802 Err ( Error :: SslPeerCloseNotify ) => Poll :: Ready ( Ok ( 0 ) ) ,
807803 Err ( Error :: SslWantWrite ) => Poll :: Pending ,
@@ -868,7 +864,7 @@ mod tests {
868864
869865 use crate :: ssl:: context:: { HandshakeContext , Context } ;
870866 use crate :: tests:: TestTrait ;
871-
867+
872868 #[ test]
873869 fn handshakecontext_sync ( ) {
874870 assert ! ( !TestTrait :: <dyn Sync , HandshakeContext >:: new( ) . impls_trait( ) , "HandshakeContext must be !Sync" ) ;
@@ -884,7 +880,7 @@ mod tests {
884880 unimplemented ! ( )
885881 }
886882 }
887-
883+
888884 #[ cfg( feature = "std" ) ]
889885 impl Write for NonSendStream {
890886 fn write ( & mut self , _: & [ u8 ] ) -> IoResult < usize > {
@@ -906,7 +902,7 @@ mod tests {
906902 unimplemented ! ( )
907903 }
908904 }
909-
905+
910906 #[ cfg( feature = "std" ) ]
911907 impl Write for SendStream {
912908 fn write ( & mut self , _: & [ u8 ] ) -> IoResult < usize > {
0 commit comments