mas_handlers/upstream_oauth2/
authorize.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{
8    extract::{Path, Query, State},
9    response::{IntoResponse, Redirect},
10};
11use hyper::StatusCode;
12use mas_axum_utils::{cookies::CookieJar, record_error};
13use mas_data_model::UpstreamOAuthProvider;
14use mas_oidc_client::requests::authorization_code::AuthorizationRequestData;
15use mas_router::{PostAuthAction, UrlBuilder};
16use mas_storage::{
17    BoxClock, BoxRepository, BoxRng,
18    upstream_oauth2::{UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository},
19};
20use thiserror::Error;
21use ulid::Ulid;
22
23use super::{UpstreamSessionsCookie, cache::LazyProviderInfos};
24use crate::{
25    impl_from_error_for_route, upstream_oauth2::cache::MetadataCache,
26    views::shared::OptionalPostAuthAction,
27};
28
29#[derive(Debug, Error)]
30pub(crate) enum RouteError {
31    #[error("Provider not found")]
32    ProviderNotFound,
33
34    #[error(transparent)]
35    Internal(Box<dyn std::error::Error>),
36}
37
38impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
39impl_from_error_for_route!(mas_oidc_client::error::AuthorizationError);
40impl_from_error_for_route!(mas_storage::RepositoryError);
41
42impl IntoResponse for RouteError {
43    fn into_response(self) -> axum::response::Response {
44        let sentry_event_id = record_error!(self, Self::Internal(_));
45        let response = match self {
46            Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
47            Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
48        };
49
50        (sentry_event_id, response).into_response()
51    }
52}
53
54#[tracing::instrument(
55    name = "handlers.upstream_oauth2.authorize.get",
56    fields(upstream_oauth_provider.id = %provider_id),
57    skip_all,
58)]
59pub(crate) async fn get(
60    mut rng: BoxRng,
61    clock: BoxClock,
62    State(metadata_cache): State<MetadataCache>,
63    mut repo: BoxRepository,
64    State(url_builder): State<UrlBuilder>,
65    State(http_client): State<reqwest::Client>,
66    cookie_jar: CookieJar,
67    Path(provider_id): Path<Ulid>,
68    Query(query): Query<OptionalPostAuthAction>,
69) -> Result<impl IntoResponse, RouteError> {
70    let provider = repo
71        .upstream_oauth_provider()
72        .lookup(provider_id)
73        .await?
74        .filter(UpstreamOAuthProvider::enabled)
75        .ok_or(RouteError::ProviderNotFound)?;
76
77    // First, discover the provider
78    // This is done lazyly according to provider.discovery_mode and the various
79    // endpoint overrides
80    let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_client);
81    lazy_metadata.maybe_discover().await?;
82
83    let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
84
85    let mut data = AuthorizationRequestData::new(
86        provider.client_id.clone(),
87        provider.scope.clone(),
88        redirect_uri,
89    );
90
91    if let Some(response_mode) = provider.response_mode {
92        data = data.with_response_mode(response_mode.into());
93    }
94
95    // Forward the raw login hint upstream for the provider to handle however it
96    // sees fit
97    if provider.forward_login_hint {
98        if let Some(PostAuthAction::ContinueAuthorizationGrant { id }) = &query.post_auth_action {
99            if let Some(login_hint) = repo
100                .oauth2_authorization_grant()
101                .lookup(*id)
102                .await?
103                .and_then(|grant| grant.login_hint)
104            {
105                data = data.with_login_hint(login_hint);
106            }
107        }
108    }
109
110    let data = if let Some(methods) = lazy_metadata.pkce_methods().await? {
111        data.with_code_challenge_methods_supported(methods)
112    } else {
113        data
114    };
115
116    // Build an authorization request for it
117    let (mut url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url(
118        lazy_metadata.authorization_endpoint().await?.clone(),
119        data,
120        &mut rng,
121    )?;
122
123    // We do that in a block because params borrows url mutably
124    {
125        // Add any additional parameters to the query
126        let mut params = url.query_pairs_mut();
127        for (key, value) in &provider.additional_authorization_parameters {
128            params.append_pair(key, value);
129        }
130    }
131
132    let session = repo
133        .upstream_oauth_session()
134        .add(
135            &mut rng,
136            &clock,
137            &provider,
138            data.state.clone(),
139            data.code_challenge_verifier,
140            data.nonce,
141        )
142        .await?;
143
144    let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
145        .add(session.id, provider.id, data.state, query.post_auth_action)
146        .save(cookie_jar, &clock);
147
148    repo.save().await?;
149
150    Ok((cookie_jar, Redirect::temporary(url.as_str())))
151}