@@ -147,7 +147,7 @@ pub struct AuthorizationManager {
147147 metadata : Option < AuthorizationMetadata > ,
148148 oauth_client : Option < OAuthClient > ,
149149 credentials : RwLock < Option < OAuthTokenResponse > > ,
150- pkce_verifier : RwLock < Option < PkceCodeVerifier > > ,
150+ state : RwLock < Option < AuthorizationState > > ,
151151 expires_at : RwLock < Option < Instant > > ,
152152 base_url : Url ,
153153}
@@ -172,6 +172,12 @@ pub struct ClientRegistrationResponse {
172172 pub additional_fields : HashMap < String , serde_json:: Value > ,
173173}
174174
175+ #[ derive( Debug ) ]
176+ struct AuthorizationState {
177+ pkce_verifier : PkceCodeVerifier ,
178+ csrf_token : CsrfToken ,
179+ }
180+
175181impl AuthorizationManager {
176182 /// create new auth manager with base url
177183 pub async fn new < U : IntoUrl > ( base_url : U ) -> Result < Self , AuthError > {
@@ -186,7 +192,7 @@ impl AuthorizationManager {
186192 metadata : None ,
187193 oauth_client : None ,
188194 credentials : RwLock :: new ( None ) ,
189- pkce_verifier : RwLock :: new ( None ) ,
195+ state : RwLock :: new ( None ) ,
190196 expires_at : RwLock :: new ( None ) ,
191197 base_url,
192198 } ;
@@ -405,11 +411,14 @@ impl AuthorizationManager {
405411 auth_request = auth_request. add_scope ( Scope :: new ( scope. to_string ( ) ) ) ;
406412 }
407413
408- let ( auth_url, _csrf_token ) = auth_request. url ( ) ;
414+ let ( auth_url, csrf_token ) = auth_request. url ( ) ;
409415
410416 // store pkce verifier for later use
411- * self . pkce_verifier . write ( ) . await = Some ( pkce_verifier) ;
412- debug ! ( "set pkce verifier: {:?}" , self . pkce_verifier. read( ) . await ) ;
417+ * self . state . write ( ) . await = Some ( AuthorizationState {
418+ pkce_verifier,
419+ csrf_token,
420+ } ) ;
421+ debug ! ( "set authorization state: {:?}" , self . state. read( ) . await ) ;
413422
414423 Ok ( auth_url. to_string ( ) )
415424 }
@@ -418,19 +427,25 @@ impl AuthorizationManager {
418427 pub async fn exchange_code_for_token (
419428 & self ,
420429 code : & str ,
430+ csrf_token : & str ,
421431 ) -> Result < StandardTokenResponse < EmptyExtraTokenFields , BasicTokenType > , AuthError > {
422432 debug ! ( "start exchange code for token: {:?}" , code) ;
423433 let oauth_client = self
424434 . oauth_client
425435 . as_ref ( )
426436 . ok_or_else ( || AuthError :: InternalError ( "OAuth client not configured" . to_string ( ) ) ) ?;
427437
428- let pkce_verifier = self
429- . pkce_verifier
430- . write ( )
431- . await
432- . take ( )
433- . ok_or_else ( || AuthError :: InternalError ( "PKCE verifier not found" . to_string ( ) ) ) ?;
438+ let AuthorizationState {
439+ pkce_verifier,
440+ csrf_token : expected_csrf_token,
441+ } =
442+ self . state . write ( ) . await . take ( ) . ok_or_else ( || {
443+ AuthError :: InternalError ( "Authorization state not found" . to_string ( ) )
444+ } ) ?;
445+
446+ if csrf_token != expected_csrf_token. secret ( ) {
447+ return Err ( AuthError :: InternalError ( "CSRF token mismatch" . to_string ( ) ) ) ;
448+ }
434449
435450 let http_client = reqwest:: ClientBuilder :: new ( )
436451 . redirect ( reqwest:: redirect:: Policy :: none ( ) )
@@ -601,8 +616,11 @@ impl AuthorizationSession {
601616 pub async fn handle_callback (
602617 & self ,
603618 code : & str ,
619+ csrf_token : & str ,
604620 ) -> Result < StandardTokenResponse < EmptyExtraTokenFields , BasicTokenType > , AuthError > {
605- self . auth_manager . exchange_code_for_token ( code) . await
621+ self . auth_manager
622+ . exchange_code_for_token ( code, csrf_token)
623+ . await
606624 }
607625}
608626
@@ -787,10 +805,10 @@ impl OAuthState {
787805 }
788806
789807 /// handle authorization callback
790- pub async fn handle_callback ( & mut self , code : & str ) -> Result < ( ) , AuthError > {
808+ pub async fn handle_callback ( & mut self , code : & str , csrf_token : & str ) -> Result < ( ) , AuthError > {
791809 match self {
792810 OAuthState :: Session ( session) => {
793- session. handle_callback ( code) . await ?;
811+ session. handle_callback ( code, csrf_token ) . await ?;
794812 self . complete_authorization ( ) . await
795813 }
796814 OAuthState :: Unauthorized ( _) => {
0 commit comments