@@ -6,7 +6,8 @@ use std::io::{Read, Write};
66use std:: ops:: { Deref , DerefMut } ;
77use std:: str;
88use std:: sync:: mpsc;
9- use std:: sync:: Arc ;
9+
10+ #[ cfg( feature = "rsasl" ) ]
1011use rsasl:: prelude:: { Mechname , SASLClient , SASLConfig , Session as SASLSession , State as SASLState } ;
1112
1213use super :: authenticator:: Authenticator ;
@@ -361,7 +362,7 @@ impl<T: Read + Write> Client<T> {
361362 /// match client.login("user", "pass") {
362363 /// Ok(s) => {
363364 /// // you are successfully authenticated!
364- /// },
365+ /// }
365366 /// Err((e, orig_client)) => {
366367 /// eprintln!("error logging in: {}", e);
367368 /// // prompt user and try again with orig_client here
@@ -419,7 +420,7 @@ impl<T: Read + Write> Client<T> {
419420 /// match client.authenticate("XOAUTH2", &auth) {
420421 /// Ok(session) => {
421422 /// // you are successfully authenticated!
422- /// },
423+ /// }
423424 /// Err((e, orig_client)) => {
424425 /// eprintln!("error authenticating: {}", e);
425426 /// // prompt user and try again with orig_client here
@@ -428,9 +429,82 @@ impl<T: Read + Write> Client<T> {
428429 /// };
429430 /// }
430431 /// ```
431- pub fn authenticate (
432+ pub fn authenticate < A : Authenticator > (
432433 mut self ,
433- config : Arc < SASLConfig > ,
434+ auth_type : impl AsRef < str > ,
435+ authenticator : & A ,
436+ ) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
437+ ok_or_unauth_client_err ! (
438+ self . run_command( & format!( "AUTHENTICATE {}" , auth_type. as_ref( ) ) ) ,
439+ self
440+ ) ;
441+ self . do_auth_handshake ( authenticator)
442+ }
443+
444+ /// This func does the handshake process once the authenticate command is made.
445+ fn do_auth_handshake < A : Authenticator > (
446+ mut self ,
447+ authenticator : & A ,
448+ ) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
449+ // TODO Clean up this code
450+ loop {
451+ let mut line = Vec :: new ( ) ;
452+
453+ // explicit match blocks neccessary to convert error to tuple and not bind self too
454+ // early (see also comment on `login`)
455+ ok_or_unauth_client_err ! ( self . readline( & mut line) , self ) ;
456+
457+ // ignore server comments
458+ if line. starts_with ( b"* " ) {
459+ continue ;
460+ }
461+
462+ // Some servers will only send `+\r\n`.
463+ if line. starts_with ( b"+ " ) || & line == b"+\r \n " {
464+ let challenge = if & line == b"+\r \n " {
465+ Vec :: new ( )
466+ } else {
467+ let line_str = ok_or_unauth_client_err ! (
468+ match str :: from_utf8( line. as_slice( ) ) {
469+ Ok ( line_str) => Ok ( line_str) ,
470+ Err ( e) => Err ( Error :: Parse ( ParseError :: DataNotUtf8 ( line, e) ) ) ,
471+ } ,
472+ self
473+ ) ;
474+ let data =
475+ ok_or_unauth_client_err ! ( parse_authenticate_response( line_str) , self ) ;
476+ ok_or_unauth_client_err ! (
477+ base64:: decode( data) . map_err( |e| Error :: Parse ( ParseError :: Authentication (
478+ data. to_string( ) ,
479+ Some ( e)
480+ ) ) ) ,
481+ self
482+ )
483+ } ;
484+
485+ let raw_response = & authenticator. process ( & challenge) ;
486+ let auth_response = base64:: encode ( raw_response) ;
487+ ok_or_unauth_client_err ! (
488+ self . write_line( auth_response. into_bytes( ) . as_slice( ) ) ,
489+ self
490+ ) ;
491+ } else {
492+ ok_or_unauth_client_err ! ( self . read_response_onto( & mut line) , self ) ;
493+ return Ok ( Session :: new ( self . conn ) ) ;
494+ }
495+ }
496+ }
497+ }
498+
499+ #[ cfg( feature = "rsasl" ) ]
500+ impl < T : Read + Write > Client < T > {
501+
502+ /// Authenticate with the server using the given custom SASLConfig to handle the server's
503+ /// challenge.
504+ ///
505+ pub fn sasl_auth (
506+ mut self ,
507+ config : :: std:: sync:: Arc < SASLConfig > ,
434508 mechanism : & Mechname ,
435509 ) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
436510 let client = SASLClient :: new ( config) ;
@@ -444,11 +518,11 @@ impl<T: Read + Write> Client<T> {
444518 self . run_command( & format!( "AUTHENTICATE {}" , mechanism. as_str( ) ) ) ,
445519 self
446520 ) ;
447- self . do_auth_handshake ( session)
521+ self . do_sasl_handshake ( session)
448522 }
449523
450- /// This func does the handshake process once the authenticate command is made.
451- fn do_auth_handshake (
524+ /// This func does the SASL handshake process once the authenticate command is made.
525+ fn do_sasl_handshake (
452526 mut self ,
453527 mut authenticator : SASLSession ,
454528 ) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
0 commit comments