Skip to content

Commit 71eed7b

Browse files
committed
Enforce policy on compat login
1 parent 45a7fbd commit 71eed7b

File tree

4 files changed

+157
-5
lines changed

4 files changed

+157
-5
lines changed

crates/handlers/src/compat/login.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use mas_data_model::{
1616
User,
1717
};
1818
use mas_matrix::HomeserverConnection;
19+
use mas_policy::{Policy, Requester, ViolationCode, model::CompatLoginType};
1920
use mas_storage::{
2021
BoxRepository, BoxRepositoryFactory, RepositoryAccess,
2122
compat::{
@@ -37,6 +38,7 @@ use crate::{
3738
BoundActivityTracker, Limiter, METER, RequesterFingerprint, impl_from_error_for_route,
3839
passwords::{PasswordManager, PasswordVerificationResult},
3940
rate_limit::PasswordCheckLimitedError,
41+
session::count_user_sessions_for_limiting,
4042
};
4143

4244
static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
@@ -213,9 +215,16 @@ pub enum RouteError {
213215

214216
#[error("failed to provision device")]
215217
ProvisionDeviceFailed(#[source] anyhow::Error),
218+
219+
#[error("login rejected by policy")]
220+
PolicyRejected,
221+
222+
#[error("login rejected by policy (hard session limit reached)")]
223+
PolicyHardSessionLimitReached,
216224
}
217225

218226
impl_from_error_for_route!(mas_storage::RepositoryError);
227+
impl_from_error_for_route!(mas_policy::EvaluationError);
219228

220229
impl From<anyhow::Error> for RouteError {
221230
fn from(err: anyhow::Error) -> Self {
@@ -274,6 +283,16 @@ impl IntoResponse for RouteError {
274283
error: "User account has been locked",
275284
status: StatusCode::UNAUTHORIZED,
276285
},
286+
Self::PolicyRejected => MatrixError {
287+
errcode: "M_FORBIDDEN",
288+
error: "Login denied by the policy enforced by this service",
289+
status: StatusCode::FORBIDDEN,
290+
},
291+
Self::PolicyHardSessionLimitReached => MatrixError {
292+
errcode: "M_FORBIDDEN",
293+
error: "You have reached your hard device limit. Please visit your account page to sign some out.",
294+
status: StatusCode::FORBIDDEN,
295+
},
277296
};
278297

279298
(sentry_event_id, response).into_response()
@@ -290,6 +309,7 @@ pub(crate) async fn post(
290309
State(homeserver): State<Arc<dyn HomeserverConnection>>,
291310
State(site_config): State<SiteConfig>,
292311
State(limiter): State<Limiter>,
312+
mut policy: Policy,
293313
requester: RequesterFingerprint,
294314
user_agent: Option<TypedHeader<headers::UserAgent>>,
295315
MatrixJsonBody(input): MatrixJsonBody<RequestBody>,
@@ -329,6 +349,11 @@ pub(crate) async fn post(
329349
&limiter,
330350
requester,
331351
&mut repo,
352+
&mut policy,
353+
Requester {
354+
ip_address: activity_tracker.ip(),
355+
user_agent: user_agent.clone(),
356+
},
332357
username,
333358
password,
334359
input.device_id, // TODO check for validity
@@ -342,6 +367,11 @@ pub(crate) async fn post(
342367
&mut rng,
343368
&clock,
344369
&mut repo,
370+
&mut policy,
371+
Requester {
372+
ip_address: activity_tracker.ip(),
373+
user_agent: user_agent.clone(),
374+
},
345375
&token,
346376
input.device_id,
347377
input.initial_device_display_name,
@@ -459,6 +489,8 @@ async fn token_login(
459489
rng: &mut (dyn RngCore + Send),
460490
clock: &dyn Clock,
461491
repo: &mut BoxRepository,
492+
policy: &mut Policy,
493+
requester: Requester,
462494
token: &str,
463495
requested_device_id: Option<String>,
464496
initial_device_display_name: Option<String>,
@@ -548,6 +580,27 @@ async fn token_login(
548580
.finish_sessions_to_replace_device(clock, &browser_session.user, &device)
549581
.await?;
550582

583+
let session_counts = count_user_sessions_for_limiting(repo, &browser_session.user).await?;
584+
585+
let res = policy
586+
.evaluate_compat_login(mas_policy::CompatLoginInput {
587+
user: &browser_session.user,
588+
login_type: CompatLoginType::WebSso,
589+
session_counts,
590+
requester,
591+
})
592+
.await?;
593+
if !res.valid() {
594+
if res.violations.len() == 1 {
595+
let violation = &res.violations[0];
596+
if violation.code == Some(ViolationCode::TooManySessions) {
597+
// The only violation is having reached the session limit.
598+
return Err(RouteError::PolicyHardSessionLimitReached);
599+
}
600+
}
601+
return Err(RouteError::PolicyRejected);
602+
}
603+
551604
// We first create the session in the database, commit the transaction, then
552605
// create it on the homeserver, scheduling a device sync job afterwards to
553606
// make sure we don't end up in an inconsistent state.
@@ -578,6 +631,8 @@ async fn user_password_login(
578631
limiter: &Limiter,
579632
requester: RequesterFingerprint,
580633
repo: &mut BoxRepository,
634+
policy: &mut Policy,
635+
policy_requester: Requester,
581636
username: &str,
582637
password: String,
583638
requested_device_id: Option<String>,
@@ -651,6 +706,27 @@ async fn user_password_login(
651706
.finish_sessions_to_replace_device(clock, &user, &device)
652707
.await?;
653708

709+
let session_counts = count_user_sessions_for_limiting(repo, &user).await?;
710+
711+
let res = policy
712+
.evaluate_compat_login(mas_policy::CompatLoginInput {
713+
user: &user,
714+
login_type: CompatLoginType::Password,
715+
session_counts,
716+
requester: policy_requester,
717+
})
718+
.await?;
719+
if !res.valid() {
720+
if res.violations.len() == 1 {
721+
let violation = &res.violations[0];
722+
if violation.code == Some(ViolationCode::TooManySessions) {
723+
// The only violation is having reached the session limit.
724+
return Err(RouteError::PolicyHardSessionLimitReached);
725+
}
726+
}
727+
return Err(RouteError::PolicyRejected);
728+
}
729+
654730
let session = repo
655731
.compat_session()
656732
.add(

crates/handlers/src/compat/login_sso_complete.rs

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,28 @@ use axum::{
1111
extract::{Form, Path, State},
1212
response::{Html, IntoResponse, Redirect, Response},
1313
};
14-
use axum_extra::extract::Query;
14+
use axum_extra::{TypedHeader, extract::Query};
1515
use chrono::Duration;
16+
use hyper::StatusCode;
1617
use mas_axum_utils::{
1718
InternalError,
1819
cookies::CookieJar,
1920
csrf::{CsrfExt, ProtectedForm},
2021
};
2122
use mas_data_model::{BoxClock, BoxRng, Clock};
23+
use mas_policy::{Policy, ViolationCode, model::CompatLoginType};
2224
use mas_router::{CompatLoginSsoAction, UrlBuilder};
2325
use mas_storage::{BoxRepository, RepositoryAccess, compat::CompatSsoLoginRepository};
24-
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
26+
use mas_templates::{
27+
CompatLoginPolicyViolationContext, CompatSsoContext, EmptyContext, ErrorContext,
28+
PolicyViolationContext, TemplateContext, Templates,
29+
};
2530
use serde::{Deserialize, Serialize};
2631
use ulid::Ulid;
2732

2833
use crate::{
29-
PreferredLanguage,
30-
session::{SessionOrFallback, load_session_or_fallback},
34+
BoundActivityTracker, PreferredLanguage,
35+
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
3136
};
3237

3338
#[derive(Serialize)]
@@ -56,10 +61,15 @@ pub async fn get(
5661
mut repo: BoxRepository,
5762
State(templates): State<Templates>,
5863
State(url_builder): State<UrlBuilder>,
64+
mut policy: Policy,
65+
activity_tracker: BoundActivityTracker,
66+
user_agent: Option<TypedHeader<headers::UserAgent>>,
5967
cookie_jar: CookieJar,
6068
Path(id): Path<Ulid>,
6169
Query(params): Query<Params>,
6270
) -> Result<Response, InternalError> {
71+
let user_agent = user_agent.map(|ua| ua.to_string());
72+
6373
let (cookie_jar, maybe_session) = match load_session_or_fallback(
6474
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
6575
)
@@ -107,6 +117,35 @@ pub async fn get(
107117
return Ok((cookie_jar, Html(content)).into_response());
108118
}
109119

120+
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
121+
122+
let res = policy
123+
.evaluate_compat_login(mas_policy::CompatLoginInput {
124+
user: &session.user,
125+
login_type: CompatLoginType::WebSso,
126+
session_counts,
127+
requester: mas_policy::Requester {
128+
ip_address: activity_tracker.ip(),
129+
user_agent,
130+
},
131+
})
132+
.await?;
133+
if !res.valid() {
134+
let ctx = CompatLoginPolicyViolationContext::for_violations(
135+
res.violations
136+
.into_iter()
137+
.filter_map(|v| Some(v.code?.as_str()))
138+
.collect(),
139+
)
140+
.with_session(session)
141+
.with_csrf(csrf_token.form_value())
142+
.with_language(locale);
143+
144+
let content = templates.render_compat_login_policy_violation(&ctx)?;
145+
146+
return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
147+
}
148+
110149
let ctx = CompatSsoContext::new(login)
111150
.with_session(session)
112151
.with_csrf(csrf_token.form_value())
@@ -129,11 +168,16 @@ pub async fn post(
129168
PreferredLanguage(locale): PreferredLanguage,
130169
State(templates): State<Templates>,
131170
State(url_builder): State<UrlBuilder>,
171+
mut policy: Policy,
172+
activity_tracker: BoundActivityTracker,
173+
user_agent: Option<TypedHeader<headers::UserAgent>>,
132174
cookie_jar: CookieJar,
133175
Path(id): Path<Ulid>,
134176
Query(params): Query<Params>,
135177
Form(form): Form<ProtectedForm<()>>,
136178
) -> Result<Response, InternalError> {
179+
let user_agent = user_agent.map(|ua| ua.to_string());
180+
137181
let (cookie_jar, maybe_session) = match load_session_or_fallback(
138182
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
139183
)
@@ -200,6 +244,37 @@ pub async fn post(
200244
redirect_uri
201245
};
202246

247+
let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;
248+
249+
let res = policy
250+
.evaluate_compat_login(mas_policy::CompatLoginInput {
251+
user: &session.user,
252+
login_type: CompatLoginType::WebSso,
253+
session_counts,
254+
requester: mas_policy::Requester {
255+
ip_address: activity_tracker.ip(),
256+
user_agent,
257+
},
258+
})
259+
.await?;
260+
261+
if !res.valid() {
262+
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
263+
let ctx = CompatLoginPolicyViolationContext::for_violations(
264+
res.violations
265+
.into_iter()
266+
.filter_map(|v| Some(v.code?.as_str()))
267+
.collect(),
268+
)
269+
.with_session(session)
270+
.with_csrf(csrf_token.form_value())
271+
.with_language(locale);
272+
273+
let content = templates.render_compat_login_policy_violation(&ctx)?;
274+
275+
return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
276+
}
277+
203278
// Note that if the login is not Pending,
204279
// this fails and aborts the transaction.
205280
repo.compat_sso_login()

crates/handlers/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ where
272272
BoxRepository: FromRequestParts<S>,
273273
BoxClock: FromRequestParts<S>,
274274
BoxRng: FromRequestParts<S>,
275+
Policy: FromRequestParts<S>,
275276
{
276277
// A sub-router for human-facing routes with error handling
277278
let human_router = Router::new()

crates/policy/src/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use schemars::JsonSchema;
1717
use serde::{Deserialize, Serialize};
1818

1919
/// A well-known policy code.
20-
#[derive(Deserialize, Debug, Clone, Copy, JsonSchema)]
20+
#[derive(Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)]
2121
#[serde(rename_all = "kebab-case")]
2222
pub enum Code {
2323
/// The username is too short.

0 commit comments

Comments
 (0)