Skip to content

Commit 271cf0b

Browse files
authored
fix(oauth): require CSRF token as part of the OAuth authorization flow. (#435)
* Require CSRF token as part of the authorization flow. * Update auth example. * Update docs/OAUTH_SUPPORT.md.
1 parent 83ce13c commit 271cf0b

File tree

3 files changed

+45
-24
lines changed

3 files changed

+45
-24
lines changed

crates/rmcp/src/transport/auth.rs

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ pub struct AuthorizationManager {
147147
metadata: Option<AuthorizationMetadata>,
148148
oauth_client: Option<OAuthClient>,
149149
credentials: RwLock<Option<OAuthTokenResponse>>,
150-
pkce_verifier: RwLock<Option<PkceCodeVerifier>>,
150+
state: RwLock<Option<AuthorizationState>>,
151151
expires_at: RwLock<Option<Instant>>,
152152
base_url: Url,
153153
}
@@ -172,6 +172,12 @@ pub struct ClientRegistrationResponse {
172172
pub additional_fields: HashMap<String, serde_json::Value>,
173173
}
174174

175+
#[derive(Debug)]
176+
struct AuthorizationState {
177+
pkce_verifier: PkceCodeVerifier,
178+
csrf_token: CsrfToken,
179+
}
180+
175181
impl AuthorizationManager {
176182
/// create new auth manager with base url
177183
pub async fn new<U: IntoUrl>(base_url: U) -> Result<Self, AuthError> {
@@ -186,7 +192,7 @@ impl AuthorizationManager {
186192
metadata: None,
187193
oauth_client: None,
188194
credentials: RwLock::new(None),
189-
pkce_verifier: RwLock::new(None),
195+
state: RwLock::new(None),
190196
expires_at: RwLock::new(None),
191197
base_url,
192198
};
@@ -405,11 +411,14 @@ impl AuthorizationManager {
405411
auth_request = auth_request.add_scope(Scope::new(scope.to_string()));
406412
}
407413

408-
let (auth_url, _csrf_token) = auth_request.url();
414+
let (auth_url, csrf_token) = auth_request.url();
409415

410416
// store pkce verifier for later use
411-
*self.pkce_verifier.write().await = Some(pkce_verifier);
412-
debug!("set pkce verifier: {:?}", self.pkce_verifier.read().await);
417+
*self.state.write().await = Some(AuthorizationState {
418+
pkce_verifier,
419+
csrf_token,
420+
});
421+
debug!("set authorization state: {:?}", self.state.read().await);
413422

414423
Ok(auth_url.to_string())
415424
}
@@ -418,19 +427,25 @@ impl AuthorizationManager {
418427
pub async fn exchange_code_for_token(
419428
&self,
420429
code: &str,
430+
csrf_token: &str,
421431
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>, AuthError> {
422432
debug!("start exchange code for token: {:?}", code);
423433
let oauth_client = self
424434
.oauth_client
425435
.as_ref()
426436
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?;
427437

428-
let pkce_verifier = self
429-
.pkce_verifier
430-
.write()
431-
.await
432-
.take()
433-
.ok_or_else(|| AuthError::InternalError("PKCE verifier not found".to_string()))?;
438+
let AuthorizationState {
439+
pkce_verifier,
440+
csrf_token: expected_csrf_token,
441+
} =
442+
self.state.write().await.take().ok_or_else(|| {
443+
AuthError::InternalError("Authorization state not found".to_string())
444+
})?;
445+
446+
if csrf_token != expected_csrf_token.secret() {
447+
return Err(AuthError::InternalError("CSRF token mismatch".to_string()));
448+
}
434449

435450
let http_client = reqwest::ClientBuilder::new()
436451
.redirect(reqwest::redirect::Policy::none())
@@ -601,8 +616,11 @@ impl AuthorizationSession {
601616
pub async fn handle_callback(
602617
&self,
603618
code: &str,
619+
csrf_token: &str,
604620
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>, AuthError> {
605-
self.auth_manager.exchange_code_for_token(code).await
621+
self.auth_manager
622+
.exchange_code_for_token(code, csrf_token)
623+
.await
606624
}
607625
}
608626

@@ -787,10 +805,10 @@ impl OAuthState {
787805
}
788806

789807
/// handle authorization callback
790-
pub async fn handle_callback(&mut self, code: &str) -> Result<(), AuthError> {
808+
pub async fn handle_callback(&mut self, code: &str, csrf_token: &str) -> Result<(), AuthError> {
791809
match self {
792810
OAuthState::Session(session) => {
793-
session.handle_callback(code).await?;
811+
session.handle_callback(code, csrf_token).await?;
794812
self.complete_authorization().await
795813
}
796814
OAuthState::Unauthorized(_) => {

docs/OAUTH_SUPPORT.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ rmcp = { version = "0.1", features = ["auth", "transport-sse-client"] }
4444
println!("Please open the following URL in your browser for authorization:\n{}", auth_url);
4545

4646
// Handle callback - In real applications, this is typically done in a callback server
47-
let auth_code = "Authorization code obtained from browser after user authorization";
48-
let credentials = oauth_state.handle_callback(auth_code).await?;
47+
let auth_code = "Authorization code (`code` param) obtained from browser after user authorization";
48+
let csrf_token = "CSRF token (`state` param) obtained from browser after user authorization";
49+
let credentials = oauth_state.handle_callback(auth_code, csrf_token).await?;
4950

5051
println!("Authorization successful, access token: {}", credentials.access_token);
5152

examples/clients/src/auth/oauth_client.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,24 @@ const CALLBACK_HTML: &str = include_str!("callback.html");
3131

3232
#[derive(Clone)]
3333
struct AppState {
34-
code_receiver: Arc<Mutex<Option<oneshot::Sender<String>>>>,
34+
code_receiver: Arc<Mutex<Option<oneshot::Sender<CallbackParams>>>>,
3535
}
3636

3737
#[derive(Debug, Deserialize)]
3838
struct CallbackParams {
3939
code: String,
40-
#[allow(dead_code)]
41-
state: Option<String>,
40+
state: String,
4241
}
4342

4443
async fn callback_handler(
4544
Query(params): Query<CallbackParams>,
4645
State(state): State<AppState>,
4746
) -> Html<String> {
48-
tracing::info!("Received callback with code: {}", params.code);
47+
tracing::info!("Received callback: {params:?}");
4948

5049
// Send the code to the main thread
5150
if let Some(sender) = state.code_receiver.lock().await.take() {
52-
let _ = sender.send(params.code);
51+
let _ = sender.send(params);
5352
}
5453
// Return success page
5554
Html(CALLBACK_HTML.to_string())
@@ -67,7 +66,7 @@ async fn main() -> Result<()> {
6766
.init();
6867
// it is a http server for handling callback
6968
// Create channel for receiving authorization code
70-
let (code_sender, code_receiver) = oneshot::channel::<String>();
69+
let (code_sender, code_receiver) = oneshot::channel::<CallbackParams>();
7170

7271
// Create app state
7372
let app_state = AppState {
@@ -121,14 +120,17 @@ async fn main() -> Result<()> {
121120

122121
// Wait for authorization code
123122
tracing::info!("Waiting for authorization code...");
124-
let auth_code = code_receiver
123+
let CallbackParams {
124+
code: auth_code,
125+
state: csrf_token,
126+
} = code_receiver
125127
.await
126128
.context("Failed to get authorization code")?;
127129
tracing::info!("Received authorization code: {}", auth_code);
128130
// Exchange code for access token
129131
tracing::info!("Exchanging authorization code for access token...");
130132
oauth_state
131-
.handle_callback(&auth_code)
133+
.handle_callback(&auth_code, &csrf_token)
132134
.await
133135
.context("Failed to handle callback")?;
134136
tracing::info!("Successfully obtained access token");

0 commit comments

Comments
 (0)