@@ -19,7 +19,7 @@ use rustls::{
1919use crate :: error:: Error ;
2020use crate :: io:: ReadBuf ;
2121use crate :: net:: tls:: util:: StdSocket ;
22- use crate :: net:: tls:: TlsConfig ;
22+ use crate :: net:: tls:: { RawTlsConfig , TlsConfig } ;
2323use crate :: net:: Socket ;
2424
2525pub struct RustlsSocket < S : Socket > {
@@ -87,100 +87,134 @@ impl<S: Socket> Socket for RustlsSocket<S> {
8787 }
8888}
8989
90- pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
91- where
92- S : Socket ,
93- {
94- #[ cfg( all(
95- feature = "_tls-rustls-aws-lc-rs" ,
96- not( feature = "_tls-rustls-ring-webpki" ) ,
97- not( feature = "_tls-rustls-ring-native-roots" )
98- ) ) ]
99- let provider = Arc :: new ( rustls:: crypto:: aws_lc_rs:: default_provider ( ) ) ;
100- #[ cfg( any(
101- feature = "_tls-rustls-ring-webpki" ,
102- feature = "_tls-rustls-ring-native-roots"
103- ) ) ]
104- let provider = Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) ;
105-
106- // Unwrapping is safe here because we use a default provider.
107- let config = ClientConfig :: builder_with_provider ( provider. clone ( ) )
90+ impl TlsConfig < ' _ > {
91+ async fn rustls_config ( & self ) -> crate :: Result < ( rustls:: ClientConfig , & str ) , Error > {
92+ let RawTlsConfig {
93+ accept_invalid_certs,
94+ accept_invalid_hostnames,
95+ hostname,
96+ root_cert,
97+ client_cert,
98+ client_key,
99+ } = match self {
100+ TlsConfig :: RawTlsConfig ( raw) => raw,
101+ TlsConfig :: PrebuiltRustls { config, hostname } => {
102+ return Ok ( ( ( * config) . to_owned ( ) , hostname) ) ;
103+ }
104+ } ;
105+
106+ #[ cfg( all(
107+ feature = "_tls-rustls-aws-lc-rs" ,
108+ not( feature = "_tls-rustls-ring-webpki" ) ,
109+ not( feature = "_tls-rustls-ring-native-roots" )
110+ ) ) ]
111+ let config = ClientConfig :: builder_with_provider ( Arc :: new (
112+ rustls:: crypto:: aws_lc_rs:: default_provider ( ) ,
113+ ) )
108114 . with_safe_default_protocol_versions ( )
109115 . unwrap ( ) ;
116+ #[ cfg( any(
117+ feature = "_tls-rustls-ring-webpki" ,
118+ feature = "_tls-rustls-ring-native-roots"
119+ ) ) ]
120+ let config =
121+ ClientConfig :: builder_with_provider ( Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) )
122+ . with_safe_default_protocol_versions ( )
123+ . unwrap ( ) ;
124+ #[ cfg( all(
125+ not( feature = "_tls-rustls-ring-webpki" ) ,
126+ not( feature = "_tls-rustls-ring-native-roots" )
127+ ) ) ]
128+ let config = ClientConfig :: builder ( ) ;
129+
130+ // authentication using user's key and its associated certificate
131+ let user_auth = match ( client_cert, client_key) {
132+ ( Some ( cert) , Some ( key) ) => {
133+ let cert_chain = certs_from_pem ( cert. data ( ) . await ?) ?;
134+ let key_der = private_key_from_pem ( key. data ( ) . await ?) ?;
135+ Some ( ( cert_chain, key_der) )
136+ }
137+ ( None , None ) => None ,
138+ ( _, _) => {
139+ return Err ( Error :: Configuration (
140+ "user auth key and certs must be given together" . into ( ) ,
141+ ) )
142+ }
143+ } ;
110144
111- // authentication using user's key and its associated certificate
112- let user_auth = match ( tls_config. client_cert_path , tls_config. client_key_path ) {
113- ( Some ( cert_path) , Some ( key_path) ) => {
114- let cert_chain = certs_from_pem ( cert_path. data ( ) . await ?) ?;
115- let key_der = private_key_from_pem ( key_path. data ( ) . await ?) ?;
116- Some ( ( cert_chain, key_der) )
117- }
118- ( None , None ) => None ,
119- ( _, _) => {
120- return Err ( Error :: Configuration (
121- "user auth key and certs must be given together" . into ( ) ,
122- ) )
123- }
124- } ;
145+ let provider = config. crypto_provider ( ) . clone ( ) ;
125146
126- let config = if tls_config. accept_invalid_certs {
127- if let Some ( user_auth) = user_auth {
128- config
129- . dangerous ( )
130- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
131- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
132- . map_err ( Error :: tls) ?
147+ let config = if * accept_invalid_certs {
148+ if let Some ( user_auth) = user_auth {
149+ config
150+ . dangerous ( )
151+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
152+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
153+ . map_err ( Error :: tls) ?
154+ } else {
155+ config
156+ . dangerous ( )
157+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
158+ . with_no_client_auth ( )
159+ }
133160 } else {
134- config
135- . dangerous ( )
136- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
137- . with_no_client_auth ( )
138- }
139- } else {
140- let mut cert_store = import_root_certs ( ) ;
161+ let mut cert_store = import_root_certs ( ) ;
141162
142- if let Some ( ca) = tls_config . root_cert_path {
143- let data = ca. data ( ) . await ?;
163+ if let Some ( ca) = root_cert {
164+ let data = ca. data ( ) . await ?;
144165
145- for result in CertificateDer :: pem_slice_iter ( & data) {
146- let Ok ( cert) = result else {
147- return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
148- } ;
166+ for result in CertificateDer :: pem_slice_iter ( & data) {
167+ let Ok ( cert) = result else {
168+ return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
169+ } ;
149170
150- cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
171+ cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
172+ }
151173 }
152- }
153-
154- if tls_config. accept_invalid_hostnames {
155- let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
156- . build ( )
157- . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
158174
159- if let Some ( user_auth) = user_auth {
175+ if * accept_invalid_hostnames {
176+ let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
177+ . build ( )
178+ . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
179+
180+ if let Some ( user_auth) = user_auth {
181+ config
182+ . dangerous ( )
183+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
184+ verifier,
185+ } ) )
186+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
187+ . map_err ( Error :: tls) ?
188+ } else {
189+ config
190+ . dangerous ( )
191+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
192+ verifier,
193+ } ) )
194+ . with_no_client_auth ( )
195+ }
196+ } else if let Some ( user_auth) = user_auth {
160197 config
161- . dangerous ( )
162- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
198+ . with_root_certificates ( cert_store)
163199 . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
164200 . map_err ( Error :: tls) ?
165201 } else {
166202 config
167- . dangerous ( )
168- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
203+ . with_root_certificates ( cert_store)
169204 . with_no_client_auth ( )
170205 }
171- } else if let Some ( user_auth) = user_auth {
172- config
173- . with_root_certificates ( cert_store)
174- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
175- . map_err ( Error :: tls) ?
176- } else {
177- config
178- . with_root_certificates ( cert_store)
179- . with_no_client_auth ( )
180- }
181- } ;
206+ } ;
207+
208+ Ok ( ( config, hostname) )
209+ }
210+ }
182211
183- let host = ServerName :: try_from ( tls_config. hostname . to_owned ( ) ) . map_err ( Error :: tls) ?;
212+ pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
213+ where
214+ S : Socket ,
215+ {
216+ let ( config, hostname) = tls_config. rustls_config ( ) . await ?;
217+ let host = ServerName :: try_from ( hostname. to_owned ( ) ) . map_err ( Error :: tls) ?;
184218
185219 let mut socket = RustlsSocket {
186220 inner : StdSocket :: new ( socket) ,
0 commit comments