@@ -19,7 +19,7 @@ use rustls::{
1919use crate :: error:: Error ;
2020use crate :: io:: ReadBuf ;
2121use crate :: net:: tls:: util:: StdSocket ;
22- use crate :: net:: tls:: TlsConfig ;
22+ use crate :: net:: tls:: { RawTlsConfig , TlsConfig } ;
2323use crate :: net:: Socket ;
2424
2525pub struct RustlsSocket < S : Socket > {
@@ -87,100 +87,135 @@ impl<S: Socket> Socket for RustlsSocket<S> {
8787 }
8888}
8989
90- pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
91- where
92- S : Socket ,
93- {
94- #[ cfg( all(
95- feature = "_tls-rustls-aws-lc-rs" ,
96- not( feature = "_tls-rustls-ring-webpki" ) ,
97- not( feature = "_tls-rustls-ring-native-roots" )
98- ) ) ]
99- let provider = Arc :: new ( rustls:: crypto:: aws_lc_rs:: default_provider ( ) ) ;
100- #[ cfg( any(
101- feature = "_tls-rustls-ring-webpki" ,
102- feature = "_tls-rustls-ring-native-roots"
103- ) ) ]
104- let provider = Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) ;
105-
106- // Unwrapping is safe here because we use a default provider.
107- let config = ClientConfig :: builder_with_provider ( provider. clone ( ) )
90+ impl TlsConfig < ' _ > {
91+ async fn rustls_config ( & self ) -> crate :: Result < ( rustls:: ClientConfig , & str ) , Error > {
92+ let RawTlsConfig {
93+ accept_invalid_certs,
94+ accept_invalid_hostnames,
95+ hostname,
96+ root_cert,
97+ client_cert,
98+ client_key,
99+ } = match self {
100+ TlsConfig :: RawTlsConfig ( raw) => raw,
101+ TlsConfig :: PrebuiltRustls { config, hostname } => {
102+ return Ok ( ( ( * config) . to_owned ( ) , hostname) ) ;
103+ }
104+ } ;
105+
106+ #[ cfg( all(
107+ feature = "_tls-rustls-aws-lc-rs" ,
108+ not( feature = "_tls-rustls-ring-webpki" ) ,
109+ not( feature = "_tls-rustls-ring-native-roots" )
110+ ) ) ]
111+ let config = ClientConfig :: builder_with_provider ( Arc :: new (
112+ rustls:: crypto:: aws_lc_rs:: default_provider ( ) ,
113+ ) )
108114 . with_safe_default_protocol_versions ( )
109115 . unwrap ( ) ;
116+ #[ cfg( any(
117+ feature = "_tls-rustls-ring-webpki" ,
118+ feature = "_tls-rustls-ring-native-roots"
119+ ) ) ]
120+ let config =
121+ ClientConfig :: builder_with_provider ( Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) )
122+ . with_safe_default_protocol_versions ( )
123+ . unwrap ( ) ;
124+ #[ cfg( all(
125+ not( feature = "_tls-rustls-aws-lc-rs" ) ,
126+ not( feature = "_tls-rustls-ring-webpki" ) ,
127+ not( feature = "_tls-rustls-ring-native-roots" )
128+ ) ) ]
129+ let config = ClientConfig :: builder ( ) ;
130+
131+ // authentication using user's key and its associated certificate
132+ let user_auth = match ( client_cert, client_key) {
133+ ( Some ( cert) , Some ( key) ) => {
134+ let cert_chain = certs_from_pem ( cert. data ( ) . await ?) ?;
135+ let key_der = private_key_from_pem ( key. data ( ) . await ?) ?;
136+ Some ( ( cert_chain, key_der) )
137+ }
138+ ( None , None ) => None ,
139+ ( _, _) => {
140+ return Err ( Error :: Configuration (
141+ "user auth key and certs must be given together" . into ( ) ,
142+ ) )
143+ }
144+ } ;
110145
111- // authentication using user's key and its associated certificate
112- let user_auth = match ( tls_config. client_cert_path , tls_config. client_key_path ) {
113- ( Some ( cert_path) , Some ( key_path) ) => {
114- let cert_chain = certs_from_pem ( cert_path. data ( ) . await ?) ?;
115- let key_der = private_key_from_pem ( key_path. data ( ) . await ?) ?;
116- Some ( ( cert_chain, key_der) )
117- }
118- ( None , None ) => None ,
119- ( _, _) => {
120- return Err ( Error :: Configuration (
121- "user auth key and certs must be given together" . into ( ) ,
122- ) )
123- }
124- } ;
146+ let provider = config. crypto_provider ( ) . clone ( ) ;
125147
126- let config = if tls_config. accept_invalid_certs {
127- if let Some ( user_auth) = user_auth {
128- config
129- . dangerous ( )
130- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
131- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
132- . map_err ( Error :: tls) ?
148+ let config = if * accept_invalid_certs {
149+ if let Some ( user_auth) = user_auth {
150+ config
151+ . dangerous ( )
152+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
153+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
154+ . map_err ( Error :: tls) ?
155+ } else {
156+ config
157+ . dangerous ( )
158+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
159+ . with_no_client_auth ( )
160+ }
133161 } else {
134- config
135- . dangerous ( )
136- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
137- . with_no_client_auth ( )
138- }
139- } else {
140- let mut cert_store = import_root_certs ( ) ;
162+ let mut cert_store = import_root_certs ( ) ;
141163
142- if let Some ( ca) = tls_config . root_cert_path {
143- let data = ca. data ( ) . await ?;
164+ if let Some ( ca) = root_cert {
165+ let data = ca. data ( ) . await ?;
144166
145- for result in CertificateDer :: pem_slice_iter ( & data) {
146- let Ok ( cert) = result else {
147- return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
148- } ;
167+ for result in CertificateDer :: pem_slice_iter ( & data) {
168+ let Ok ( cert) = result else {
169+ return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
170+ } ;
149171
150- cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
172+ cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
173+ }
151174 }
152- }
153-
154- if tls_config. accept_invalid_hostnames {
155- let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
156- . build ( )
157- . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
158175
159- if let Some ( user_auth) = user_auth {
176+ if * accept_invalid_hostnames {
177+ let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
178+ . build ( )
179+ . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
180+
181+ if let Some ( user_auth) = user_auth {
182+ config
183+ . dangerous ( )
184+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
185+ verifier,
186+ } ) )
187+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
188+ . map_err ( Error :: tls) ?
189+ } else {
190+ config
191+ . dangerous ( )
192+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
193+ verifier,
194+ } ) )
195+ . with_no_client_auth ( )
196+ }
197+ } else if let Some ( user_auth) = user_auth {
160198 config
161- . dangerous ( )
162- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
199+ . with_root_certificates ( cert_store)
163200 . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
164201 . map_err ( Error :: tls) ?
165202 } else {
166203 config
167- . dangerous ( )
168- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
204+ . with_root_certificates ( cert_store)
169205 . with_no_client_auth ( )
170206 }
171- } else if let Some ( user_auth) = user_auth {
172- config
173- . with_root_certificates ( cert_store)
174- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
175- . map_err ( Error :: tls) ?
176- } else {
177- config
178- . with_root_certificates ( cert_store)
179- . with_no_client_auth ( )
180- }
181- } ;
207+ } ;
208+
209+ Ok ( ( config, hostname) )
210+ }
211+ }
182212
183- let host = ServerName :: try_from ( tls_config. hostname . to_owned ( ) ) . map_err ( Error :: tls) ?;
213+ pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
214+ where
215+ S : Socket ,
216+ {
217+ let ( config, hostname) = tls_config. rustls_config ( ) . await ?;
218+ let host = ServerName :: try_from ( hostname. to_owned ( ) ) . map_err ( Error :: tls) ?;
184219
185220 let mut socket = RustlsSocket {
186221 inner : StdSocket :: new ( socket) ,
0 commit comments