diff --git a/crates/data-model/src/users.rs b/crates/data-model/src/users.rs index 7c7da6293..541eb26d2 100644 --- a/crates/data-model/src/users.rs +++ b/crates/data-model/src/users.rs @@ -272,6 +272,7 @@ pub struct UserRegistration { pub email_authentication_id: Option, pub user_registration_token_id: Option, pub password: Option, + pub upstream_oauth_authorization_session_id: Option, pub post_auth_action: Option, pub ip_address: Option, pub user_agent: Option, diff --git a/crates/handlers/src/graphql/mutations/user_email.rs b/crates/handlers/src/graphql/mutations/user_email.rs index 63b825566..34fb54050 100644 --- a/crates/handlers/src/graphql/mutations/user_email.rs +++ b/crates/handlers/src/graphql/mutations/user_email.rs @@ -817,7 +817,7 @@ impl UserEmailMutations { let authentication = repo .user_email() - .complete_authentication(&clock, authentication, &code) + .complete_authentication_with_code(&clock, authentication, &code) .await?; // Check the email is not already in use by anyone, including the current user diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index d9577bafd..96d1b0180 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -26,7 +26,6 @@ use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{ BoxRepository, RepositoryAccess, - queue::{ProvisionUserJob, QueueJobRepositoryExt as _}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, user::{BrowserSessionRepository, UserEmailRepository, UserRepository}, }; @@ -46,7 +45,7 @@ use super::{ }; use crate::{ BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route, - views::shared::OptionalPostAuthAction, + views::{register::UserRegistrationSessionsCookie, shared::OptionalPostAuthAction}, }; static LOGIN_COUNTER: LazyLock> = LazyLock::new(|| { @@ -610,10 +609,6 @@ pub(crate) async fn post( .lookup_link(link_id) .map_err(|_| RouteError::MissingCookie)?; - let post_auth_action = OptionalPostAuthAction { - post_auth_action: post_auth_action.cloned(), - }; - let link = repo .upstream_oauth_link() .lookup(link_id) @@ -641,7 +636,7 @@ pub(crate) async fn post( let maybe_user_session = user_session_info.load_active_session(&mut repo).await?; let form_state = form.to_form_state(); - let session = match (maybe_user_session, link.user_id, form) { + match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { // The user is already logged in, the link is not linked to any user, and the // user asked to link their account. @@ -649,7 +644,27 @@ pub(crate) async fn post( .associate_to_user(&link, &session.user) .await?; - session + let upstream_session = repo + .upstream_oauth_session() + .consume(&clock, upstream_session) + .await?; + + repo.browser_session() + .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) + .await?; + + let post_auth_action = OptionalPostAuthAction { + post_auth_action: post_auth_action.cloned(), + }; + + let cookie_jar = sessions_cookie + .consume_link(link_id)? + .save(cookie_jar, &clock); + let cookie_jar = cookie_jar.set_session(&session); + + repo.save().await?; + + Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response()) } (None, None, FormData::Link) => { @@ -714,14 +729,38 @@ pub(crate) async fn post( return Err(RouteError::InvalidFormAction); } UpstreamOAuthProviderOnConflict::Add => { - //add link to the user + // Add link to the user repo.upstream_oauth_link() .associate_to_user(&link, &user) .await?; - repo.browser_session() + // And sign in the user + let session = repo + .browser_session() .add(&mut rng, &clock, &user, user_agent) - .await? + .await?; + + let upstream_session = repo + .upstream_oauth_session() + .consume(&clock, upstream_session) + .await?; + + repo.browser_session() + .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) + .await?; + + let post_auth_action = OptionalPostAuthAction { + post_auth_action: post_auth_action.cloned(), + }; + + let cookie_jar = sessions_cookie + .consume_link(link_id)? + .save(cookie_jar, &clock); + let cookie_jar = cookie_jar.set_session(&session); + + repo.save().await?; + + Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response()) } } } @@ -950,61 +989,84 @@ pub(crate) async fn post( REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]); - // Now we can create the user - let user = repo.user().add(&mut rng, &clock, username).await?; + let mut registration = repo + .user_registration() + .add( + &mut rng, + &clock, + username, + activity_tracker.ip(), + user_agent, + post_auth_action.map(|action| serde_json::json!(action)), + ) + .await?; if let Some(terms_url) = &site_config.tos_uri { - repo.user_terms() - .accept_terms(&mut rng, &clock, &user, terms_url.clone()) + registration = repo + .user_registration() + .set_terms_url(registration, terms_url.clone()) .await?; } - // And schedule the job to provision it - let mut job = ProvisionUserJob::new(&user); + // If we have an email, add an email authentication and complete it + if let Some(email) = email { + let authentication = repo + .user_email() + .add_authentication_for_registration(&mut rng, &clock, email, ®istration) + .await?; + let authentication = repo + .user_email() + .complete_authentication_with_upstream( + &clock, + authentication, + &upstream_session, + ) + .await?; - // If we have a display name, set it during provisioning - if let Some(name) = display_name { - job = job.set_display_name(name); + registration = repo + .user_registration() + .set_email_authentication(registration, &authentication) + .await?; } - repo.queue_job().schedule_job(&mut rng, &clock, job).await?; - - // If we have an email, add it to the user - if let Some(email) = email { - repo.user_email() - .add(&mut rng, &clock, &user, email) + // If we have a display name, add it to the registration + if let Some(name) = display_name { + registration = repo + .user_registration() + .set_display_name(registration, name) .await?; } - repo.upstream_oauth_link() - .associate_to_user(&link, &user) + let registration = repo + .user_registration() + .set_upstream_oauth_authorization_session(registration, &upstream_session) .await?; - repo.browser_session() - .add(&mut rng, &clock, &user, user_agent) - .await? - } + repo.upstream_oauth_session() + .consume(&clock, upstream_session) + .await?; - _ => return Err(RouteError::InvalidFormAction), - }; + let registrations = UserRegistrationSessionsCookie::load(&cookie_jar); - let upstream_session = repo - .upstream_oauth_session() - .consume(&clock, upstream_session) - .await?; + let cookie_jar = sessions_cookie + .consume_link(link_id)? + .save(cookie_jar, &clock); - repo.browser_session() - .authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session) - .await?; + let cookie_jar = registrations.add(®istration).save(cookie_jar, &clock); - let cookie_jar = sessions_cookie - .consume_link(link_id)? - .save(cookie_jar, &clock); - let cookie_jar = cookie_jar.set_session(&session); + repo.save().await?; - repo.save().await?; + // Redirect to the user registration flow, in case we have any other step to + // finish + Ok(( + cookie_jar, + url_builder.redirect(&mas_router::RegisterFinish::new(registration.id)), + ) + .into_response()) + } - Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response()) + _ => Err(RouteError::InvalidFormAction), + } } #[cfg(test)] @@ -1013,20 +1075,18 @@ mod tests { use mas_data_model::{ UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderLocalpartPreference, - UpstreamOAuthProviderTokenAuthMethod, + UpstreamOAuthProviderTokenAuthMethod, UserEmailAuthentication, UserRegistration, }; use mas_iana::jose::JsonWebSignatureAlg; use mas_jose::jwt::{JsonWebSignatureHeader, Jwt}; use mas_keystore::Keystore; use mas_router::Route; - use mas_storage::{ - Pagination, Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams, - user::UserEmailFilter, - }; + use mas_storage::{Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams}; use oauth2_types::scope::{OPENID, Scope}; use rand_chacha::ChaChaRng; use serde_json::Value; use sqlx::PgPool; + use ulid::Ulid; use super::UpstreamSessionsCookie; use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup}; @@ -1188,33 +1248,41 @@ mod tests { let response = state.request(request).await; cookies.save_cookies(&response); response.assert_status(StatusCode::SEE_OTHER); + let location = response.headers().get(hyper::header::LOCATION).unwrap(); + // Grab the registration ID from the redirected URL: + // /register/steps/{id}/finish + let registration_id: Ulid = str::from_utf8(location.as_bytes()) + .unwrap() + .rsplit('/') + .nth(1) + .expect("Location to have two slashes") + .parse() + .expect("last segment of location to be a ULID"); // Check that we have a registered user, with the email imported let mut repo = state.repository().await.unwrap(); - let user = repo - .user() - .find_by_username("john") - .await - .unwrap() - .expect("user exists"); - - let link = repo - .upstream_oauth_link() - .find_by_subject(&provider, "subject") + let registration: UserRegistration = repo + .user_registration() + .lookup(registration_id) .await .unwrap() - .expect("link exists"); + .expect("user registration exists"); - assert_eq!(link.user_id, Some(user.id)); + assert_eq!(registration.password, None); + assert_eq!(registration.completed_at, None); + assert_eq!(registration.username, "john"); - let page = repo + let email_auth_id = registration + .email_authentication_id + .expect("registration should have an email authentication"); + let email_auth: UserEmailAuthentication = repo .user_email() - .list(UserEmailFilter::new().for_user(&user), Pagination::first(1)) + .lookup_authentication(email_auth_id) .await - .unwrap(); - let edge = page.edges.first().expect("email exists"); - - assert_eq!(edge.node.email, "john@example.com"); + .unwrap() + .expect("email authentication should exist"); + assert_eq!(email_auth.email, "john@example.com"); + assert!(email_auth.completed_at.is_some()); } #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] diff --git a/crates/handlers/src/views/register/mod.rs b/crates/handlers/src/views/register/mod.rs index ad7867a39..6a51852ae 100644 --- a/crates/handlers/src/views/register/mod.rs +++ b/crates/handlers/src/views/register/mod.rs @@ -21,6 +21,8 @@ mod cookie; pub(crate) mod password; pub(crate) mod steps; +pub use self::cookie::UserRegistrationSessions as UserRegistrationSessionsCookie; + #[tracing::instrument(name = "handlers.views.register.get", skip_all)] pub(crate) async fn get( mut rng: BoxRng, diff --git a/crates/handlers/src/views/register/steps/finish.rs b/crates/handlers/src/views/register/steps/finish.rs index e1ed8a3f0..6b7b5bfc5 100644 --- a/crates/handlers/src/views/register/steps/finish.rs +++ b/crates/handlers/src/views/register/steps/finish.rs @@ -154,56 +154,90 @@ pub(crate) async fn get( // If there is an email authentication, we need to check that the email // address was verified. If there is no email authentication attached, we // need to make sure the server doesn't require it - let email_authentication = if let Some(email_authentication_id) = - registration.email_authentication_id + let email_authentication = + if let Some(email_authentication_id) = registration.email_authentication_id { + let email_authentication = repo + .user_email() + .lookup_authentication(email_authentication_id) + .await? + .context("Could not load the email authentication") + .map_err(InternalError::from_anyhow)?; + + // Check that the email authentication has been completed + if email_authentication.completed_at.is_none() { + return Ok(( + cookie_jar, + url_builder.redirect(&mas_router::RegisterVerifyEmail::new(id)), + ) + .into_response()); + } + + // Check that the email address isn't already used + // It is important to do that here, as we we're not checking during the + // registration, because we don't want to disclose whether an email is + // already being used or not before we verified it + if repo + .user_email() + .count(UserEmailFilter::new().for_email(&email_authentication.email)) + .await? + > 0 + { + let action = registration + .post_auth_action + .map(serde_json::from_value) + .transpose()?; + + let ctx = RegisterStepsEmailInUseContext::new(email_authentication.email, action) + .with_language(lang); + + return Ok(( + cookie_jar, + Html(templates.render_register_steps_email_in_use(&ctx)?), + ) + .into_response()); + } + + Some(email_authentication) + } else { + None + }; + + // If this registration was created from an upstream OAuth session, check + // it is still valid and wasn't linked to a user in the meantime + let upstream_oauth = if let Some(upstream_oauth_authorization_session_id) = + registration.upstream_oauth_authorization_session_id { - let email_authentication = repo - .user_email() - .lookup_authentication(email_authentication_id) + let upstream_oauth_authorization_session = repo + .upstream_oauth_session() + .lookup(upstream_oauth_authorization_session_id) .await? - .context("Could not load the email authentication") + .context("Could not load the upstream OAuth authorization session") .map_err(InternalError::from_anyhow)?; - // Check that the email authentication has been completed - if email_authentication.completed_at.is_none() { - return Ok(( - cookie_jar, - url_builder.redirect(&mas_router::RegisterVerifyEmail::new(id)), - ) - .into_response()); - } + let link_id = upstream_oauth_authorization_session + .link_id() + // This should not happen, the session is associated with the user + // registration once the link was already created + .context("Authorization session has no upstream link associated with it") + .map_err(InternalError::from_anyhow)?; - // Check that the email address isn't already used - // It is important to do that here, as we we're not checking during the - // registration, because we don't want to disclose whether an email is - // already being used or not before we verified it - if repo - .user_email() - .count(UserEmailFilter::new().for_email(&email_authentication.email)) + let upstream_oauth_link = repo + .upstream_oauth_link() + .lookup(link_id) .await? - > 0 - { - let action = registration - .post_auth_action - .map(serde_json::from_value) - .transpose()?; - - let ctx = RegisterStepsEmailInUseContext::new(email_authentication.email, action) - .with_language(lang); + .context("Could not load the upstream OAuth link") + .map_err(InternalError::from_anyhow)?; - return Ok(( - cookie_jar, - Html(templates.render_register_steps_email_in_use(&ctx)?), - ) - .into_response()); + if upstream_oauth_link.user_id.is_some() { + // This means the link was already associated to a user. This could + // in theory happen if the same user registers concurrently, but + // this is not going to happen often enough to have a dedicated page + return Err(InternalError::from_anyhow(anyhow::anyhow!( + "The upstream identity was already linked to a user. Try logging in again" + ))); } - Some(email_authentication) - } else if site_config.password_registration_email_required { - // This could only happen in theory during a configuration change - return Err(InternalError::from_anyhow(anyhow::anyhow!( - "Server requires an email address to complete the registration, but no email authentication was attached to the user registration" - ))); + Some((upstream_oauth_authorization_session, upstream_oauth_link)) } else { None }; @@ -272,6 +306,16 @@ pub(crate) async fn get( PASSWORD_REGISTER_COUNTER.add(1, &[]); } + if let Some((upstream_session, upstream_link)) = upstream_oauth { + repo.upstream_oauth_link() + .associate_to_user(&upstream_link, &user) + .await?; + + repo.browser_session() + .authenticate_with_upstream(&mut rng, &clock, &user_session, &upstream_session) + .await?; + } + if let Some(terms_url) = registration.terms_url { repo.user_terms() .accept_terms(&mut rng, &clock, &user, terms_url) diff --git a/crates/handlers/src/views/register/steps/verify_email.rs b/crates/handlers/src/views/register/steps/verify_email.rs index 9b85626e1..d1312c951 100644 --- a/crates/handlers/src/views/register/steps/verify_email.rs +++ b/crates/handlers/src/views/register/steps/verify_email.rs @@ -200,7 +200,7 @@ pub(crate) async fn post( }; repo.user_email() - .complete_authentication(&clock, email_authentication, &code) + .complete_authentication_with_code(&clock, email_authentication, &code) .await?; repo.save().await?; diff --git a/crates/storage-pg/.sqlx/query-4c37988dacca5a83c8b64209042d5f1a8ec44ec8ccccad2d7fce9ac855209883.json b/crates/storage-pg/.sqlx/query-4c37988dacca5a83c8b64209042d5f1a8ec44ec8ccccad2d7fce9ac855209883.json new file mode 100644 index 000000000..1151ca3e6 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-4c37988dacca5a83c8b64209042d5f1a8ec44ec8ccccad2d7fce9ac855209883.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE user_registrations\n SET upstream_oauth_authorization_session_id = $2\n WHERE user_registration_id = $1 AND completed_at IS NULL\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "4c37988dacca5a83c8b64209042d5f1a8ec44ec8ccccad2d7fce9ac855209883" +} diff --git a/crates/storage-pg/.sqlx/query-5bb3ad7486365e0798e103b072514e66b5b69a347dce91135e158a5eba1d1426.json b/crates/storage-pg/.sqlx/query-b91cc2458e1a530e7cadbd1ca3e2eaf93e1c44108b6770a24c9a24ac29db37d3.json similarity index 86% rename from crates/storage-pg/.sqlx/query-5bb3ad7486365e0798e103b072514e66b5b69a347dce91135e158a5eba1d1426.json rename to crates/storage-pg/.sqlx/query-b91cc2458e1a530e7cadbd1ca3e2eaf93e1c44108b6770a24c9a24ac29db37d3.json index bad355b81..68df599ba 100644 --- a/crates/storage-pg/.sqlx/query-5bb3ad7486365e0798e103b072514e66b5b69a347dce91135e158a5eba1d1426.json +++ b/crates/storage-pg/.sqlx/query-b91cc2458e1a530e7cadbd1ca3e2eaf93e1c44108b6770a24c9a24ac29db37d3.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT user_registration_id\n , ip_address as \"ip_address: IpAddr\"\n , user_agent\n , post_auth_action\n , username\n , display_name\n , terms_url\n , email_authentication_id\n , user_registration_token_id\n , hashed_password\n , hashed_password_version\n , created_at\n , completed_at\n FROM user_registrations\n WHERE user_registration_id = $1\n ", + "query": "\n SELECT user_registration_id\n , ip_address as \"ip_address: IpAddr\"\n , user_agent\n , post_auth_action\n , username\n , display_name\n , terms_url\n , email_authentication_id\n , user_registration_token_id\n , hashed_password\n , hashed_password_version\n , upstream_oauth_authorization_session_id\n , created_at\n , completed_at\n FROM user_registrations\n WHERE user_registration_id = $1\n ", "describe": { "columns": [ { @@ -60,11 +60,16 @@ }, { "ordinal": 11, + "name": "upstream_oauth_authorization_session_id", + "type_info": "Uuid" + }, + { + "ordinal": 12, "name": "created_at", "type_info": "Timestamptz" }, { - "ordinal": 12, + "ordinal": 13, "name": "completed_at", "type_info": "Timestamptz" } @@ -86,9 +91,10 @@ true, true, true, + true, false, true ] }, - "hash": "5bb3ad7486365e0798e103b072514e66b5b69a347dce91135e158a5eba1d1426" + "hash": "b91cc2458e1a530e7cadbd1ca3e2eaf93e1c44108b6770a24c9a24ac29db37d3" } diff --git a/crates/storage-pg/migrations/20251121145458_user_registration_upstream_oauth_session.sql b/crates/storage-pg/migrations/20251121145458_user_registration_upstream_oauth_session.sql new file mode 100644 index 000000000..4717dee1c --- /dev/null +++ b/crates/storage-pg/migrations/20251121145458_user_registration_upstream_oauth_session.sql @@ -0,0 +1,10 @@ +-- Copyright 2025 Element Creations Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +-- Please see LICENSE in the repository root for full details. + +-- Track what upstream OAuth session to associate during user registration +ALTER TABLE user_registrations + ADD COLUMN upstream_oauth_authorization_session_id UUID + REFERENCES upstream_oauth_authorization_sessions (upstream_oauth_authorization_session_id) + ON DELETE SET NULL; diff --git a/crates/storage-pg/migrations/20251127145951_user_registration_upstream_oauth_session_idx.sql b/crates/storage-pg/migrations/20251127145951_user_registration_upstream_oauth_session_idx.sql new file mode 100644 index 000000000..b9890ffad --- /dev/null +++ b/crates/storage-pg/migrations/20251127145951_user_registration_upstream_oauth_session_idx.sql @@ -0,0 +1,9 @@ +-- no-transaction +-- Copyright 2025 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +-- Please see LICENSE in the repository root for full details. + +-- Index on the new foreign key added by the previous migration +CREATE INDEX CONCURRENTLY user_registrations_upstream_oauth_session_id_idx + ON user_registrations (upstream_oauth_authorization_session_id); diff --git a/crates/storage-pg/src/user/email.rs b/crates/storage-pg/src/user/email.rs index 0f998e55f..05122ac7a 100644 --- a/crates/storage-pg/src/user/email.rs +++ b/crates/storage-pg/src/user/email.rs @@ -7,8 +7,8 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{ - BrowserSession, Clock, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode, - UserRegistration, + BrowserSession, Clock, UpstreamOAuthAuthorizationSession, User, UserEmail, + UserEmailAuthentication, UserEmailAuthenticationCode, UserRegistration, }; use mas_storage::{ Page, Pagination, @@ -668,7 +668,7 @@ impl UserEmailRepository for PgUserEmailRepository<'_> { } #[tracing::instrument( - name = "db.user_email.complete_email_authentication", + name = "db.user_email.complete_email_authentication_with_code", skip_all, fields( db.query.text, @@ -679,7 +679,7 @@ impl UserEmailRepository for PgUserEmailRepository<'_> { ), err, )] - async fn complete_authentication( + async fn complete_authentication_with_code( &mut self, clock: &dyn Clock, mut user_email_authentication: UserEmailAuthentication, @@ -712,4 +712,49 @@ impl UserEmailRepository for PgUserEmailRepository<'_> { user_email_authentication.completed_at = Some(completed_at); Ok(user_email_authentication) } + + #[tracing::instrument( + name = "db.user_email.complete_email_authentication_with_upstream", + skip_all, + fields( + db.query.text, + %user_email_authentication.id, + %user_email_authentication.email, + %upstream_oauth_authorization_session.id, + ), + err, + )] + async fn complete_authentication_with_upstream( + &mut self, + clock: &dyn Clock, + mut user_email_authentication: UserEmailAuthentication, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result { + // We technically don't use the upstream_oauth_authorization_session here (other + // than recording it in the span), but this is to make sure the caller + // has fetched one before calling this + let completed_at = clock.now(); + + // We'll assume the caller has checked that completed_at is None, so in case + // they haven't, the update will not affect any rows, which will raise + // an error + let res = sqlx::query!( + r#" + UPDATE user_email_authentications + SET completed_at = $2 + WHERE user_email_authentication_id = $1 + AND completed_at IS NULL + "#, + Uuid::from(user_email_authentication.id), + completed_at, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + user_email_authentication.completed_at = Some(completed_at); + Ok(user_email_authentication) + } } diff --git a/crates/storage-pg/src/user/registration.rs b/crates/storage-pg/src/user/registration.rs index fc62c1cd8..69f14ff93 100644 --- a/crates/storage-pg/src/user/registration.rs +++ b/crates/storage-pg/src/user/registration.rs @@ -8,8 +8,8 @@ use std::net::IpAddr; use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{ - Clock, UserEmailAuthentication, UserRegistration, UserRegistrationPassword, - UserRegistrationToken, + Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration, + UserRegistrationPassword, UserRegistrationToken, }; use mas_storage::user::UserRegistrationRepository; use rand::RngCore; @@ -46,6 +46,7 @@ struct UserRegistrationLookup { user_registration_token_id: Option, hashed_password: Option, hashed_password_version: Option, + upstream_oauth_authorization_session_id: Option, created_at: DateTime, completed_at: Option>, } @@ -100,6 +101,9 @@ impl TryFrom for UserRegistration { email_authentication_id: value.email_authentication_id.map(Ulid::from), user_registration_token_id: value.user_registration_token_id.map(Ulid::from), password, + upstream_oauth_authorization_session_id: value + .upstream_oauth_authorization_session_id + .map(Ulid::from), created_at: value.created_at, completed_at: value.completed_at, }) @@ -134,6 +138,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { , user_registration_token_id , hashed_password , hashed_password_version + , upstream_oauth_authorization_session_id , created_at , completed_at FROM user_registrations @@ -208,6 +213,7 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { email_authentication_id: None, user_registration_token_id: None, password: None, + upstream_oauth_authorization_session_id: None, }) } @@ -393,6 +399,42 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { Ok(user_registration) } + #[tracing::instrument( + name = "db.user_registration.set_upstream_oauth_authorization_session", + skip_all, + fields( + db.query.text, + %user_registration.id, + %upstream_oauth_authorization_session.id, + ), + err, + )] + async fn set_upstream_oauth_authorization_session( + &mut self, + mut user_registration: UserRegistration, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result { + let res = sqlx::query!( + r#" + UPDATE user_registrations + SET upstream_oauth_authorization_session_id = $2 + WHERE user_registration_id = $1 AND completed_at IS NULL + "#, + Uuid::from(user_registration.id), + Uuid::from(upstream_oauth_authorization_session.id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1)?; + + user_registration.upstream_oauth_authorization_session_id = + Some(upstream_oauth_authorization_session.id); + + Ok(user_registration) + } + #[tracing::instrument( name = "db.user_registration.complete", skip_all, @@ -433,7 +475,14 @@ impl UserRegistrationRepository for PgUserRegistrationRepository<'_> { mod tests { use std::net::{IpAddr, Ipv4Addr}; - use mas_data_model::{Clock, UserRegistrationPassword, clock::MockClock}; + use mas_data_model::{ + Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode, + UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode, + UpstreamOAuthProviderTokenAuthMethod, UserRegistrationPassword, clock::MockClock, + }; + use mas_iana::jose::JsonWebSignatureAlg; + use mas_storage::upstream_oauth2::UpstreamOAuthProviderParams; + use oauth2_types::scope::Scope; use rand::SeedableRng; use rand_chacha::ChaChaRng; use sqlx::PgPool; @@ -851,4 +900,120 @@ mod tests { .await; assert!(res.is_err()); } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_set_upstream_oauth_session(pool: PgPool) { + let mut rng = ChaChaRng::seed_from_u64(42); + let clock = MockClock::default(); + + let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed(); + + let registration = repo + .user_registration() + .add(&mut rng, &clock, "alice".to_owned(), None, None, None) + .await + .unwrap(); + + assert_eq!(registration.upstream_oauth_authorization_session_id, None); + + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &clock, + UpstreamOAuthProviderParams { + issuer: Some("https://example.com/".to_owned()), + human_name: Some("Example Ltd.".to_owned()), + brand_name: None, + scope: Scope::from_iter([oauth2_types::scope::OPENID]), + token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None, + token_endpoint_signing_alg: None, + id_token_signed_response_alg: JsonWebSignatureAlg::Rs256, + client_id: "client".to_owned(), + encrypted_client_secret: None, + claims_imports: UpstreamOAuthProviderClaimsImports::default(), + authorization_endpoint_override: None, + token_endpoint_override: None, + userinfo_endpoint_override: None, + fetch_userinfo: false, + userinfo_signed_response_alg: None, + jwks_uri_override: None, + discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc, + pkce_mode: UpstreamOAuthProviderPkceMode::Auto, + response_mode: None, + additional_authorization_parameters: Vec::new(), + forward_login_hint: false, + ui_order: 0, + on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing, + }, + ) + .await + .unwrap(); + + let session = repo + .upstream_oauth_session() + .add(&mut rng, &clock, &provider, "state".to_owned(), None, None) + .await + .unwrap(); + + let registration = repo + .user_registration() + .set_upstream_oauth_authorization_session(registration, &session) + .await + .unwrap(); + + assert_eq!( + registration.upstream_oauth_authorization_session_id, + Some(session.id) + ); + + let lookup = repo + .user_registration() + .lookup(registration.id) + .await + .unwrap() + .unwrap(); + + assert_eq!( + lookup.upstream_oauth_authorization_session_id, + registration.upstream_oauth_authorization_session_id + ); + + // Setting it again should work + let registration = repo + .user_registration() + .set_upstream_oauth_authorization_session(registration, &session) + .await + .unwrap(); + + assert_eq!( + registration.upstream_oauth_authorization_session_id, + Some(session.id) + ); + + let lookup = repo + .user_registration() + .lookup(registration.id) + .await + .unwrap() + .unwrap(); + + assert_eq!( + lookup.upstream_oauth_authorization_session_id, + registration.upstream_oauth_authorization_session_id + ); + + // Can't set it once completed + let registration = repo + .user_registration() + .complete(&clock, registration) + .await + .unwrap(); + + let res = repo + .user_registration() + .set_upstream_oauth_authorization_session(registration, &session) + .await; + assert!(res.is_err()); + } } diff --git a/crates/storage-pg/src/user/tests.rs b/crates/storage-pg/src/user/tests.rs index 98489d68d..aa8c9dd07 100644 --- a/crates/storage-pg/src/user/tests.rs +++ b/crates/storage-pg/src/user/tests.rs @@ -488,7 +488,7 @@ async fn test_user_email_repo_authentications(pool: PgPool) { // Complete the authentication let authentication = repo .user_email() - .complete_authentication(&clock, authentication, &code) + .complete_authentication_with_code(&clock, authentication, &code) .await .unwrap(); @@ -514,7 +514,7 @@ async fn test_user_email_repo_authentications(pool: PgPool) { // Completing a second time should fail let res = repo .user_email() - .complete_authentication(&clock, authentication, &code) + .complete_authentication_with_code(&clock, authentication, &code) .await; assert!(res.is_err()); } diff --git a/crates/storage/src/user/email.rs b/crates/storage/src/user/email.rs index 7e973510a..f73414130 100644 --- a/crates/storage/src/user/email.rs +++ b/crates/storage/src/user/email.rs @@ -6,8 +6,8 @@ use async_trait::async_trait; use mas_data_model::{ - BrowserSession, Clock, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode, - UserRegistration, + BrowserSession, Clock, UpstreamOAuthAuthorizationSession, User, UserEmail, + UserEmailAuthentication, UserEmailAuthenticationCode, UserRegistration, }; use rand_core::RngCore; use ulid::Ulid; @@ -306,12 +306,34 @@ pub trait UserEmailRepository: Send + Sync { /// # Errors /// /// Returns an error if the underlying repository fails - async fn complete_authentication( + async fn complete_authentication_with_code( &mut self, clock: &dyn Clock, authentication: UserEmailAuthentication, code: &UserEmailAuthenticationCode, ) -> Result; + + /// Complete a [`UserEmailAuthentication`] by using the given upstream oauth + /// authorization session + /// + /// Returns the completed [`UserEmailAuthentication`] + /// + /// # Parameters + /// + /// * `clock`: The clock to use to generate timestamps + /// * `authentication`: The [`UserEmailAuthentication`] to complete + /// * `upstream_oauth_authorization_session`: The + /// [`UpstreamOAuthAuthorizationSession`] to use + /// + /// # Errors + /// + /// Returns an error if the underlying repository fails + async fn complete_authentication_with_upstream( + &mut self, + clock: &dyn Clock, + authentication: UserEmailAuthentication, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result; } repository_impl!(UserEmailRepository: @@ -374,10 +396,17 @@ repository_impl!(UserEmailRepository: code: &str, ) -> Result, Self::Error>; - async fn complete_authentication( + async fn complete_authentication_with_code( &mut self, clock: &dyn Clock, authentication: UserEmailAuthentication, code: &UserEmailAuthenticationCode, ) -> Result; + + async fn complete_authentication_with_upstream( + &mut self, + clock: &dyn Clock, + authentication: UserEmailAuthentication, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result; ); diff --git a/crates/storage/src/user/registration.rs b/crates/storage/src/user/registration.rs index 0d32684d4..77c85b932 100644 --- a/crates/storage/src/user/registration.rs +++ b/crates/storage/src/user/registration.rs @@ -6,7 +6,10 @@ use std::net::IpAddr; use async_trait::async_trait; -use mas_data_model::{Clock, UserEmailAuthentication, UserRegistration, UserRegistrationToken}; +use mas_data_model::{ + Clock, UpstreamOAuthAuthorizationSession, UserEmailAuthentication, UserRegistration, + UserRegistrationToken, +}; use rand_core::RngCore; use ulid::Ulid; use url::Url; @@ -157,6 +160,27 @@ pub trait UserRegistrationRepository: Send + Sync { user_registration_token: &UserRegistrationToken, ) -> Result; + /// Set an [`UpstreamOAuthAuthorizationSession`] to associate with a + /// [`UserRegistration`] + /// + /// Returns the updated [`UserRegistration`] + /// + /// # Parameters + /// + /// * `user_registration`: The [`UserRegistration`] to update + /// * `upstream_oauth_authorization_session`: The + /// [`UpstreamOAuthAuthorizationSession`] to set + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails or if the + /// registration is already completed + async fn set_upstream_oauth_authorization_session( + &mut self, + user_registration: UserRegistration, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result; + /// Complete a [`UserRegistration`] /// /// Returns the updated [`UserRegistration`] @@ -214,6 +238,11 @@ repository_impl!(UserRegistrationRepository: user_registration: UserRegistration, user_registration_token: &UserRegistrationToken, ) -> Result; + async fn set_upstream_oauth_authorization_session( + &mut self, + user_registration: UserRegistration, + upstream_oauth_authorization_session: &UpstreamOAuthAuthorizationSession, + ) -> Result; async fn complete( &mut self, clock: &dyn Clock,