diff --git a/Cargo.lock b/Cargo.lock index b67a4a65f..61fc0b4ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3723,6 +3723,7 @@ dependencies = [ "mas-data-model", "mas-i18n", "mas-iana", + "mas-policy", "mas-router", "mas-spa", "minijinja", diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index c0f31557b..454276150 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -145,6 +145,7 @@ pub async fn policy_factory_from_config( register: config.register_entrypoint.clone(), client_registration: config.client_registration_entrypoint.clone(), authorization_grant: config.authorization_grant_entrypoint.clone(), + compat_login: config.compat_login_entrypoint.clone(), email: config.email_entrypoint.clone(), }; diff --git a/crates/config/src/sections/policy.rs b/crates/config/src/sections/policy.rs index 37d052ade..3b816b713 100644 --- a/crates/config/src/sections/policy.rs +++ b/crates/config/src/sections/policy.rs @@ -62,6 +62,14 @@ fn is_default_password_entrypoint(value: &String) -> bool { *value == default_password_entrypoint() } +fn default_compat_login_entrypoint() -> String { + "compat_login/violation".to_owned() +} + +fn is_default_compat_login_entrypoint(value: &String) -> bool { + *value == default_compat_login_entrypoint() +} + fn default_email_entrypoint() -> String { "email/violation".to_owned() } @@ -111,6 +119,13 @@ pub struct PolicyConfig { )] pub authorization_grant_entrypoint: String, + /// Entrypoint to use when evaluating compatibility logins + #[serde( + default = "default_compat_login_entrypoint", + skip_serializing_if = "is_default_compat_login_entrypoint" + )] + pub compat_login_entrypoint: String, + /// Entrypoint to use when changing password #[serde( default = "default_password_entrypoint", @@ -137,6 +152,7 @@ impl Default for PolicyConfig { client_registration_entrypoint: default_client_registration_entrypoint(), register_entrypoint: default_register_entrypoint(), authorization_grant_entrypoint: default_authorization_grant_entrypoint(), + compat_login_entrypoint: default_compat_login_entrypoint(), password_entrypoint: default_password_entrypoint(), email_entrypoint: default_email_entrypoint(), data: default_data(), diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index d3c7c979f..9f10d0373 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -16,6 +16,7 @@ use mas_data_model::{ User, }; use mas_matrix::HomeserverConnection; +use mas_policy::{Policy, Requester, ViolationCode, model::CompatLogin}; use mas_storage::{ BoxRepository, BoxRepositoryFactory, RepositoryAccess, compat::{ @@ -37,6 +38,7 @@ use crate::{ BoundActivityTracker, Limiter, METER, RequesterFingerprint, impl_from_error_for_route, passwords::{PasswordManager, PasswordVerificationResult}, rate_limit::PasswordCheckLimitedError, + session::count_user_sessions_for_limiting, }; static LOGIN_COUNTER: LazyLock> = LazyLock::new(|| { @@ -213,9 +215,16 @@ pub enum RouteError { #[error("failed to provision device")] ProvisionDeviceFailed(#[source] anyhow::Error), + + #[error("login rejected by policy")] + PolicyRejected, + + #[error("login rejected by policy (hard session limit reached)")] + PolicyHardSessionLimitReached, } impl_from_error_for_route!(mas_storage::RepositoryError); +impl_from_error_for_route!(mas_policy::EvaluationError); impl From for RouteError { fn from(err: anyhow::Error) -> Self { @@ -274,6 +283,16 @@ impl IntoResponse for RouteError { error: "User account has been locked", status: StatusCode::UNAUTHORIZED, }, + Self::PolicyRejected => MatrixError { + errcode: "M_FORBIDDEN", + error: "Login denied by the policy enforced by this service", + status: StatusCode::FORBIDDEN, + }, + Self::PolicyHardSessionLimitReached => MatrixError { + errcode: "M_FORBIDDEN", + error: "You have reached your hard device limit. Please visit your account page to sign some out.", + status: StatusCode::FORBIDDEN, + }, }; (sentry_event_id, response).into_response() @@ -290,6 +309,7 @@ pub(crate) async fn post( State(homeserver): State>, State(site_config): State, State(limiter): State, + mut policy: Policy, requester: RequesterFingerprint, user_agent: Option>, MatrixJsonBody(input): MatrixJsonBody, @@ -329,6 +349,11 @@ pub(crate) async fn post( &limiter, requester, &mut repo, + &mut policy, + Requester { + ip_address: activity_tracker.ip(), + user_agent: user_agent.clone(), + }, username, password, input.device_id, // TODO check for validity @@ -342,6 +367,11 @@ pub(crate) async fn post( &mut rng, &clock, &mut repo, + &mut policy, + Requester { + ip_address: activity_tracker.ip(), + user_agent: user_agent.clone(), + }, &token, input.device_id, input.initial_device_display_name, @@ -459,6 +489,8 @@ async fn token_login( rng: &mut (dyn RngCore + Send), clock: &dyn Clock, repo: &mut BoxRepository, + policy: &mut Policy, + requester: Requester, token: &str, requested_device_id: Option, initial_device_display_name: Option, @@ -544,10 +576,38 @@ async fn token_login( Device::generate(rng) }; - repo.app_session() + let session_replaced = repo + .app_session() .finish_sessions_to_replace_device(clock, &browser_session.user, &device) .await?; + let session_counts = count_user_sessions_for_limiting(repo, &browser_session.user).await?; + + let res = policy + .evaluate_compat_login(mas_policy::CompatLoginInput { + user: &browser_session.user, + login: CompatLogin::Token, + session_replaced, + session_counts, + requester, + }) + .await?; + if !res.valid() { + // If the only violation is that we have too many sessions, then handle that + // separately. + // In the future, we intend to evict some sessions automatically instead. We + // don't trigger this if there was some other violation anyway, since that means + // that removing a session wouldn't actually unblock the login. + if res.violations.len() == 1 { + let violation = &res.violations[0]; + if violation.code == Some(ViolationCode::TooManySessions) { + // The only violation is having reached the session limit. + return Err(RouteError::PolicyHardSessionLimitReached); + } + } + return Err(RouteError::PolicyRejected); + } + // We first create the session in the database, commit the transaction, then // create it on the homeserver, scheduling a device sync job afterwards to // make sure we don't end up in an inconsistent state. @@ -578,6 +638,8 @@ async fn user_password_login( limiter: &Limiter, requester: RequesterFingerprint, repo: &mut BoxRepository, + policy: &mut Policy, + policy_requester: Requester, username: &str, password: String, requested_device_id: Option, @@ -647,10 +709,38 @@ async fn user_password_login( Device::generate(&mut rng) }; - repo.app_session() + let session_replaced = repo + .app_session() .finish_sessions_to_replace_device(clock, &user, &device) .await?; + let session_counts = count_user_sessions_for_limiting(repo, &user).await?; + + let res = policy + .evaluate_compat_login(mas_policy::CompatLoginInput { + user: &user, + login: CompatLogin::Password, + session_replaced, + session_counts, + requester: policy_requester, + }) + .await?; + if !res.valid() { + // If the only violation is that we have too many sessions, then handle that + // separately. + // In the future, we intend to evict some sessions automatically instead. We + // don't trigger this if there was some other violation anyway, since that means + // that removing a session wouldn't actually unblock the login. + if res.violations.len() == 1 { + let violation = &res.violations[0]; + if violation.code == Some(ViolationCode::TooManySessions) { + // The only violation is having reached the session limit. + return Err(RouteError::PolicyHardSessionLimitReached); + } + } + return Err(RouteError::PolicyRejected); + } + let session = repo .compat_session() .add( diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index a4fbb24fb..b0735ffce 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -11,23 +11,27 @@ use axum::{ extract::{Form, Path, State}, response::{Html, IntoResponse, Redirect, Response}, }; -use axum_extra::extract::Query; +use axum_extra::{TypedHeader, extract::Query}; use chrono::Duration; +use hyper::StatusCode; use mas_axum_utils::{ InternalError, cookies::CookieJar, csrf::{CsrfExt, ProtectedForm}, }; use mas_data_model::{BoxClock, BoxRng, Clock}; +use mas_policy::{Policy, model::CompatLogin}; use mas_router::{CompatLoginSsoAction, UrlBuilder}; use mas_storage::{BoxRepository, RepositoryAccess, compat::CompatSsoLoginRepository}; -use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; +use mas_templates::{ + CompatLoginPolicyViolationContext, CompatSsoContext, ErrorContext, TemplateContext, Templates, +}; use serde::{Deserialize, Serialize}; use ulid::Ulid; use crate::{ - PreferredLanguage, - session::{SessionOrFallback, load_session_or_fallback}, + BoundActivityTracker, PreferredLanguage, + session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback}, }; #[derive(Serialize)] @@ -56,10 +60,15 @@ pub async fn get( mut repo: BoxRepository, State(templates): State, State(url_builder): State, + mut policy: Policy, + activity_tracker: BoundActivityTracker, + user_agent: Option>, cookie_jar: CookieJar, Path(id): Path, Query(params): Query, ) -> Result { + let user_agent = user_agent.map(|ua| ua.to_string()); + let (cookie_jar, maybe_session) = match load_session_or_fallback( cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo, ) @@ -107,6 +116,35 @@ pub async fn get( return Ok((cookie_jar, Html(content)).into_response()); } + let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?; + + let res = policy + .evaluate_compat_login(mas_policy::CompatLoginInput { + user: &session.user, + login: CompatLogin::Sso { + redirect_uri: login.redirect_uri.to_string(), + }, + // We don't know if there's going to be a replacement until we received the device ID, + // which happens too late. + session_replaced: false, + session_counts, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent, + }, + }) + .await?; + if !res.valid() { + let ctx = CompatLoginPolicyViolationContext::for_violations(res.violations) + .with_session(session) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + let content = templates.render_compat_login_policy_violation(&ctx)?; + + return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response()); + } + let ctx = CompatSsoContext::new(login) .with_session(session) .with_csrf(csrf_token.form_value()) @@ -129,11 +167,16 @@ pub async fn post( PreferredLanguage(locale): PreferredLanguage, State(templates): State, State(url_builder): State, + mut policy: Policy, + activity_tracker: BoundActivityTracker, + user_agent: Option>, cookie_jar: CookieJar, Path(id): Path, Query(params): Query, Form(form): Form>, ) -> Result { + let user_agent = user_agent.map(|ua| ua.to_string()); + let (cookie_jar, maybe_session) = match load_session_or_fallback( cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo, ) @@ -200,6 +243,37 @@ pub async fn post( redirect_uri }; + let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?; + + let res = policy + .evaluate_compat_login(mas_policy::CompatLoginInput { + user: &session.user, + login: CompatLogin::Sso { + redirect_uri: login.redirect_uri.to_string(), + }, + session_counts, + // We don't know if there's going to be a replacement until we received the device ID, + // which happens too late. + session_replaced: false, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent, + }, + }) + .await?; + + if !res.valid() { + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + let ctx = CompatLoginPolicyViolationContext::for_violations(res.violations) + .with_session(session) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + let content = templates.render_compat_login_policy_violation(&ctx)?; + + return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response()); + } + // Note that if the login is not Pending, // this fails and aborts the transaction. repo.compat_sso_login() diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 65a75f550..ebd223e4a 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -272,6 +272,7 @@ where BoxRepository: FromRequestParts, BoxClock: FromRequestParts, BoxRng: FromRequestParts, + Policy: FromRequestParts, { // A sub-router for human-facing routes with error handling let human_router = Router::new() diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 4b93177de..521a4848d 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -82,6 +82,7 @@ pub(crate) async fn policy_factory( register: "register/violation".to_owned(), client_registration: "client_registration/violation".to_owned(), authorization_grant: "authorization_grant/violation".to_owned(), + compat_login: "compat_login/violation".to_owned(), email: "email/violation".to_owned(), }; diff --git a/crates/policy/src/bin/schema.rs b/crates/policy/src/bin/schema.rs index 8e9c81a07..be778f6e1 100644 --- a/crates/policy/src/bin/schema.rs +++ b/crates/policy/src/bin/schema.rs @@ -12,7 +12,7 @@ use std::path::{Path, PathBuf}; use mas_policy::model::{ - AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput, + AuthorizationGrantInput, ClientRegistrationInput, CompatLoginInput, EmailInput, RegisterInput, }; use schemars::{JsonSchema, generate::SchemaSettings}; @@ -42,5 +42,6 @@ fn main() { write_schema::(output_root, "register_input.json"); write_schema::(output_root, "client_registration_input.json"); write_schema::(output_root, "authorization_grant_input.json"); + write_schema::(output_root, "compat_login_input.json"); write_schema::(output_root, "email_input.json"); } diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 8a038aea8..dcb68dd36 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -19,8 +19,9 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; pub use self::model::{ - AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput, - EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation, + AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput, + EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, + Violation, }; #[derive(Debug, Error)] @@ -72,15 +73,17 @@ pub struct Entrypoints { pub register: String, pub client_registration: String, pub authorization_grant: String, + pub compat_login: String, pub email: String, } impl Entrypoints { - fn all(&self) -> [&str; 4] { + fn all(&self) -> [&str; 5] { [ self.register.as_str(), self.client_registration.as_str(), self.authorization_grant.as_str(), + self.compat_login.as_str(), self.email.as_str(), ] } @@ -459,6 +462,30 @@ impl Policy { Ok(res) } + + /// Evaluate the `compat_login` entrypoint. + /// + /// # Errors + /// + /// Returns an error if the policy engine fails to evaluate the entrypoint. + #[tracing::instrument( + name = "policy.evaluate.compat_login", + skip_all, + fields( + %input.user.id, + ), + )] + pub async fn evaluate_compat_login( + &mut self, + input: CompatLoginInput<'_>, + ) -> Result { + let [res]: [EvaluationResult; 1] = self + .instance + .evaluate(&mut self.store, &self.entrypoints.compat_login, &input) + .await?; + + Ok(res) + } } #[cfg(test)] @@ -468,6 +495,16 @@ mod tests { use super::*; + fn make_entrypoints() -> Entrypoints { + Entrypoints { + register: "register/violation".to_owned(), + client_registration: "client_registration/violation".to_owned(), + authorization_grant: "authorization_grant/violation".to_owned(), + compat_login: "compat_login/violation".to_owned(), + email: "email/violation".to_owned(), + } + } + #[tokio::test] async fn test_register() { let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({ @@ -484,14 +521,9 @@ mod tests { let file = tokio::fs::File::open(path).await.unwrap(); - let entrypoints = Entrypoints { - register: "register/violation".to_owned(), - client_registration: "client_registration/violation".to_owned(), - authorization_grant: "authorization_grant/violation".to_owned(), - email: "email/violation".to_owned(), - }; - - let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); + let factory = PolicyFactory::load(file, data, make_entrypoints()) + .await + .unwrap(); let mut policy = factory.instantiate().await.unwrap(); @@ -551,14 +583,9 @@ mod tests { let file = tokio::fs::File::open(path).await.unwrap(); - let entrypoints = Entrypoints { - register: "register/violation".to_owned(), - client_registration: "client_registration/violation".to_owned(), - authorization_grant: "authorization_grant/violation".to_owned(), - email: "email/violation".to_owned(), - }; - - let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); + let factory = PolicyFactory::load(file, data, make_entrypoints()) + .await + .unwrap(); let mut policy = factory.instantiate().await.unwrap(); @@ -620,14 +647,9 @@ mod tests { let file = tokio::fs::File::open(path).await.unwrap(); - let entrypoints = Entrypoints { - register: "register/violation".to_owned(), - client_registration: "client_registration/violation".to_owned(), - authorization_grant: "authorization_grant/violation".to_owned(), - email: "email/violation".to_owned(), - }; - - let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap(); + let factory = PolicyFactory::load(file, data, make_entrypoints()) + .await + .unwrap(); // That is around 1 MB of JSON data. Each element is a 5-digit string, so 8 // characters including the quotes and a comma. diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index b85170025..a9f5fb502 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -17,7 +17,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; /// A well-known policy code. -#[derive(Deserialize, Debug, Clone, Copy, JsonSchema)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, JsonSchema, PartialEq, Eq)] #[serde(rename_all = "kebab-case")] pub enum Code { /// The username is too short. @@ -75,7 +75,7 @@ impl Code { } /// A single violation of a policy. -#[derive(Deserialize, Debug, JsonSchema)] +#[derive(Serialize, Deserialize, Debug, JsonSchema)] pub struct Violation { pub msg: String, pub redirect_uri: Option, @@ -187,6 +187,42 @@ pub struct AuthorizationGrantInput<'a> { pub requester: Requester, } +/// Input for the compatibility login policy. +#[derive(Serialize, Debug, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub struct CompatLoginInput<'a> { + #[schemars(with = "std::collections::HashMap")] + pub user: &'a User, + + /// How many sessions the user has. + pub session_counts: SessionCounts, + + /// Whether a session will be replaced by this login + pub session_replaced: bool, + + /// What type of login is being performed. + /// This also determines whether the login is interactive. + pub login: CompatLogin, + + pub requester: Requester, +} + +#[derive(Serialize, Debug, JsonSchema)] +#[serde(tag = "type")] +pub enum CompatLogin { + /// Used as the interactive part of SSO login. + #[serde(rename = "m.login.sso")] + Sso { redirect_uri: String }, + + /// Used as the final (non-interactive) stage of SSO login. + #[serde(rename = "m.login.token")] + Token, + + /// Non-interactive password-over-the-API login. + #[serde(rename = "m.login.password")] + Password, +} + /// Information about how many sessions the user has #[derive(Serialize, Debug, JsonSchema)] pub struct SessionCounts { diff --git a/crates/storage-pg/src/app_session.rs b/crates/storage-pg/src/app_session.rs index 4e12810cc..2867534c3 100644 --- a/crates/storage-pg/src/app_session.rs +++ b/crates/storage-pg/src/app_session.rs @@ -487,14 +487,15 @@ impl AppSessionRepository for PgAppSessionRepository<'_> { clock: &dyn Clock, user: &User, device: &Device, - ) -> Result<(), Self::Error> { + ) -> Result { + let mut affected = false; // TODO need to invoke this from all the oauth2 login sites let span = tracing::info_span!( "db.app_session.finish_sessions_to_replace_device.compat_sessions", { DB_QUERY_TEXT } = tracing::field::Empty, ); let finished_at = clock.now(); - sqlx::query!( + let compat_affected = sqlx::query!( " UPDATE compat_sessions SET finished_at = $3 WHERE user_id = $1 AND device_id = $2 AND finished_at IS NULL ", @@ -505,7 +506,9 @@ impl AppSessionRepository for PgAppSessionRepository<'_> { .record(&span) .execute(&mut *self.conn) .instrument(span) - .await?; + .await? + .rows_affected(); + affected |= compat_affected > 0; if let Ok([stable_device_as_scope_token, unstable_device_as_scope_token]) = device.to_scope_token() @@ -514,7 +517,7 @@ impl AppSessionRepository for PgAppSessionRepository<'_> { "db.app_session.finish_sessions_to_replace_device.oauth2_sessions", { DB_QUERY_TEXT } = tracing::field::Empty, ); - sqlx::query!( + let oauth2_affected = sqlx::query!( " UPDATE oauth2_sessions SET finished_at = $4 @@ -530,10 +533,12 @@ impl AppSessionRepository for PgAppSessionRepository<'_> { .record(&span) .execute(&mut *self.conn) .instrument(span) - .await?; + .await? + .rows_affected(); + affected |= oauth2_affected > 0; } - Ok(()) + Ok(affected) } } diff --git a/crates/storage/src/app_session.rs b/crates/storage/src/app_session.rs index d649ff35e..4c0b7703a 100644 --- a/crates/storage/src/app_session.rs +++ b/crates/storage/src/app_session.rs @@ -196,12 +196,14 @@ pub trait AppSessionRepository: Send + Sync { /// replacing a device). /// /// Should be called *before* creating a new session for the device. + /// + /// Returns true if a session was finished. async fn finish_sessions_to_replace_device( &mut self, clock: &dyn Clock, user: &User, device: &Device, - ) -> Result<(), Self::Error>; + ) -> Result; } repository_impl!(AppSessionRepository: @@ -218,5 +220,5 @@ repository_impl!(AppSessionRepository: clock: &dyn Clock, user: &User, device: &Device, - ) -> Result<(), Self::Error>; + ) -> Result; ); diff --git a/crates/templates/Cargo.toml b/crates/templates/Cargo.toml index 46ff80c74..d9c1bb019 100644 --- a/crates/templates/Cargo.toml +++ b/crates/templates/Cargo.toml @@ -41,6 +41,7 @@ oauth2-types.workspace = true mas-data-model.workspace = true mas-i18n.workspace = true mas-iana.workspace = true +mas-policy.workspace = true mas-router.workspace = true mas-spa.workspace = true diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 4ed09c3e1..bee16efbd 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -28,6 +28,7 @@ use mas_data_model::{ }; use mas_i18n::DataLocale; use mas_iana::jose::JsonWebSignatureAlg; +use mas_policy::{Violation, ViolationCode}; use mas_router::{Account, GraphQL, PostAuthAction, UrlBuilder}; use oauth2_types::scope::{OPENID, Scope}; use rand::{ @@ -860,6 +861,44 @@ impl PolicyViolationContext { } } +/// Context used by the `compat_login_policy_violation.html` template +#[derive(Serialize)] +pub struct CompatLoginPolicyViolationContext { + violations: Vec, +} + +impl TemplateContext for CompatLoginPolicyViolationContext { + fn sample( + _now: chrono::DateTime, + _rng: &mut R, + _locales: &[DataLocale], + ) -> BTreeMap + where + Self: Sized, + { + sample_list(vec![ + CompatLoginPolicyViolationContext { violations: vec![] }, + CompatLoginPolicyViolationContext { + violations: vec![Violation { + msg: "user has too many active sessions".to_owned(), + redirect_uri: None, + field: None, + code: Some(ViolationCode::TooManySessions), + }], + }, + ]) + } +} + +impl CompatLoginPolicyViolationContext { + /// Constructs a context for the compatibility login policy violation page + /// given the list of violations + #[must_use] + pub const fn for_violations(violations: Vec) -> Self { + Self { violations } + } +} + /// Context used by the `sso.html` template #[derive(Serialize)] pub struct CompatSsoContext { diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 32a41e8b2..dc0e1e714 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -37,14 +37,15 @@ mod macros; pub use self::{ context::{ - AccountInactiveContext, ApiDocContext, AppContext, CompatSsoContext, ConsentContext, - DeviceConsentContext, DeviceLinkContext, DeviceLinkFormField, DeviceNameContext, - EmailRecoveryContext, EmailVerificationContext, EmptyContext, ErrorContext, - FormPostContext, IndexContext, LoginContext, LoginFormField, NotFoundContext, - PasswordRegisterContext, PolicyViolationContext, PostAuthContext, PostAuthContextInner, - RecoveryExpiredContext, RecoveryFinishContext, RecoveryFinishFormField, - RecoveryProgressContext, RecoveryStartContext, RecoveryStartFormField, RegisterContext, - RegisterFormField, RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField, + AccountInactiveContext, ApiDocContext, AppContext, CompatLoginPolicyViolationContext, + CompatSsoContext, ConsentContext, DeviceConsentContext, DeviceLinkContext, + DeviceLinkFormField, DeviceNameContext, EmailRecoveryContext, EmailVerificationContext, + EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, + NotFoundContext, PasswordRegisterContext, PolicyViolationContext, PostAuthContext, + PostAuthContextInner, RecoveryExpiredContext, RecoveryFinishContext, + RecoveryFinishFormField, RecoveryProgressContext, RecoveryStartContext, + RecoveryStartFormField, RegisterContext, RegisterFormField, + RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField, RegisterStepsEmailInUseContext, RegisterStepsRegistrationTokenContext, RegisterStepsRegistrationTokenFormField, RegisterStepsVerifyEmailContext, RegisterStepsVerifyEmailFormField, SiteBranding, SiteConfigExt, SiteFeatures, @@ -391,6 +392,9 @@ register_templates! { /// Render the policy violation page pub fn render_policy_violation(WithLanguage>>) { "pages/policy_violation.html" } + /// Render the compatibility login policy violation page + pub fn render_compat_login_policy_violation(WithLanguage>>) { "pages/compat_login_policy_violation.html" } + /// Render the legacy SSO login consent page pub fn render_sso_login(WithLanguage>>) { "pages/sso.html" } diff --git a/docs/config.schema.json b/docs/config.schema.json index cda68f145..496cd2c5b 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -1883,6 +1883,10 @@ "description": "Entrypoint to use when evaluating authorization grants", "type": "string" }, + "compat_login_entrypoint": { + "description": "Entrypoint to use when evaluating compatibility logins", + "type": "string" + }, "password_entrypoint": { "description": "Entrypoint to use when changing password", "type": "string" diff --git a/policies/Makefile b/policies/Makefile index 0d515b904..db5991672 100644 --- a/policies/Makefile +++ b/policies/Makefile @@ -16,6 +16,7 @@ INPUTS := \ client_registration/client_registration.rego \ register/register.rego \ authorization_grant/authorization_grant.rego \ + compat_login/compat_login.rego \ email/email.rego ifeq ($(DOCKER), 1) @@ -38,6 +39,7 @@ policy.wasm: $(INPUTS) -e "client_registration/violation" \ -e "register/violation" \ -e "authorization_grant/violation" \ + -e "compat_login/violation" \ -e "email/violation" \ $^ tar xzf bundle.tar.gz /policy.wasm diff --git a/policies/compat_login/compat_login.rego b/policies/compat_login/compat_login.rego new file mode 100644 index 000000000..4f76842cd --- /dev/null +++ b/policies/compat_login/compat_login.rego @@ -0,0 +1,74 @@ +# Copyright 2025 Element Creations Ltd. +# +# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +# Please see LICENSE files in the repository root for full details. + +# METADATA +# schemas: +# - input: schema["compat_login_input"] +package compat_login + +import rego.v1 + +import data.common + +default allow := false + +allow if { + count(violation) == 0 +} + +violation contains {"msg": sprintf( + "Requester [%s] isn't allowed to do this action", + [common.format_requester(input.requester)], +)} if { + common.requester_banned(input.requester, data.requester) +} + +violation contains { + "code": "too-many-sessions", + "msg": "user has too many active sessions (soft limit)", +} if { + # Only apply if session limits are enabled in the config + data.session_limit != null + + # This is a web-based interactive login + is_interactive + + # Only apply if this login doesn't replace a session + # (As then this login is not actually increasing the number of devices) + not input.session_replaced + + # For web-based 'compat SSO' login, a violation occurs when the soft limit has already been + # reached or exceeded. + # We use the soft limit because the user will be able to interactively remove + # sessions to return under the limit. + data.session_limit.soft_limit <= input.session_counts.total +} + +violation contains { + "code": "too-many-sessions", + "msg": "user has too many active sessions (hard limit)", +} if { + # Only apply if session limits are enabled in the config + data.session_limit != null + + # This is not a web-based interactive login + not is_interactive + + # Only apply if this login doesn't replace a session + # (As then this login is not actually increasing the number of devices) + not input.session_replaced + + # For `m.login.password` login, a violation occurs when the hard limit has already been + # reached or exceeded. + # We don't use the soft limit because the user won't be able to interactively remove + # sessions to return under the limit. + data.session_limit.hard_limit <= input.session_counts.total +} + +is_interactive if { + # Only `m.login.sso` (the interactive web form) is interactive; + # `m.login.password` and `m.login.token` (including the finalisation of an SSO login) are not + input.login.type == "m.login.sso" +} diff --git a/policies/compat_login/compat_login_test.rego b/policies/compat_login/compat_login_test.rego new file mode 100644 index 000000000..1b8049844 --- /dev/null +++ b/policies/compat_login/compat_login_test.rego @@ -0,0 +1,99 @@ +# Copyright 2025 Element Creations Ltd. +# +# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +# Please see LICENSE files in the repository root for full details. + +package compat_login_test + +import data.compat_login +import rego.v1 + +user := {"username": "john"} + +# Tests session limiting when using (the interactive part of) `m.login.sso` +test_session_limiting_sso if { + compat_login.allow with input.user as user + with input.session_counts as {"total": 1} + with input.login as {"type": "m.login.sso"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + compat_login.allow with input.user as user + with input.session_counts as {"total": 31} + with input.login as {"type": "m.login.sso"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + not compat_login.allow with input.user as user + with input.session_counts as {"total": 32} + with input.login as {"type": "m.login.sso"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + not compat_login.allow with input.user as user + with input.session_counts as {"total": 42} + with input.login as {"type": "m.login.sso"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + not compat_login.allow with input.user as user + with input.session_counts as {"total": 65} + with input.login as {"type": "m.login.sso"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + # No limit configured + compat_login.allow with input.user as user + with input.session_counts as {"total": 1} + with input.login as {"type": "m.login.sso"} + with input.session_replaced as false + with data.session_limit as null +} + +# Test session limiting when using `m.login.password` +test_session_limiting_password if { + compat_login.allow with input.user as user + with input.session_counts as {"total": 1} + with input.login as {"type": "m.login.password"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + compat_login.allow with input.user as user + with input.session_counts as {"total": 63} + with input.login as {"type": "m.login.password"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + not compat_login.allow with input.user as user + with input.session_counts as {"total": 64} + with input.login as {"type": "m.login.password"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + not compat_login.allow with input.user as user + with input.session_counts as {"total": 65} + with input.login as {"type": "m.login.password"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + # No limit configured + compat_login.allow with input.user as user + with input.session_counts as {"total": 1} + with input.login as {"type": "m.login.password"} + with input.session_replaced as false + with data.session_limit as null +} + +test_no_session_limiting_upon_replacement if { + not compat_login.allow with input.user as user + with input.session_counts as {"total": 65} + with input.login as {"type": "m.login.password"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} + + not compat_login.allow with input.user as user + with input.session_counts as {"total": 65} + with input.login as {"type": "m.login.sso"} + with input.session_replaced as false + with data.session_limit as {"soft_limit": 32, "hard_limit": 64} +} diff --git a/policies/schema/compat_login_input.json b/policies/schema/compat_login_input.json new file mode 100644 index 000000000..ffb182de4 --- /dev/null +++ b/policies/schema/compat_login_input.json @@ -0,0 +1,144 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "CompatLoginInput", + "description": "Input for the compatibility login policy.", + "type": "object", + "properties": { + "user": { + "type": "object", + "additionalProperties": true + }, + "session_counts": { + "description": "How many sessions the user has.", + "allOf": [ + { + "$ref": "#/definitions/SessionCounts" + } + ] + }, + "session_replaced": { + "description": "Whether a session will be replaced by this login", + "type": "boolean" + }, + "login": { + "description": "What type of login is being performed.\n This also determines whether the login is interactive.", + "allOf": [ + { + "$ref": "#/definitions/CompatLogin" + } + ] + }, + "requester": { + "$ref": "#/definitions/Requester" + } + }, + "required": [ + "user", + "session_counts", + "session_replaced", + "login", + "requester" + ], + "definitions": { + "SessionCounts": { + "description": "Information about how many sessions the user has", + "type": "object", + "properties": { + "total": { + "type": "integer", + "format": "uint64", + "minimum": 0 + }, + "oauth2": { + "type": "integer", + "format": "uint64", + "minimum": 0 + }, + "compat": { + "type": "integer", + "format": "uint64", + "minimum": 0 + }, + "personal": { + "type": "integer", + "format": "uint64", + "minimum": 0 + } + }, + "required": [ + "total", + "oauth2", + "compat", + "personal" + ] + }, + "CompatLogin": { + "oneOf": [ + { + "description": "Used as the interactive part of SSO login.", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "m.login.sso" + }, + "redirect_uri": { + "type": "string" + } + }, + "required": [ + "type", + "redirect_uri" + ] + }, + { + "description": "Used as the final (non-interactive) stage of SSO login.", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "m.login.token" + } + }, + "required": [ + "type" + ] + }, + { + "description": "Non-interactive password-over-the-API login.", + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "m.login.password" + } + }, + "required": [ + "type" + ] + } + ] + }, + "Requester": { + "description": "Identity of the requester", + "type": "object", + "properties": { + "ip_address": { + "description": "IP address of the entity making the request", + "type": [ + "string", + "null" + ], + "format": "ip" + }, + "user_agent": { + "description": "User agent of the entity making the request", + "type": [ + "string", + "null" + ] + } + } + } + } +} \ No newline at end of file diff --git a/templates/pages/compat_login_policy_violation.html b/templates/pages/compat_login_policy_violation.html new file mode 100644 index 000000000..5953faefb --- /dev/null +++ b/templates/pages/compat_login_policy_violation.html @@ -0,0 +1,32 @@ +{# +Copyright 2024, 2025 New Vector Ltd. +Copyright 2022-2024 The Matrix.org Foundation C.I.C. + +SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +Please see LICENSE files in the repository root for full details. +-#} + +{% extends "base.html" %} + +{% block content %} +
+
+ {{ icon.error_solid() }} +
+ +
+

{{ _("mas.policy_violation.heading") }}

+

{{ _("mas.policy_violation.description") }}

+
+
+ +
+
+

+ {{ _("mas.policy_violation.logged_as", username=current_session.user.username) }} +

+ + {{ logout.button(text=_("action.sign_out"), csrf_token=csrf_token, post_logout_action=action, as_link=True) }} +
+
+{% endblock content %} diff --git a/translations/en.json b/translations/en.json index cdf2df82d..06ae76773 100644 --- a/translations/en.json +++ b/translations/en.json @@ -22,7 +22,7 @@ }, "sign_out": "Sign out", "@sign_out": { - "context": "pages/account/logged_out.html:22:28-48, pages/consent.html:65:28-48, pages/device_consent.html:136:30-50, pages/index.html:28:28-48, pages/policy_violation.html:38:28-48, pages/sso.html:45:28-48, pages/upstream_oauth2/link_mismatch.html:24:24-44, pages/upstream_oauth2/suggest_link.html:32:26-46" + "context": "pages/account/logged_out.html:22:28-48, pages/compat_login_policy_violation.html:29:28-48, pages/consent.html:65:28-48, pages/device_consent.html:136:30-50, pages/index.html:28:28-48, pages/policy_violation.html:38:28-48, pages/sso.html:45:28-48, pages/upstream_oauth2/link_mismatch.html:24:24-44, pages/upstream_oauth2/suggest_link.html:32:26-46" }, "skip": "Skip", "@skip": { @@ -496,17 +496,17 @@ "policy_violation": { "description": "This might be because of the client which authored the request, the currently logged in user, or the request itself.", "@description": { - "context": "pages/policy_violation.html:19:25-62", + "context": "pages/compat_login_policy_violation.html:19:25-62, pages/policy_violation.html:19:25-62", "description": "Displayed when an authorization request is denied by the policy" }, "heading": "The authorization request was denied by the policy enforced by this service", "@heading": { - "context": "pages/policy_violation.html:18:27-60", + "context": "pages/compat_login_policy_violation.html:18:27-60, pages/policy_violation.html:18:27-60", "description": "Displayed when an authorization request is denied by the policy" }, "logged_as": "Logged as %(username)s", "@logged_as": { - "context": "pages/policy_violation.html:35:11-86" + "context": "pages/compat_login_policy_violation.html:26:11-86, pages/policy_violation.html:35:11-86" } }, "recovery": {