Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/data-model/src/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ pub struct UserRegistration {
pub email_authentication_id: Option<Ulid>,
pub user_registration_token_id: Option<Ulid>,
pub password: Option<UserRegistrationPassword>,
pub upstream_oauth_authorization_session_id: Option<Ulid>,
pub post_auth_action: Option<serde_json::Value>,
pub ip_address: Option<IpAddr>,
pub user_agent: Option<String>,
Expand Down
2 changes: 1 addition & 1 deletion crates/handlers/src/graphql/mutations/user_email.rs
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ impl UserEmailMutations {

let authentication = repo
.user_email()
.complete_authentication(&clock, authentication, &code)
.complete_authentication_with_code(&clock, authentication, &code)
.await?;

// Check the email is not already in use by anyone, including the current user
Expand Down
208 changes: 138 additions & 70 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use mas_policy::Policy;
use mas_router::UrlBuilder;
use mas_storage::{
BoxRepository, RepositoryAccess,
queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
};
Expand All @@ -46,7 +45,7 @@ use super::{
};
use crate::{
BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
views::shared::OptionalPostAuthAction,
views::{register::UserRegistrationSessionsCookie, shared::OptionalPostAuthAction},
};

static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
Expand Down Expand Up @@ -610,10 +609,6 @@ pub(crate) async fn post(
.lookup_link(link_id)
.map_err(|_| RouteError::MissingCookie)?;

let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};

let link = repo
.upstream_oauth_link()
.lookup(link_id)
Expand Down Expand Up @@ -641,15 +636,35 @@ pub(crate) async fn post(
let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
let form_state = form.to_form_state();

let session = match (maybe_user_session, link.user_id, form) {
match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => {
// The user is already logged in, the link is not linked to any user, and the
// user asked to link their account.
repo.upstream_oauth_link()
.associate_to_user(&link, &session.user)
.await?;

session
let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;

repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;

let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};

let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = cookie_jar.set_session(&session);

repo.save().await?;

Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
}

(None, None, FormData::Link) => {
Expand Down Expand Up @@ -714,14 +729,38 @@ pub(crate) async fn post(
return Err(RouteError::InvalidFormAction);
}
UpstreamOAuthProviderOnConflict::Add => {
//add link to the user
// Add link to the user
repo.upstream_oauth_link()
.associate_to_user(&link, &user)
.await?;

repo.browser_session()
// And sign in the user
let session = repo
.browser_session()
.add(&mut rng, &clock, &user, user_agent)
.await?
.await?;

let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;

repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;

let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};

let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = cookie_jar.set_session(&session);

repo.save().await?;

Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
}
}
}
Expand Down Expand Up @@ -950,61 +989,84 @@ pub(crate) async fn post(

REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);

// Now we can create the user
let user = repo.user().add(&mut rng, &clock, username).await?;
let mut registration = repo
.user_registration()
.add(
&mut rng,
&clock,
username,
activity_tracker.ip(),
user_agent,
post_auth_action.map(|action| serde_json::json!(action)),
)
.await?;

if let Some(terms_url) = &site_config.tos_uri {
repo.user_terms()
.accept_terms(&mut rng, &clock, &user, terms_url.clone())
registration = repo
.user_registration()
.set_terms_url(registration, terms_url.clone())
.await?;
}

// And schedule the job to provision it
let mut job = ProvisionUserJob::new(&user);
// If we have an email, add an email authentication and complete it
if let Some(email) = email {
let authentication = repo
.user_email()
.add_authentication_for_registration(&mut rng, &clock, email, &registration)
.await?;
let authentication = repo
.user_email()
.complete_authentication_with_upstream(
&clock,
authentication,
&upstream_session,
)
.await?;

// If we have a display name, set it during provisioning
if let Some(name) = display_name {
job = job.set_display_name(name);
registration = repo
.user_registration()
.set_email_authentication(registration, &authentication)
.await?;
}

repo.queue_job().schedule_job(&mut rng, &clock, job).await?;

// If we have an email, add it to the user
if let Some(email) = email {
repo.user_email()
.add(&mut rng, &clock, &user, email)
// If we have a display name, add it to the registration
if let Some(name) = display_name {
registration = repo
.user_registration()
.set_display_name(registration, name)
.await?;
}

repo.upstream_oauth_link()
.associate_to_user(&link, &user)
let registration = repo
.user_registration()
.set_upstream_oauth_authorization_session(registration, &upstream_session)
.await?;

repo.browser_session()
.add(&mut rng, &clock, &user, user_agent)
.await?
}
repo.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;

_ => return Err(RouteError::InvalidFormAction),
};
let registrations = UserRegistrationSessionsCookie::load(&cookie_jar);

let upstream_session = repo
.upstream_oauth_session()
.consume(&clock, upstream_session)
.await?;
let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);

repo.browser_session()
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
.await?;
let cookie_jar = registrations.add(&registration).save(cookie_jar, &clock);

let cookie_jar = sessions_cookie
.consume_link(link_id)?
.save(cookie_jar, &clock);
let cookie_jar = cookie_jar.set_session(&session);
repo.save().await?;

repo.save().await?;
// Redirect to the user registration flow, in case we have any other step to
// finish
Ok((
cookie_jar,
url_builder.redirect(&mas_router::RegisterFinish::new(registration.id)),
)
.into_response())
}

Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
_ => Err(RouteError::InvalidFormAction),
}
}

#[cfg(test)]
Expand All @@ -1013,20 +1075,18 @@ mod tests {
use mas_data_model::{
UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProviderClaimsImports,
UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderLocalpartPreference,
UpstreamOAuthProviderTokenAuthMethod,
UpstreamOAuthProviderTokenAuthMethod, UserEmailAuthentication, UserRegistration,
};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
use mas_keystore::Keystore;
use mas_router::Route;
use mas_storage::{
Pagination, Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams,
user::UserEmailFilter,
};
use mas_storage::{Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams};
use oauth2_types::scope::{OPENID, Scope};
use rand_chacha::ChaChaRng;
use serde_json::Value;
use sqlx::PgPool;
use ulid::Ulid;

use super::UpstreamSessionsCookie;
use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
Expand Down Expand Up @@ -1188,33 +1248,41 @@ mod tests {
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::SEE_OTHER);
let location = response.headers().get(hyper::header::LOCATION).unwrap();
// Grab the registration ID from the redirected URL:
// /register/steps/{id}/finish
let registration_id: Ulid = str::from_utf8(location.as_bytes())
.unwrap()
.rsplit('/')
.nth(1)
.expect("Location to have two slashes")
.parse()
.expect("last segment of location to be a ULID");

// Check that we have a registered user, with the email imported
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.find_by_username("john")
.await
.unwrap()
.expect("user exists");

let link = repo
.upstream_oauth_link()
.find_by_subject(&provider, "subject")
let registration: UserRegistration = repo
.user_registration()
.lookup(registration_id)
.await
.unwrap()
.expect("link exists");
.expect("user registration exists");

assert_eq!(link.user_id, Some(user.id));
assert_eq!(registration.password, None);
assert_eq!(registration.completed_at, None);
assert_eq!(registration.username, "john");

let page = repo
let email_auth_id = registration
.email_authentication_id
.expect("registration should have an email authentication");
let email_auth: UserEmailAuthentication = repo
.user_email()
.list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
.lookup_authentication(email_auth_id)
.await
.unwrap();
let edge = page.edges.first().expect("email exists");

assert_eq!(edge.node.email, "john@example.com");
.unwrap()
.expect("email authentication should exist");
assert_eq!(email_auth.email, "john@example.com");
assert!(email_auth.completed_at.is_some());
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
Expand Down
2 changes: 2 additions & 0 deletions crates/handlers/src/views/register/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ mod cookie;
pub(crate) mod password;
pub(crate) mod steps;

pub use self::cookie::UserRegistrationSessions as UserRegistrationSessionsCookie;

#[tracing::instrument(name = "handlers.views.register.get", skip_all)]
pub(crate) async fn get(
mut rng: BoxRng,
Expand Down
Loading
Loading