mirror of
https://github.com/CloudNebulaProject/barycenter.git
synced 2026-04-10 13:10:42 +00:00
929 lines
38 KiB
Rust
929 lines
38 KiB
Rust
|
|
//! See docs/axum.md for axum usage patterns and router/state conventions.
|
||
|
|
//! See docs/oidc-conformance.md for OIDC OP requirements, current status, and roadmap.
|
||
|
|
//! This module exposes the HTTP endpoints, some of which are discovery and JWKS
|
||
|
|
//! as required by OpenID Connect. Authorization, token, and userinfo will follow
|
||
|
|
//! per the conformance plan.
|
||
|
|
use crate::jwks::JwksManager;
|
||
|
|
use crate::errors::CrabError;
|
||
|
|
use crate::settings::Settings;
|
||
|
|
use crate::storage;
|
||
|
|
use crate::session::SessionCookie;
|
||
|
|
use axum::extract::{Path, State, Query, Form};
|
||
|
|
use axum::http::{StatusCode, HeaderMap, HeaderName, HeaderValue, Request};
|
||
|
|
use axum::response::{IntoResponse, Html, Redirect};
|
||
|
|
use axum::routing::{get, post};
|
||
|
|
use axum::{Json, Router};
|
||
|
|
use axum::middleware::{self, Next};
|
||
|
|
use axum::body::Body;
|
||
|
|
use miette::IntoDiagnostic;
|
||
|
|
use sea_orm::DatabaseConnection;
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use serde_json::{json, Value};
|
||
|
|
use std::net::SocketAddr;
|
||
|
|
use std::sync::Arc;
|
||
|
|
use sha2::{Digest, Sha256};
|
||
|
|
use base64ct::{Base64UrlUnpadded, Encoding, Base64};
|
||
|
|
use axum::response::Response;
|
||
|
|
use std::time::SystemTime;
|
||
|
|
|
||
|
|
#[derive(Clone)]
|
||
|
|
pub struct AppState {
|
||
|
|
pub settings: Arc<Settings>,
|
||
|
|
pub db: DatabaseConnection,
|
||
|
|
pub jwks: JwksManager,
|
||
|
|
}
|
||
|
|
|
||
|
|
// Security headers middleware
|
||
|
|
async fn security_headers(request: Request<Body>, next: Next) -> impl IntoResponse {
|
||
|
|
let mut response = next.run(request).await;
|
||
|
|
let headers = response.headers_mut();
|
||
|
|
|
||
|
|
// X-Frame-Options: Prevent clickjacking
|
||
|
|
headers.insert(
|
||
|
|
HeaderName::from_static("x-frame-options"),
|
||
|
|
HeaderValue::from_static("DENY"),
|
||
|
|
);
|
||
|
|
|
||
|
|
// X-Content-Type-Options: Prevent MIME sniffing
|
||
|
|
headers.insert(
|
||
|
|
HeaderName::from_static("x-content-type-options"),
|
||
|
|
HeaderValue::from_static("nosniff"),
|
||
|
|
);
|
||
|
|
|
||
|
|
// X-XSS-Protection: Legacy XSS protection (still useful for older browsers)
|
||
|
|
headers.insert(
|
||
|
|
HeaderName::from_static("x-xss-protection"),
|
||
|
|
HeaderValue::from_static("1; mode=block"),
|
||
|
|
);
|
||
|
|
|
||
|
|
// Content-Security-Policy: Restrict resource loading
|
||
|
|
headers.insert(
|
||
|
|
HeaderName::from_static("content-security-policy"),
|
||
|
|
HeaderValue::from_static("default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; form-action 'self'"),
|
||
|
|
);
|
||
|
|
|
||
|
|
// Referrer-Policy: Control referrer information
|
||
|
|
headers.insert(
|
||
|
|
HeaderName::from_static("referrer-policy"),
|
||
|
|
HeaderValue::from_static("strict-origin-when-cross-origin"),
|
||
|
|
);
|
||
|
|
|
||
|
|
// Permissions-Policy: Disable unnecessary browser features
|
||
|
|
headers.insert(
|
||
|
|
HeaderName::from_static("permissions-policy"),
|
||
|
|
HeaderValue::from_static("geolocation=(), microphone=(), camera=()"),
|
||
|
|
);
|
||
|
|
|
||
|
|
response
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn serve(settings: Settings, db: DatabaseConnection, jwks: JwksManager) -> miette::Result<()> {
|
||
|
|
let state = AppState { settings: Arc::new(settings), db, jwks };
|
||
|
|
|
||
|
|
// NOTE: Rate limiting should be implemented at the reverse proxy level (nginx, traefik, etc.)
|
||
|
|
// for production deployments. This is more efficient and flexible than application-level
|
||
|
|
// rate limiting. Configure your reverse proxy with limits like:
|
||
|
|
// - Token endpoint: 10 req/min per IP
|
||
|
|
// - Login endpoint: 5 attempts/min per IP
|
||
|
|
// - Authorize endpoint: 20 req/min per IP
|
||
|
|
|
||
|
|
let router = Router::new()
|
||
|
|
.route("/.well-known/openid-configuration", get(discovery))
|
||
|
|
.route("/.well-known/jwks.json", get(jwks_handler))
|
||
|
|
.route("/connect/register", post(register_client))
|
||
|
|
.route("/properties/{owner}/{key}", get(get_property))
|
||
|
|
.route("/federation/trust-anchors", get(trust_anchors))
|
||
|
|
.route("/register", post(register_user))
|
||
|
|
.route("/login", get(login_page).post(login_submit))
|
||
|
|
.route("/logout", get(logout))
|
||
|
|
.route("/authorize", get(authorize))
|
||
|
|
.route("/token", post(token))
|
||
|
|
.route("/userinfo", get(userinfo))
|
||
|
|
.layer(middleware::from_fn(security_headers))
|
||
|
|
.with_state(state.clone());
|
||
|
|
|
||
|
|
let addr: SocketAddr = format!("{}:{}", state.settings.server.host, state.settings.server.port)
|
||
|
|
.parse()
|
||
|
|
.map_err(|e| miette::miette!("bad listen addr: {e}"))?;
|
||
|
|
tracing::info!(%addr, "listening");
|
||
|
|
tracing::warn!("Rate limiting should be configured at the reverse proxy level for production");
|
||
|
|
let listener = tokio::net::TcpListener::bind(addr).await.into_diagnostic()?;
|
||
|
|
axum::serve(listener, router).await.into_diagnostic()?;
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn discovery(State(state): State<AppState>) -> impl IntoResponse {
|
||
|
|
let issuer = state.settings.issuer();
|
||
|
|
let metadata = json!({
|
||
|
|
"issuer": issuer,
|
||
|
|
"authorization_endpoint": format!("{}/authorize", issuer),
|
||
|
|
"token_endpoint": format!("{}/token", issuer),
|
||
|
|
"jwks_uri": format!("{}/.well-known/jwks.json", issuer),
|
||
|
|
"registration_endpoint": format!("{}/connect/register", issuer),
|
||
|
|
"userinfo_endpoint": format!("{}/userinfo", issuer),
|
||
|
|
"scopes_supported": ["openid", "profile", "email", "offline_access"],
|
||
|
|
"response_types_supported": ["code", "id_token", "id_token token"],
|
||
|
|
"subject_types_supported": ["public"],
|
||
|
|
"id_token_signing_alg_values_supported": [state.settings.keys.alg],
|
||
|
|
// Additional recommended metadata for better interoperability
|
||
|
|
"grant_types_supported": ["authorization_code", "refresh_token", "implicit"],
|
||
|
|
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
||
|
|
"code_challenge_methods_supported": ["S256"],
|
||
|
|
"claims_supported": [
|
||
|
|
"sub", "iss", "aud", "exp", "iat", "auth_time", "nonce",
|
||
|
|
"name", "given_name", "family_name", "email", "email_verified"
|
||
|
|
],
|
||
|
|
// OIDC Core 1.0 features
|
||
|
|
"prompt_values_supported": ["none", "login", "consent", "select_account"],
|
||
|
|
"display_values_supported": ["page"],
|
||
|
|
"ui_locales_supported": ["en"],
|
||
|
|
"claim_types_supported": ["normal"],
|
||
|
|
});
|
||
|
|
Json(metadata)
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn jwks_handler(State(state): State<AppState>) -> impl IntoResponse {
|
||
|
|
Json(state.jwks.jwks_json())
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Deserialize)]
|
||
|
|
struct AuthorizeQuery {
|
||
|
|
client_id: String,
|
||
|
|
redirect_uri: String,
|
||
|
|
response_type: String,
|
||
|
|
scope: String,
|
||
|
|
state: Option<String>,
|
||
|
|
nonce: Option<String>,
|
||
|
|
code_challenge: Option<String>,
|
||
|
|
code_challenge_method: Option<String>,
|
||
|
|
prompt: Option<String>,
|
||
|
|
display: Option<String>,
|
||
|
|
ui_locales: Option<String>,
|
||
|
|
claims_locales: Option<String>,
|
||
|
|
max_age: Option<String>,
|
||
|
|
acr_values: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
fn url_append_query(mut base: String, params: &[(&str, String)]) -> String {
|
||
|
|
let qs = serde_urlencoded::to_string(
|
||
|
|
params.iter().map(|(k, v)| (k.to_string(), v.clone())).collect::<Vec<(String, String)>>(),
|
||
|
|
).unwrap_or_default();
|
||
|
|
if base.contains('?') { base.push('&'); } else { base.push('?'); }
|
||
|
|
base.push_str(&qs);
|
||
|
|
base
|
||
|
|
}
|
||
|
|
|
||
|
|
fn oauth_error_redirect(redirect_uri: &str, state: Option<&str>, error: &str, desc: &str) -> axum::response::Redirect {
|
||
|
|
let mut params = vec![("error", error.to_string())];
|
||
|
|
if !desc.is_empty() { params.push(("error_description", desc.to_string())); }
|
||
|
|
if let Some(s) = state { params.push(("state", s.to_string())); }
|
||
|
|
let loc = url_append_query(redirect_uri.to_string(), ¶ms);
|
||
|
|
axum::response::Redirect::temporary(&loc)
|
||
|
|
}
|
||
|
|
|
||
|
|
// OIDC-specific error codes per OpenID Connect Core 1.0 Section 3.1.2.6
|
||
|
|
// login_required: Authentication is required but prompt=none was specified
|
||
|
|
// consent_required: Consent is required but prompt=none was specified
|
||
|
|
// interaction_required: User interaction is required but prompt=none was specified
|
||
|
|
// account_selection_required: Account selection is required but prompt=none was specified
|
||
|
|
fn oidc_error_redirect(redirect_uri: &str, state: Option<&str>, error: &str) -> axum::response::Redirect {
|
||
|
|
oauth_error_redirect(redirect_uri, state, error, "")
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn authorize(State(state): State<AppState>, headers: HeaderMap, Query(q): Query<AuthorizeQuery>) -> impl IntoResponse {
|
||
|
|
// Validate response_type - support code, id_token, and id_token token
|
||
|
|
let valid_response_types = ["code", "id_token", "id_token token"];
|
||
|
|
if !valid_response_types.contains(&q.response_type.as_str()) {
|
||
|
|
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "unsupported_response_type",
|
||
|
|
"only response_type=code, id_token, or 'id_token token' supported").into_response();
|
||
|
|
}
|
||
|
|
// Validate scope includes openid
|
||
|
|
if !q.scope.split_whitespace().any(|s| s == "openid") {
|
||
|
|
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "invalid_scope", "scope must include openid").into_response();
|
||
|
|
}
|
||
|
|
// Require PKCE S256
|
||
|
|
let (code_challenge, ccm) = match (&q.code_challenge, &q.code_challenge_method) {
|
||
|
|
(Some(cc), Some(m)) if m == "S256" => (cc.clone(), m.clone()),
|
||
|
|
_ => {
|
||
|
|
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "invalid_request", "PKCE (S256) required").into_response();
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
// Lookup client
|
||
|
|
let client = match storage::get_client(&state.db, &q.client_id).await {
|
||
|
|
Ok(Some(c)) => c,
|
||
|
|
Ok(None) => return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "unauthorized_client", "unknown client_id").into_response(),
|
||
|
|
Err(_) => return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "server_error", "db error").into_response(),
|
||
|
|
};
|
||
|
|
// Validate redirect_uri exact match
|
||
|
|
if !client.redirect_uris.iter().any(|u| u == &q.redirect_uri) {
|
||
|
|
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "invalid_request", "redirect_uri mismatch").into_response();
|
||
|
|
}
|
||
|
|
|
||
|
|
// Parse prompt parameter (can be space-separated list)
|
||
|
|
let prompt_values: Vec<&str> = q.prompt.as_ref()
|
||
|
|
.map(|p| p.split_whitespace().collect())
|
||
|
|
.unwrap_or_default();
|
||
|
|
|
||
|
|
let has_prompt_none = prompt_values.contains(&"none");
|
||
|
|
let has_prompt_login = prompt_values.contains(&"login");
|
||
|
|
let has_prompt_select_account = prompt_values.contains(&"select_account");
|
||
|
|
|
||
|
|
// Check for existing session (but ignore if prompt=login or prompt=select_account)
|
||
|
|
let session_opt = if has_prompt_login || has_prompt_select_account {
|
||
|
|
None // Force re-authentication
|
||
|
|
} else if let Some(cookie) = SessionCookie::from_headers(&headers) {
|
||
|
|
storage::get_session(&state.db, &cookie.session_id).await.ok().flatten()
|
||
|
|
} else {
|
||
|
|
None
|
||
|
|
};
|
||
|
|
|
||
|
|
// Handle max_age parameter - requires re-authentication if session is too old
|
||
|
|
let needs_fresh_auth = if let Some(max_age_str) = &q.max_age {
|
||
|
|
if let Ok(max_age) = max_age_str.parse::<i64>() {
|
||
|
|
if let Some(ref sess) = session_opt {
|
||
|
|
let age = chrono::Utc::now().timestamp() - sess.auth_time;
|
||
|
|
age > max_age
|
||
|
|
} else {
|
||
|
|
false
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
false
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
false
|
||
|
|
};
|
||
|
|
|
||
|
|
let (subject, auth_time) = match session_opt {
|
||
|
|
Some(sess) if sess.expires_at > chrono::Utc::now().timestamp() && !needs_fresh_auth => {
|
||
|
|
(sess.subject.clone(), Some(sess.auth_time))
|
||
|
|
}
|
||
|
|
_ => {
|
||
|
|
// No valid session or session too old
|
||
|
|
// If prompt=none, return error instead of redirecting
|
||
|
|
if has_prompt_none {
|
||
|
|
return oidc_error_redirect(&q.redirect_uri, q.state.as_deref(), "login_required").into_response();
|
||
|
|
}
|
||
|
|
|
||
|
|
// Build return_to URL with all parameters
|
||
|
|
let mut return_params = vec![
|
||
|
|
("client_id", q.client_id.clone()),
|
||
|
|
("redirect_uri", q.redirect_uri.clone()),
|
||
|
|
("response_type", q.response_type.clone()),
|
||
|
|
("scope", q.scope.clone()),
|
||
|
|
("code_challenge", code_challenge.clone()),
|
||
|
|
("code_challenge_method", ccm.clone()),
|
||
|
|
];
|
||
|
|
if let Some(s) = &q.state {
|
||
|
|
return_params.push(("state", s.clone()));
|
||
|
|
}
|
||
|
|
if let Some(n) = &q.nonce {
|
||
|
|
return_params.push(("nonce", n.clone()));
|
||
|
|
}
|
||
|
|
if let Some(p) = &q.prompt {
|
||
|
|
return_params.push(("prompt", p.clone()));
|
||
|
|
}
|
||
|
|
if let Some(d) = &q.display {
|
||
|
|
return_params.push(("display", d.clone()));
|
||
|
|
}
|
||
|
|
if let Some(ui) = &q.ui_locales {
|
||
|
|
return_params.push(("ui_locales", ui.clone()));
|
||
|
|
}
|
||
|
|
if let Some(cl) = &q.claims_locales {
|
||
|
|
return_params.push(("claims_locales", cl.clone()));
|
||
|
|
}
|
||
|
|
if let Some(ma) = &q.max_age {
|
||
|
|
return_params.push(("max_age", ma.clone()));
|
||
|
|
}
|
||
|
|
if let Some(acr) = &q.acr_values {
|
||
|
|
return_params.push(("acr_values", acr.clone()));
|
||
|
|
}
|
||
|
|
|
||
|
|
let return_to = url_append_query("/authorize".to_string(),
|
||
|
|
&return_params.iter().map(|(k, v)| (*k, v.clone())).collect::<Vec<_>>());
|
||
|
|
let login_url = format!("/login?return_to={}", urlencoded(&return_to));
|
||
|
|
return Redirect::temporary(&login_url).into_response();
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
let scope = q.scope.clone();
|
||
|
|
let nonce = q.nonce.clone();
|
||
|
|
|
||
|
|
// Handle different response types
|
||
|
|
match q.response_type.as_str() {
|
||
|
|
"code" => {
|
||
|
|
// Authorization Code Flow - issue auth code
|
||
|
|
let ttl = 300; // 5 minutes
|
||
|
|
match storage::issue_auth_code(&state.db, &q.client_id, &q.redirect_uri, &scope, &subject, nonce, &code_challenge, &ccm, ttl, auth_time).await {
|
||
|
|
Ok(code) => {
|
||
|
|
let mut params = vec![("code", code.code.clone())];
|
||
|
|
if let Some(s) = &q.state { params.push(("state", s.clone())); }
|
||
|
|
let loc = url_append_query(q.redirect_uri.clone(), ¶ms);
|
||
|
|
axum::response::Redirect::temporary(&loc).into_response()
|
||
|
|
}
|
||
|
|
Err(_) => oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "server_error", "could not issue code").into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
"id_token" => {
|
||
|
|
// Implicit Flow - return ID token in fragment
|
||
|
|
// Require nonce for implicit flow (OIDC Core 1.0 Section 3.2.2.1)
|
||
|
|
if nonce.is_none() {
|
||
|
|
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "invalid_request", "nonce required for implicit flow").into_response();
|
||
|
|
}
|
||
|
|
|
||
|
|
match build_id_token(&state, &q.client_id, &subject, nonce.as_deref(), auth_time, None).await {
|
||
|
|
Ok(id_token) => {
|
||
|
|
let mut fragment_params = vec![("id_token", id_token)];
|
||
|
|
if let Some(s) = &q.state {
|
||
|
|
fragment_params.push(("state", s.clone()));
|
||
|
|
}
|
||
|
|
let fragment = serde_urlencoded::to_string(&fragment_params).unwrap_or_default();
|
||
|
|
let loc = format!("{}#{}", q.redirect_uri, fragment);
|
||
|
|
axum::response::Redirect::temporary(&loc).into_response()
|
||
|
|
}
|
||
|
|
Err(_) => oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "server_error", "could not generate id_token").into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
"id_token token" => {
|
||
|
|
// Implicit Flow - return both ID token and access token in fragment
|
||
|
|
// Require nonce for implicit flow
|
||
|
|
if nonce.is_none() {
|
||
|
|
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "invalid_request", "nonce required for implicit flow").into_response();
|
||
|
|
}
|
||
|
|
|
||
|
|
// Issue access token
|
||
|
|
let access = match storage::issue_access_token(&state.db, &q.client_id, &subject, &scope, 3600).await {
|
||
|
|
Ok(t) => t,
|
||
|
|
Err(_) => return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "server_error", "could not issue access token").into_response(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Build ID token with at_hash
|
||
|
|
match build_id_token(&state, &q.client_id, &subject, nonce.as_deref(), auth_time, Some(&access.token)).await {
|
||
|
|
Ok(id_token) => {
|
||
|
|
let mut fragment_params = vec![
|
||
|
|
("access_token", access.token),
|
||
|
|
("token_type", "bearer".to_string()),
|
||
|
|
("expires_in", "3600".to_string()),
|
||
|
|
("id_token", id_token),
|
||
|
|
];
|
||
|
|
if let Some(s) = &q.state {
|
||
|
|
fragment_params.push(("state", s.clone()));
|
||
|
|
}
|
||
|
|
let fragment = serde_urlencoded::to_string(&fragment_params).unwrap_or_default();
|
||
|
|
let loc = format!("{}#{}", q.redirect_uri, fragment);
|
||
|
|
axum::response::Redirect::temporary(&loc).into_response()
|
||
|
|
}
|
||
|
|
Err(_) => oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "server_error", "could not generate id_token").into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
_ => oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "unsupported_response_type", "invalid response_type").into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Helper function to build ID token
|
||
|
|
async fn build_id_token(
|
||
|
|
state: &AppState,
|
||
|
|
client_id: &str,
|
||
|
|
subject: &str,
|
||
|
|
nonce: Option<&str>,
|
||
|
|
auth_time: Option<i64>,
|
||
|
|
access_token: Option<&str>, // For at_hash calculation
|
||
|
|
) -> Result<String, CrabError> {
|
||
|
|
let now = SystemTime::now();
|
||
|
|
let exp_unix = now.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64 + 3600;
|
||
|
|
|
||
|
|
let mut payload = josekit::jwt::JwtPayload::new();
|
||
|
|
payload.set_issuer(state.settings.issuer());
|
||
|
|
payload.set_subject(subject.to_string());
|
||
|
|
payload.set_audience(vec![client_id.to_string()]);
|
||
|
|
payload.set_issued_at(&now);
|
||
|
|
let _ = payload.set_claim("exp", Some(serde_json::json!(exp_unix)));
|
||
|
|
|
||
|
|
if let Some(n) = nonce {
|
||
|
|
let _ = payload.set_claim("nonce", Some(serde_json::Value::String(n.to_string())));
|
||
|
|
}
|
||
|
|
|
||
|
|
if let Some(at) = auth_time {
|
||
|
|
let _ = payload.set_claim("auth_time", Some(serde_json::json!(at)));
|
||
|
|
}
|
||
|
|
|
||
|
|
// Add at_hash if access_token is provided (for id_token token response type)
|
||
|
|
if let Some(token) = access_token {
|
||
|
|
let mut hasher = Sha256::new();
|
||
|
|
hasher.update(token.as_bytes());
|
||
|
|
let digest = hasher.finalize();
|
||
|
|
let half = &digest[..16]; // left-most 128 bits
|
||
|
|
let at_hash = Base64UrlUnpadded::encode_string(half);
|
||
|
|
let _ = payload.set_claim("at_hash", Some(serde_json::Value::String(at_hash)));
|
||
|
|
}
|
||
|
|
|
||
|
|
state.jwks.sign_jwt_rs256(&payload).map_err(|e| CrabError::Other(e.to_string()))
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Deserialize)]
|
||
|
|
struct TokenRequest {
|
||
|
|
grant_type: String,
|
||
|
|
code: Option<String>,
|
||
|
|
redirect_uri: Option<String>,
|
||
|
|
client_id: Option<String>,
|
||
|
|
client_secret: Option<String>,
|
||
|
|
code_verifier: Option<String>,
|
||
|
|
refresh_token: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Serialize)]
|
||
|
|
struct TokenResponse {
|
||
|
|
access_token: String,
|
||
|
|
token_type: String,
|
||
|
|
expires_in: i64,
|
||
|
|
id_token: Option<String>,
|
||
|
|
refresh_token: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
fn pkce_s256(verifier: &str) -> String {
|
||
|
|
let mut hasher = Sha256::new();
|
||
|
|
hasher.update(verifier.as_bytes());
|
||
|
|
let digest = hasher.finalize();
|
||
|
|
// Take full digest then base64url without padding
|
||
|
|
Base64UrlUnpadded::encode_string(&digest)
|
||
|
|
}
|
||
|
|
|
||
|
|
fn json_with_headers(status: StatusCode, value: Value, headers: &[(&str, String)]) -> Response {
|
||
|
|
let mut resp = (status, Json(value)).into_response();
|
||
|
|
let h = resp.headers_mut();
|
||
|
|
for (name, val) in headers {
|
||
|
|
if let (Ok(n), Ok(v)) = (HeaderName::from_bytes(name.as_bytes()), HeaderValue::from_str(val)) {
|
||
|
|
h.insert(n, v);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
resp
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn token(State(state): State<AppState>, headers: HeaderMap, Form(req): Form<TokenRequest>) -> impl IntoResponse {
|
||
|
|
// Validate grant_type
|
||
|
|
match req.grant_type.as_str() {
|
||
|
|
"authorization_code" => handle_authorization_code_grant(state, headers, req).await,
|
||
|
|
"refresh_token" => handle_refresh_token_grant(state, headers, req).await,
|
||
|
|
_ => (StatusCode::BAD_REQUEST, Json(json!({"error":"unsupported_grant_type"}))).into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn handle_authorization_code_grant(state: AppState, headers: HeaderMap, req: TokenRequest) -> Response {
|
||
|
|
// Client authentication: client_secret_basic preferred, then client_secret_post
|
||
|
|
let (client_id, client_secret) = match authenticate_client(&headers, &req) {
|
||
|
|
Ok(pair) => pair,
|
||
|
|
Err(resp) => return resp,
|
||
|
|
};
|
||
|
|
|
||
|
|
let client = match storage::get_client(&state.db, &client_id).await {
|
||
|
|
Ok(Some(c)) => c,
|
||
|
|
_ => {
|
||
|
|
return json_with_headers(
|
||
|
|
StatusCode::UNAUTHORIZED,
|
||
|
|
json!({"error":"invalid_client"}),
|
||
|
|
&[("www-authenticate", "Basic realm=\"token\", error=\"invalid_client\"".to_string())],
|
||
|
|
)
|
||
|
|
}
|
||
|
|
};
|
||
|
|
if client.client_secret != client_secret {
|
||
|
|
return json_with_headers(
|
||
|
|
StatusCode::UNAUTHORIZED,
|
||
|
|
json!({"error":"invalid_client"}),
|
||
|
|
&[("www-authenticate", "Basic realm=\"token\", error=\"invalid_client\"".to_string())],
|
||
|
|
);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Require code
|
||
|
|
let code = match req.code {
|
||
|
|
Some(c) => c,
|
||
|
|
None => return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_request","error_description":"code required"}))).into_response(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Consume code
|
||
|
|
let code_row = match storage::consume_auth_code(&state.db, &code).await {
|
||
|
|
Ok(Some(c)) => c,
|
||
|
|
_ => return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_grant"}))).into_response(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Validate code binding
|
||
|
|
let redirect_uri = req.redirect_uri.unwrap_or_default();
|
||
|
|
if code_row.client_id != client_id || code_row.redirect_uri != redirect_uri {
|
||
|
|
return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_grant"}))).into_response();
|
||
|
|
}
|
||
|
|
|
||
|
|
// Validate PKCE S256
|
||
|
|
let verifier = match &req.code_verifier {
|
||
|
|
Some(v) => v,
|
||
|
|
None => return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_request","error_description":"code_verifier required"}))).into_response(),
|
||
|
|
};
|
||
|
|
if code_row.code_challenge_method != "S256" || pkce_s256(verifier) != code_row.code_challenge {
|
||
|
|
return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_grant","error_description":"pkce verification failed"}))).into_response();
|
||
|
|
}
|
||
|
|
|
||
|
|
// Issue access token
|
||
|
|
let access = match storage::issue_access_token(&state.db, &client_id, &code_row.subject, &code_row.scope, 3600).await {
|
||
|
|
Ok(t) => t,
|
||
|
|
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error":"server_error","details":e.to_string()}))).into_response(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Build ID Token using helper function
|
||
|
|
let id_token = match build_id_token(&state, &client_id, &code_row.subject, code_row.nonce.as_deref(), code_row.auth_time, Some(&access.token)).await {
|
||
|
|
Ok(t) => t,
|
||
|
|
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error":"server_error","details":e.to_string()}))).into_response(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Issue refresh token if offline_access scope was requested
|
||
|
|
let refresh_token = if code_row.scope.split_whitespace().any(|s| s == "offline_access") {
|
||
|
|
match storage::issue_refresh_token(&state.db, &client_id, &code_row.subject, &code_row.scope, 2592000, None).await {
|
||
|
|
Ok(rt) => Some(rt.token),
|
||
|
|
Err(_) => None, // Don't fail the whole request if refresh token issuance fails
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
None
|
||
|
|
};
|
||
|
|
|
||
|
|
let resp = TokenResponse {
|
||
|
|
access_token: access.token,
|
||
|
|
token_type: "bearer".into(),
|
||
|
|
expires_in: 3600,
|
||
|
|
id_token: Some(id_token),
|
||
|
|
refresh_token,
|
||
|
|
};
|
||
|
|
|
||
|
|
// Add Cache-Control: no-store as required by OAuth 2.0 and OIDC specs
|
||
|
|
json_with_headers(
|
||
|
|
StatusCode::OK,
|
||
|
|
serde_json::to_value(resp).unwrap(),
|
||
|
|
&[
|
||
|
|
("cache-control", "no-store".to_string()),
|
||
|
|
("pragma", "no-cache".to_string()),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn handle_refresh_token_grant(state: AppState, headers: HeaderMap, req: TokenRequest) -> Response {
|
||
|
|
// Client authentication
|
||
|
|
let (client_id, client_secret) = match authenticate_client(&headers, &req) {
|
||
|
|
Ok(pair) => pair,
|
||
|
|
Err(resp) => return resp,
|
||
|
|
};
|
||
|
|
|
||
|
|
let client = match storage::get_client(&state.db, &client_id).await {
|
||
|
|
Ok(Some(c)) => c,
|
||
|
|
_ => {
|
||
|
|
return json_with_headers(
|
||
|
|
StatusCode::UNAUTHORIZED,
|
||
|
|
json!({"error":"invalid_client"}),
|
||
|
|
&[("www-authenticate", "Basic realm=\"token\", error=\"invalid_client\"".to_string())],
|
||
|
|
)
|
||
|
|
}
|
||
|
|
};
|
||
|
|
if client.client_secret != client_secret {
|
||
|
|
return json_with_headers(
|
||
|
|
StatusCode::UNAUTHORIZED,
|
||
|
|
json!({"error":"invalid_client"}),
|
||
|
|
&[("www-authenticate", "Basic realm=\"token\", error=\"invalid_client\"".to_string())],
|
||
|
|
);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Require refresh_token
|
||
|
|
let refresh_token_str = match req.refresh_token {
|
||
|
|
Some(rt) => rt,
|
||
|
|
None => return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_request","error_description":"refresh_token required"}))).into_response(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Get and validate refresh token
|
||
|
|
let refresh_token = match storage::get_refresh_token(&state.db, &refresh_token_str).await {
|
||
|
|
Ok(Some(rt)) => rt,
|
||
|
|
_ => return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_grant","error_description":"invalid refresh_token"}))).into_response(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Validate client_id matches
|
||
|
|
if refresh_token.client_id != client_id {
|
||
|
|
return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_grant"}))).into_response();
|
||
|
|
}
|
||
|
|
|
||
|
|
// Issue new access token
|
||
|
|
let access = match storage::issue_access_token(&state.db, &client_id, &refresh_token.subject, &refresh_token.scope, 3600).await {
|
||
|
|
Ok(t) => t,
|
||
|
|
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error":"server_error","details":e.to_string()}))).into_response(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Build ID Token (no nonce, no auth_time for refresh grants)
|
||
|
|
let id_token = match build_id_token(&state, &client_id, &refresh_token.subject, None, None, Some(&access.token)).await {
|
||
|
|
Ok(t) => t,
|
||
|
|
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error":"server_error","details":e.to_string()}))).into_response(),
|
||
|
|
};
|
||
|
|
|
||
|
|
// Rotate refresh token (issue new one and revoke old one)
|
||
|
|
let new_refresh_token = match storage::rotate_refresh_token(&state.db, &refresh_token_str, &client_id, &refresh_token.subject, &refresh_token.scope, 2592000).await {
|
||
|
|
Ok(rt) => Some(rt.token),
|
||
|
|
Err(_) => None, // Don't fail the whole request if rotation fails
|
||
|
|
};
|
||
|
|
|
||
|
|
let resp = TokenResponse {
|
||
|
|
access_token: access.token,
|
||
|
|
token_type: "bearer".into(),
|
||
|
|
expires_in: 3600,
|
||
|
|
id_token: Some(id_token),
|
||
|
|
refresh_token: new_refresh_token,
|
||
|
|
};
|
||
|
|
|
||
|
|
// Add Cache-Control: no-store as required by OAuth 2.0 and OIDC specs
|
||
|
|
json_with_headers(
|
||
|
|
StatusCode::OK,
|
||
|
|
serde_json::to_value(resp).unwrap(),
|
||
|
|
&[
|
||
|
|
("cache-control", "no-store".to_string()),
|
||
|
|
("pragma", "no-cache".to_string()),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
fn authenticate_client(headers: &HeaderMap, req: &TokenRequest) -> Result<(String, String), Response> {
|
||
|
|
// Try client_secret_basic first (Authorization header)
|
||
|
|
let mut basic_client: Option<(String, String)> = None;
|
||
|
|
if let Some(auth_val) = headers.get(axum::http::header::AUTHORIZATION).and_then(|h| h.to_str().ok()) {
|
||
|
|
if let Some(b64) = auth_val.strip_prefix("Basic ") {
|
||
|
|
if let Ok(mut decoded) = Base64::decode_vec(b64) {
|
||
|
|
if let Ok(s) = String::from_utf8(std::mem::take(&mut decoded)) {
|
||
|
|
if let Some((id, sec)) = s.split_once(':') {
|
||
|
|
basic_client = Some((id.to_string(), sec.to_string()));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if let Some(pair) = basic_client {
|
||
|
|
Ok(pair)
|
||
|
|
} else {
|
||
|
|
// Try client_secret_post (form body)
|
||
|
|
match (req.client_id.clone(), req.client_secret.clone()) {
|
||
|
|
(Some(id), Some(sec)) => Ok((id, sec)),
|
||
|
|
_ => {
|
||
|
|
Err(json_with_headers(
|
||
|
|
StatusCode::UNAUTHORIZED,
|
||
|
|
json!({"error":"invalid_client","error_description":"missing client authentication"}),
|
||
|
|
&[("www-authenticate", "Basic realm=\"token\", error=\"invalid_client\"".to_string())],
|
||
|
|
))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn userinfo(State(state): State<AppState>, req: axum::http::Request<axum::body::Body>) -> impl IntoResponse {
|
||
|
|
// Extract bearer token
|
||
|
|
let token_opt = req.headers().get(axum::http::header::AUTHORIZATION).and_then(|h| h.to_str().ok()).and_then(|s| s.strip_prefix("Bearer ")).map(|s| s.to_string());
|
||
|
|
let token = match token_opt {
|
||
|
|
Some(t) => t,
|
||
|
|
None => {
|
||
|
|
return json_with_headers(
|
||
|
|
StatusCode::UNAUTHORIZED,
|
||
|
|
json!({"error":"invalid_token"}),
|
||
|
|
&[("www-authenticate", "Bearer realm=\"userinfo\", error=\"invalid_token\"".to_string())],
|
||
|
|
)
|
||
|
|
}
|
||
|
|
};
|
||
|
|
let token_row = match storage::get_access_token(&state.db, &token).await {
|
||
|
|
Ok(Some(t)) => t,
|
||
|
|
_ => {
|
||
|
|
return json_with_headers(
|
||
|
|
StatusCode::UNAUTHORIZED,
|
||
|
|
json!({"error":"invalid_token"}),
|
||
|
|
&[("www-authenticate", "Bearer realm=\"userinfo\", error=\"invalid_token\"".to_string())],
|
||
|
|
)
|
||
|
|
}
|
||
|
|
};
|
||
|
|
let mut claims = serde_json::Map::new();
|
||
|
|
claims.insert("sub".to_string(), serde_json::Value::String(token_row.subject.clone()));
|
||
|
|
// Optional: email claims from properties
|
||
|
|
if let Ok(Some(email)) = storage::get_property(&state.db, &token_row.subject, "email").await {
|
||
|
|
if let Some(email_str) = email.as_str() {
|
||
|
|
claims.insert("email".to_string(), serde_json::Value::String(email_str.to_string()));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if let Ok(Some(verified)) = storage::get_property(&state.db, &token_row.subject, "email_verified").await {
|
||
|
|
claims.insert("email_verified".to_string(), verified);
|
||
|
|
}
|
||
|
|
(StatusCode::OK, Json(serde_json::Value::Object(claims))).into_response()
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Deserialize)]
|
||
|
|
struct RegistrationRequest {
|
||
|
|
client_name: Option<String>,
|
||
|
|
redirect_uris: Vec<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Serialize)]
|
||
|
|
struct RegistrationResponse {
|
||
|
|
client_id: String,
|
||
|
|
client_secret: String,
|
||
|
|
client_name: Option<String>,
|
||
|
|
redirect_uris: Vec<String>,
|
||
|
|
client_id_issued_at: i64,
|
||
|
|
token_endpoint_auth_method: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn register_client(State(state): State<AppState>, Json(req): Json<RegistrationRequest>) -> impl IntoResponse {
|
||
|
|
if req.redirect_uris.is_empty() {
|
||
|
|
return (StatusCode::BAD_REQUEST, Json(json!({"error": "invalid_client_metadata", "error_description": "redirect_uris required"}))).into_response();
|
||
|
|
}
|
||
|
|
let input = storage::NewClient { client_name: req.client_name.clone(), redirect_uris: req.redirect_uris.clone() };
|
||
|
|
match storage::create_client(&state.db, input).await {
|
||
|
|
Ok(client) => {
|
||
|
|
let resp = RegistrationResponse {
|
||
|
|
client_id: client.client_id,
|
||
|
|
client_secret: client.client_secret,
|
||
|
|
client_name: client.client_name,
|
||
|
|
redirect_uris: client.redirect_uris,
|
||
|
|
client_id_issued_at: client.created_at,
|
||
|
|
token_endpoint_auth_method: "client_secret_post".into(),
|
||
|
|
};
|
||
|
|
(StatusCode::CREATED, Json(serde_json::to_value(resp).unwrap())).into_response()
|
||
|
|
}
|
||
|
|
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn get_property(State(state): State<AppState>, Path((owner, key)): Path<(String, String)>) -> impl IntoResponse {
|
||
|
|
match storage::get_property(&state.db, &owner, &key).await {
|
||
|
|
Ok(Some(v)) => (StatusCode::OK, Json(v)).into_response(),
|
||
|
|
Ok(None) => (StatusCode::NOT_FOUND, Json(json!({"error": "not_found"}))).into_response(),
|
||
|
|
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn set_property(State(state): State<AppState>, Path((owner, key)): Path<(String, String)>, Json(v): Json<Value>) -> impl IntoResponse {
|
||
|
|
match storage::set_property(&state.db, &owner, &key, &v).await {
|
||
|
|
Ok(_) => (StatusCode::NO_CONTENT, ()).into_response(),
|
||
|
|
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn trust_anchors(State(state): State<AppState>) -> impl IntoResponse {
|
||
|
|
Json(json!({ "trust_anchors": state.settings.federation.trust_anchors }))
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Deserialize)]
|
||
|
|
struct LoginQuery {
|
||
|
|
return_to: Option<String>,
|
||
|
|
error: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn login_page(Query(q): Query<LoginQuery>) -> impl IntoResponse {
|
||
|
|
let error_html = if let Some(err) = q.error {
|
||
|
|
format!("<p style='color: red;'>{}</p>", html_escape(&err))
|
||
|
|
} else {
|
||
|
|
String::new()
|
||
|
|
};
|
||
|
|
|
||
|
|
let return_to = q.return_to.unwrap_or_default();
|
||
|
|
|
||
|
|
let html = format!(r#"
|
||
|
|
<!DOCTYPE html>
|
||
|
|
<html>
|
||
|
|
<head>
|
||
|
|
<title>Login - Barycenter OpenID Provider</title>
|
||
|
|
<style>
|
||
|
|
body {{ font-family: Arial, sans-serif; max-width: 400px; margin: 100px auto; padding: 20px; }}
|
||
|
|
h1 {{ color: #333; }}
|
||
|
|
label {{ display: block; margin-top: 10px; }}
|
||
|
|
input[type="text"], input[type="password"] {{ width: 100%; padding: 8px; margin-top: 5px; box-sizing: border-box; }}
|
||
|
|
button {{ margin-top: 20px; padding: 10px 20px; background-color: #007bff; color: white; border: none; cursor: pointer; }}
|
||
|
|
button:hover {{ background-color: #0056b3; }}
|
||
|
|
</style>
|
||
|
|
</head>
|
||
|
|
<body>
|
||
|
|
<h1>Login</h1>
|
||
|
|
{error_html}
|
||
|
|
<form method="POST" action="/login">
|
||
|
|
<input type="hidden" name="return_to" value="{return_to}">
|
||
|
|
<label>
|
||
|
|
Username:
|
||
|
|
<input type="text" name="username" required autofocus>
|
||
|
|
</label>
|
||
|
|
<label>
|
||
|
|
Password:
|
||
|
|
<input type="password" name="password" required>
|
||
|
|
</label>
|
||
|
|
<button type="submit">Login</button>
|
||
|
|
</form>
|
||
|
|
</body>
|
||
|
|
</html>
|
||
|
|
"#);
|
||
|
|
|
||
|
|
Html(html)
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Deserialize)]
|
||
|
|
struct LoginForm {
|
||
|
|
username: String,
|
||
|
|
password: String,
|
||
|
|
return_to: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Deserialize)]
|
||
|
|
struct RegisterForm {
|
||
|
|
username: String,
|
||
|
|
password: String,
|
||
|
|
email: Option<String>,
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn register_user(
|
||
|
|
State(state): State<AppState>,
|
||
|
|
Form(form): Form<RegisterForm>,
|
||
|
|
) -> impl IntoResponse {
|
||
|
|
// Create the user
|
||
|
|
match storage::create_user(&state.db, &form.username, &form.password, form.email).await {
|
||
|
|
Ok(_) => {
|
||
|
|
// Return success response
|
||
|
|
Response::builder()
|
||
|
|
.status(StatusCode::CREATED)
|
||
|
|
.body(Body::from("User created"))
|
||
|
|
.unwrap()
|
||
|
|
.into_response()
|
||
|
|
}
|
||
|
|
Err(e) => {
|
||
|
|
// Return error response
|
||
|
|
Response::builder()
|
||
|
|
.status(StatusCode::BAD_REQUEST)
|
||
|
|
.body(Body::from(format!("Failed to create user: {}", e)))
|
||
|
|
.unwrap()
|
||
|
|
.into_response()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn login_submit(
|
||
|
|
State(state): State<AppState>,
|
||
|
|
headers: HeaderMap,
|
||
|
|
Form(form): Form<LoginForm>,
|
||
|
|
) -> impl IntoResponse {
|
||
|
|
// Verify credentials
|
||
|
|
let subject = match storage::verify_user_password(&state.db, &form.username, &form.password).await {
|
||
|
|
Ok(Some(sub)) => sub,
|
||
|
|
_ => {
|
||
|
|
// Redirect back to login with error
|
||
|
|
let return_to = urlencoded(&form.return_to.unwrap_or_default());
|
||
|
|
let error = urlencoded("Invalid username or password");
|
||
|
|
return Redirect::temporary(&format!("/login?error={error}&return_to={return_to}")).into_response();
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
// Create session
|
||
|
|
let user_agent = headers
|
||
|
|
.get(axum::http::header::USER_AGENT)
|
||
|
|
.and_then(|h| h.to_str().ok())
|
||
|
|
.map(String::from);
|
||
|
|
|
||
|
|
let session = match storage::create_session(&state.db, &subject, 3600, user_agent, None).await {
|
||
|
|
Ok(s) => s,
|
||
|
|
Err(_) => {
|
||
|
|
let return_to = urlencoded(&form.return_to.unwrap_or_default());
|
||
|
|
let error = urlencoded("Failed to create session");
|
||
|
|
return Redirect::temporary(&format!("/login?error={error}&return_to={return_to}")).into_response();
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
// Set cookie and redirect
|
||
|
|
let cookie = SessionCookie::new(session.session_id);
|
||
|
|
let cookie_header = cookie.to_cookie_header(&state.settings);
|
||
|
|
|
||
|
|
let redirect_url = form.return_to.unwrap_or_else(|| "/".to_string());
|
||
|
|
|
||
|
|
Response::builder()
|
||
|
|
.status(StatusCode::SEE_OTHER)
|
||
|
|
.header(axum::http::header::SET_COOKIE, cookie_header)
|
||
|
|
.header(axum::http::header::LOCATION, redirect_url)
|
||
|
|
.body(Body::empty())
|
||
|
|
.unwrap()
|
||
|
|
.into_response()
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn logout(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
|
||
|
|
if let Some(cookie) = SessionCookie::from_headers(&headers) {
|
||
|
|
let _ = storage::delete_session(&state.db, &cookie.session_id).await;
|
||
|
|
}
|
||
|
|
|
||
|
|
Response::builder()
|
||
|
|
.status(StatusCode::SEE_OTHER)
|
||
|
|
.header(axum::http::header::SET_COOKIE, SessionCookie::delete_cookie_header())
|
||
|
|
.header(axum::http::header::LOCATION, "/")
|
||
|
|
.body(Body::empty())
|
||
|
|
.unwrap()
|
||
|
|
.into_response()
|
||
|
|
}
|
||
|
|
|
||
|
|
fn html_escape(s: &str) -> String {
|
||
|
|
s.replace('&', "&")
|
||
|
|
.replace('<', "<")
|
||
|
|
.replace('>', ">")
|
||
|
|
.replace('"', """)
|
||
|
|
.replace('\'', "'")
|
||
|
|
}
|
||
|
|
|
||
|
|
fn urlencoded(s: &str) -> String {
|
||
|
|
serde_urlencoded::to_string(&[("", s)]).unwrap_or_default().trim_start_matches('=').to_string()
|
||
|
|
}
|