diff --git a/.claude/settings.local.json b/.claude/settings.local.json index b76f3e9..3e3ad92 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -41,7 +41,8 @@ "Bash(sqlite3:*)", "Bash(rustc:*)", "Bash(docker build:*)", - "Bash(git commit -m \"$(cat <<''EOF''\nfix(docker): Add missing client-wasm directory and update Rust version\n\n- Add COPY client-wasm to Dockerfile to include workspace member\n- Update Rust base image from 1.91 to 1.92\n- Fixes CI build failure: \"failed to load manifest for workspace member client-wasm\"\n\nšŸ¤– Generated with [Claude Code](https://claude.com/claude-code)\n\nCo-Authored-By: Claude Sonnet 4.5 \nEOF\n)\")" + "Bash(git commit -m \"$(cat <<''EOF''\nfix(docker): Add missing client-wasm directory and update Rust version\n\n- Add COPY client-wasm to Dockerfile to include workspace member\n- Update Rust base image from 1.91 to 1.92\n- Fixes CI build failure: \"failed to load manifest for workspace member client-wasm\"\n\nšŸ¤– Generated with [Claude Code](https://claude.com/claude-code)\n\nCo-Authored-By: Claude Sonnet 4.5 \nEOF\n)\")", + "WebFetch(domain:datatracker.ietf.org)" ], "deny": [], "ask": [] diff --git a/Cargo.toml b/Cargo.toml index cffbbc5..c60b662 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ tower-http = { version = "0.6", features = ["fs"] } # Validation regex = "1" url = "2" +urlencoding = "2" # GraphQL Admin API seaography = { version = "1", features = ["with-decimal", "with-chrono", "with-uuid"] } diff --git a/migration/src/lib.rs b/migration/src/lib.rs index a6713eb..d434b45 100644 --- a/migration/src/lib.rs +++ b/migration/src/lib.rs @@ -4,6 +4,7 @@ mod m20250101_000001_initial_schema; mod m20250107_000001_add_passkeys; mod m20250107_000002_extend_sessions_users; mod m20250108_000001_add_consent_table; +mod m20250109_000001_add_device_codes; pub struct Migrator; @@ -15,6 +16,7 @@ impl MigratorTrait for Migrator { Box::new(m20250107_000001_add_passkeys::Migration), Box::new(m20250107_000002_extend_sessions_users::Migration), Box::new(m20250108_000001_add_consent_table::Migration), + Box::new(m20250109_000001_add_device_codes::Migration), ] } } diff --git a/migration/src/m20250109_000001_add_device_codes.rs b/migration/src/m20250109_000001_add_device_codes.rs new file mode 100644 index 0000000..38738f7 --- /dev/null +++ b/migration/src/m20250109_000001_add_device_codes.rs @@ -0,0 +1,121 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + // Create device_codes table for OAuth 2.0 Device Authorization Grant (RFC 8628) + manager + .create_table( + Table::create() + .table(DeviceCode::Table) + .if_not_exists() + .col( + ColumnDef::new(DeviceCode::DeviceCode) + .string() + .not_null() + .primary_key(), + ) + .col(ColumnDef::new(DeviceCode::UserCode).string().not_null()) + .col(ColumnDef::new(DeviceCode::ClientId).string().not_null()) + .col(ColumnDef::new(DeviceCode::ClientName).string()) + .col(ColumnDef::new(DeviceCode::Scope).string().not_null()) + .col(ColumnDef::new(DeviceCode::DeviceInfo).string()) + .col( + ColumnDef::new(DeviceCode::CreatedAt) + .big_integer() + .not_null(), + ) + .col( + ColumnDef::new(DeviceCode::ExpiresAt) + .big_integer() + .not_null(), + ) + .col(ColumnDef::new(DeviceCode::LastPollAt).big_integer()) + .col( + ColumnDef::new(DeviceCode::Interval) + .integer() + .not_null() + .default(5), + ) + .col( + ColumnDef::new(DeviceCode::Status) + .string() + .not_null() + .default("pending"), + ) + .col(ColumnDef::new(DeviceCode::Subject).string()) + .col(ColumnDef::new(DeviceCode::AuthTime).big_integer()) + .col(ColumnDef::new(DeviceCode::Amr).string()) + .col(ColumnDef::new(DeviceCode::Acr).string()) + .to_owned(), + ) + .await?; + + // Create index on user_code for fast lookups during verification + manager + .create_index( + Index::create() + .if_not_exists() + .name("idx_device_codes_user_code") + .table(DeviceCode::Table) + .col(DeviceCode::UserCode) + .to_owned(), + ) + .await?; + + // Create index on expires_at for efficient cleanup job + manager + .create_index( + Index::create() + .if_not_exists() + .name("idx_device_codes_expires_at") + .table(DeviceCode::Table) + .col(DeviceCode::ExpiresAt) + .to_owned(), + ) + .await?; + + // Create index on status for filtering pending/approved codes + manager + .create_index( + Index::create() + .if_not_exists() + .name("idx_device_codes_status") + .table(DeviceCode::Table) + .col(DeviceCode::Status) + .to_owned(), + ) + .await?; + + Ok(()) + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table(Table::drop().table(DeviceCode::Table).to_owned()) + .await + } +} + +#[derive(DeriveIden)] +enum DeviceCode { + Table, + DeviceCode, + UserCode, + ClientId, + ClientName, + Scope, + DeviceInfo, + CreatedAt, + ExpiresAt, + LastPollAt, + Interval, + Status, + Subject, + AuthTime, + Amr, + Acr, +} diff --git a/src/admin_mutations.rs b/src/admin_mutations.rs index 3a6cbc3..71e7e43 100644 --- a/src/admin_mutations.rs +++ b/src/admin_mutations.rs @@ -1,7 +1,7 @@ use async_graphql::*; use sea_orm::{ ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, IntoActiveModel, QueryFilter, - QueryOrder, QuerySelect, Set, + QuerySelect, Set, }; use std::sync::Arc; @@ -172,6 +172,11 @@ impl AdminQuery { description: "Clean up expired WebAuthn challenges".to_string(), schedule: "Every 5 minutes".to_string(), }, + JobInfo { + name: "cleanup_expired_device_codes".to_string(), + description: "Clean up expired device authorization codes".to_string(), + schedule: "Hourly at :45".to_string(), + }, ]) } diff --git a/src/entities/device_code.rs b/src/entities/device_code.rs new file mode 100644 index 0000000..beb93a9 --- /dev/null +++ b/src/entities/device_code.rs @@ -0,0 +1,28 @@ +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "device_codes")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub device_code: String, + pub user_code: String, + pub client_id: String, + pub client_name: Option, + pub scope: String, + pub device_info: Option, // JSON: {ip_address, user_agent} + pub created_at: i64, + pub expires_at: i64, + pub last_poll_at: Option, + pub interval: i64, + pub status: String, // "pending" | "approved" | "denied" | "consumed" + pub subject: Option, + pub auth_time: Option, + pub amr: Option, // JSON array: ["pwd", "hwk"] + pub acr: Option, // "aal1" or "aal2" +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/src/entities/mod.rs b/src/entities/mod.rs index de3d2d9..4cd758b 100644 --- a/src/entities/mod.rs +++ b/src/entities/mod.rs @@ -2,6 +2,7 @@ pub mod access_token; pub mod auth_code; pub mod client; pub mod consent; +pub mod device_code; pub mod job_execution; pub mod passkey; pub mod property; @@ -14,6 +15,7 @@ pub use access_token::Entity as AccessToken; pub use auth_code::Entity as AuthCode; pub use client::Entity as Client; pub use consent::Entity as Consent; +pub use device_code::Entity as DeviceCode; pub use job_execution::Entity as JobExecution; pub use passkey::Entity as Passkey; pub use property::Entity as Property; diff --git a/src/jobs.rs b/src/jobs.rs index e021cce..1b01d65 100644 --- a/src/jobs.rs +++ b/src/jobs.rs @@ -123,13 +123,49 @@ pub async fn init_scheduler(db: DatabaseConnection) -> Result { + info!("Cleaned up {} expired device codes", count); + if let Some(id) = execution_id { + let _ = + complete_job_execution(&db, id, true, None, Some(count as i64)).await; + } + } + Err(e) => { + error!("Failed to cleanup expired device codes: {}", e); + if let Some(id) = execution_id { + let _ = + complete_job_execution(&db, id, false, Some(e.to_string()), None).await; + } + } + } + }) + }) + .map_err(|e| CrabError::Other(format!("Failed to create cleanup device codes job: {}", e)))?; + + sched + .add(cleanup_device_codes_job) + .await + .map_err(|e| CrabError::Other(format!("Failed to add cleanup device codes job: {}", e)))?; + // Start the scheduler sched .start() .await .map_err(|e| CrabError::Other(format!("Failed to start job scheduler: {}", e)))?; - info!("Job scheduler started with {} jobs", 3); + info!("Job scheduler started with {} jobs", 4); Ok(sched) } @@ -197,6 +233,7 @@ pub async fn trigger_job_manually( "cleanup_expired_sessions" => storage::cleanup_expired_sessions(db).await, "cleanup_expired_refresh_tokens" => storage::cleanup_expired_refresh_tokens(db).await, "cleanup_expired_challenges" => storage::cleanup_expired_challenges(db).await, + "cleanup_expired_device_codes" => storage::cleanup_expired_device_codes(db).await, _ => { return Err(CrabError::Other(format!("Unknown job name: {}", job_name))); } diff --git a/src/storage.rs b/src/storage.rs index 819654b..eb3d7d0 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -3,7 +3,7 @@ use crate::errors::CrabError; use crate::settings::Database as DbCfg; use base64ct::Encoding; use chrono::Utc; -use rand::RngCore; +use rand::{Rng, RngCore}; use sea_orm::{ ActiveModelTrait, ColumnTrait, Database, DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, Set, @@ -92,6 +92,25 @@ pub struct RefreshToken { pub parent_token: Option, // For token rotation tracking } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeviceCode { + pub device_code: String, + pub user_code: String, + pub client_id: String, + pub client_name: Option, + pub scope: String, + pub device_info: Option, // JSON: {ip_address, user_agent} + pub created_at: i64, + pub expires_at: i64, + pub last_poll_at: Option, + pub interval: i64, + pub status: String, // "pending" | "approved" | "denied" | "consumed" + pub subject: Option, + pub auth_time: Option, + pub amr: Option, // JSON array: ["pwd", "hwk"] + pub acr: Option, // "aal1" or "aal2" +} + pub async fn init(cfg: &DbCfg) -> Result { let db = Database::connect(&cfg.url).await?; Ok(db) @@ -352,6 +371,25 @@ fn random_id() -> String { base64ct::Base64UrlUnpadded::encode_string(&bytes) } +/// Generate 8-character base-20 user code in format XXXX-XXXX +/// Alphabet: BCDFGHJKLMNPQRSTVWXZ (consonants only, no ambiguous chars) +/// Entropy: 20^8 = ~43 bits +fn generate_user_code() -> String { + const ALPHABET: &[u8] = b"BCDFGHJKLMNPQRSTVWXZ"; + let mut rng = rand::thread_rng(); + let mut code = String::with_capacity(9); + + for i in 0..8 { + if i == 4 { + code.push('-'); + } + let idx = rng.gen_range(0..ALPHABET.len()); + code.push(ALPHABET[idx] as char); + } + + code +} + // User management functions pub async fn create_user( @@ -1186,6 +1224,305 @@ pub async fn revoke_consents_for_client( Ok(()) } +// Device Authorization Grant (RFC 8628) functions + +/// Create a new device authorization code +pub async fn create_device_code( + db: &DatabaseConnection, + client_id: &str, + client_name: Option, + scope: &str, + device_info: Option, +) -> Result { + let device_code = random_id(); + let user_code = generate_user_code(); + let now = Utc::now().timestamp(); + let expires_at = now + 1800; // 30 minutes + + let device_code_model = entities::device_code::ActiveModel { + device_code: Set(device_code.clone()), + user_code: Set(user_code.clone()), + client_id: Set(client_id.to_string()), + client_name: Set(client_name.clone()), + scope: Set(scope.to_string()), + device_info: Set(device_info.clone()), + created_at: Set(now), + expires_at: Set(expires_at), + last_poll_at: Set(None), + interval: Set(5), + status: Set("pending".to_string()), + subject: Set(None), + auth_time: Set(None), + amr: Set(None), + acr: Set(None), + }; + + device_code_model.insert(db).await?; + + Ok(DeviceCode { + device_code, + user_code, + client_id: client_id.to_string(), + client_name, + scope: scope.to_string(), + device_info, + created_at: now, + expires_at, + last_poll_at: None, + interval: 5, + status: "pending".to_string(), + subject: None, + auth_time: None, + amr: None, + acr: None, + }) +} + +/// Get device code by device_code +pub async fn get_device_code( + db: &DatabaseConnection, + device_code: &str, +) -> Result, CrabError> { + use entities::device_code::{Column, Entity}; + + let now = Utc::now().timestamp(); + + let result = Entity::find() + .filter(Column::DeviceCode.eq(device_code)) + .one(db) + .await?; + + match result { + Some(dc) => { + // Check if expired + if dc.expires_at < now { + return Ok(None); + } + + Ok(Some(DeviceCode { + device_code: dc.device_code, + user_code: dc.user_code, + client_id: dc.client_id, + client_name: dc.client_name, + scope: dc.scope, + device_info: dc.device_info, + created_at: dc.created_at, + expires_at: dc.expires_at, + last_poll_at: dc.last_poll_at, + interval: dc.interval, + status: dc.status, + subject: dc.subject, + auth_time: dc.auth_time, + amr: dc.amr, + acr: dc.acr, + })) + } + None => Ok(None), + } +} + +/// Get device code by user_code +pub async fn get_device_code_by_user_code( + db: &DatabaseConnection, + user_code: &str, +) -> Result, CrabError> { + use entities::device_code::{Column, Entity}; + + let now = Utc::now().timestamp(); + + let result = Entity::find() + .filter(Column::UserCode.eq(user_code)) + .one(db) + .await?; + + match result { + Some(dc) => { + // Check if expired + if dc.expires_at < now { + return Ok(None); + } + + Ok(Some(DeviceCode { + device_code: dc.device_code, + user_code: dc.user_code, + client_id: dc.client_id, + client_name: dc.client_name, + scope: dc.scope, + device_info: dc.device_info, + created_at: dc.created_at, + expires_at: dc.expires_at, + last_poll_at: dc.last_poll_at, + interval: dc.interval, + status: dc.status, + subject: dc.subject, + auth_time: dc.auth_time, + amr: dc.amr, + acr: dc.acr, + })) + } + None => Ok(None), + } +} + +/// Update device code polling metadata +pub async fn update_device_code_poll( + db: &DatabaseConnection, + device_code: &str, +) -> Result<(), CrabError> { + use entities::device_code::{Column, Entity}; + + let now = Utc::now().timestamp(); + + if let Some(dc) = Entity::find() + .filter(Column::DeviceCode.eq(device_code)) + .one(db) + .await? + { + let mut active: entities::device_code::ActiveModel = dc.into(); + active.last_poll_at = Set(Some(now)); + active.update(db).await?; + } + + Ok(()) +} + +/// Approve device code with user authentication context +pub async fn approve_device_code( + db: &DatabaseConnection, + device_code: &str, + subject: &str, + auth_time: i64, + amr: Option, + acr: Option, +) -> Result<(), CrabError> { + use entities::device_code::{Column, Entity}; + + if let Some(dc) = Entity::find() + .filter(Column::DeviceCode.eq(device_code)) + .one(db) + .await? + { + let mut active: entities::device_code::ActiveModel = dc.into(); + active.status = Set("approved".to_string()); + active.subject = Set(Some(subject.to_string())); + active.auth_time = Set(Some(auth_time)); + active.amr = Set(amr); + active.acr = Set(acr); + active.update(db).await?; + } + + Ok(()) +} + +/// Deny device code +pub async fn deny_device_code( + db: &DatabaseConnection, + device_code: &str, +) -> Result<(), CrabError> { + use entities::device_code::{Column, Entity}; + + if let Some(dc) = Entity::find() + .filter(Column::DeviceCode.eq(device_code)) + .one(db) + .await? + { + let mut active: entities::device_code::ActiveModel = dc.into(); + active.status = Set("denied".to_string()); + active.update(db).await?; + } + + Ok(()) +} + +/// Consume device code (mark as used) and return its data +pub async fn consume_device_code( + db: &DatabaseConnection, + device_code: &str, +) -> Result, CrabError> { + use entities::device_code::{Column, Entity}; + + let now = Utc::now().timestamp(); + + if let Some(dc) = Entity::find() + .filter(Column::DeviceCode.eq(device_code)) + .one(db) + .await? + { + // Check if expired + if dc.expires_at < now { + return Ok(None); + } + + // Check if status is approved + if dc.status != "approved" { + return Ok(None); + } + + // Mark as consumed + let mut active: entities::device_code::ActiveModel = dc.clone().into(); + active.status = Set("consumed".to_string()); + active.update(db).await?; + + // Return the device code data + Ok(Some(DeviceCode { + device_code: dc.device_code, + user_code: dc.user_code, + client_id: dc.client_id, + client_name: dc.client_name, + scope: dc.scope, + device_info: dc.device_info, + created_at: dc.created_at, + expires_at: dc.expires_at, + last_poll_at: dc.last_poll_at, + interval: dc.interval, + status: "consumed".to_string(), + subject: dc.subject, + auth_time: dc.auth_time, + amr: dc.amr, + acr: dc.acr, + })) + } else { + Ok(None) + } +} + +/// Increment device code polling interval by 5 seconds (for slow_down) +pub async fn increment_device_code_interval( + db: &DatabaseConnection, + device_code: &str, +) -> Result<(), CrabError> { + use entities::device_code::{Column, Entity}; + + let now = Utc::now().timestamp(); + + if let Some(dc) = Entity::find() + .filter(Column::DeviceCode.eq(device_code)) + .one(db) + .await? + { + let mut active: entities::device_code::ActiveModel = dc.into(); + active.interval = Set(active.interval.clone().unwrap() + 5); + active.last_poll_at = Set(Some(now)); + active.update(db).await?; + } + + Ok(()) +} + +/// Cleanup expired device codes +pub async fn cleanup_expired_device_codes(db: &DatabaseConnection) -> Result { + use entities::device_code::{Column, Entity}; + + let now = Utc::now().timestamp(); + + let result = Entity::delete_many() + .filter(Column::ExpiresAt.lt(now)) + .exec(db) + .await?; + + Ok(result.rows_affected) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/web.rs b/src/web.rs index 859c67e..f40985a 100644 --- a/src/web.rs +++ b/src/web.rs @@ -26,6 +26,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::SystemTime; use tower_http::services::ServeDir; +use urlencoding; #[derive(Clone)] pub struct AppState { @@ -122,6 +123,11 @@ pub async fn serve( .route("/token", post(token)) .route("/revoke", post(token_revoke)) .route("/userinfo", get(userinfo)) + // Device Authorization Grant (RFC 8628) endpoints + .route("/device_authorization", post(device_authorization)) + .route("/device", get(device_page)) + .route("/device/verify", post(device_verify)) + .route("/device/consent", post(device_consent)) // WebAuthn / Passkey endpoints .route("/webauthn/register/start", post(passkey_register_start)) .route("/webauthn/register/finish", post(passkey_register_finish)) @@ -200,6 +206,7 @@ async fn discovery(State(state): State) -> impl IntoResponse { "issuer": issuer, "authorization_endpoint": format!("{}/authorize", issuer), "token_endpoint": format!("{}/token", issuer), + "device_authorization_endpoint": format!("{}/device_authorization", issuer), "jwks_uri": format!("{}/.well-known/jwks.json", issuer), "registration_endpoint": format!("{}/connect/register", issuer), "userinfo_endpoint": format!("{}/userinfo", issuer), @@ -208,7 +215,7 @@ async fn discovery(State(state): State) -> impl IntoResponse { "subject_types_supported": ["public"], "id_token_signing_alg_values_supported": [state.settings.keys.alg], // Additional recommended metadata for better interoperability - "grant_types_supported": ["authorization_code", "refresh_token", "implicit"], + "grant_types_supported": ["authorization_code", "refresh_token", "implicit", "urn:ietf:params:oauth:grant-type:device_code"], "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], "code_challenge_methods_supported": ["S256"], "claims_supported": [ @@ -1114,6 +1121,7 @@ struct TokenRequest { client_secret: Option, code_verifier: Option, refresh_token: Option, + device_code: Option, // For device_code grant } #[derive(Debug, Serialize)] @@ -1156,6 +1164,9 @@ async fn token( match req.grant_type.as_str() { "authorization_code" => handle_authorization_code_grant(state, headers, req).await, "refresh_token" => handle_refresh_token_grant(state, headers, req).await, + "urn:ietf:params:oauth:grant-type:device_code" => { + handle_device_code_grant(state, headers, req).await + } _ => ( StatusCode::BAD_REQUEST, Json(json!({"error":"unsupported_grant_type"})), @@ -1591,6 +1602,247 @@ async fn handle_refresh_token_grant( ) } +async fn handle_device_code_grant( + state: AppState, + _headers: HeaderMap, + req: TokenRequest, +) -> Response { + // Require device_code parameter + let device_code_str = match req.device_code { + Some(dc) => dc, + None => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "invalid_request", + "error_description": "device_code required" + })), + ) + .into_response() + } + }; + + // Require client_id (device flow uses public clients, so no secret required) + let client_id = match req.client_id { + Some(cid) => cid, + None => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "invalid_request", + "error_description": "client_id required" + })), + ) + .into_response() + } + }; + + // Get device code + let device_code = match storage::get_device_code(&state.db, &device_code_str).await { + Ok(Some(dc)) => dc, + _ => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "invalid_grant", + "error_description": "device_code not found or expired" + })), + ) + .into_response() + } + }; + + // Verify device_code bound to same client_id + if device_code.client_id != client_id { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "invalid_grant", + "error_description": "device_code not bound to this client" + })), + ) + .into_response(); + } + + // Rate limiting: check if polling too fast + if let Some(last_poll) = device_code.last_poll_at { + let now = chrono::Utc::now().timestamp(); + let elapsed = now - last_poll; + + if elapsed < device_code.interval { + // Polling too fast - increment interval and return slow_down + let _ = storage::increment_device_code_interval(&state.db, &device_code_str).await; + + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "slow_down", + "error_description": "Polling too frequently" + })), + ) + .into_response(); + } + } + + // Update last_poll_at + let _ = storage::update_device_code_poll(&state.db, &device_code_str).await; + + // Check status + match device_code.status.as_str() { + "pending" => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "authorization_pending", + "error_description": "User has not yet authorized the device" + })), + ) + .into_response() + } + "denied" => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "access_denied", + "error_description": "User denied the authorization request" + })), + ) + .into_response() + } + "consumed" => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "invalid_grant", + "error_description": "device_code already used" + })), + ) + .into_response() + } + "approved" => { + // Continue to token issuance + } + _ => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": "server_error", + "error_description": "Unknown device_code status" + })), + ) + .into_response() + } + } + + // Consume device code + let consumed_device_code = match storage::consume_device_code(&state.db, &device_code_str).await + { + Ok(Some(dc)) => dc, + _ => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "invalid_grant", + "error_description": "Failed to consume device_code" + })), + ) + .into_response() + } + }; + + let subject = consumed_device_code.subject.unwrap(); + + // Issue access token + let access = match storage::issue_access_token( + &state.db, + &client_id, + &subject, + &consumed_device_code.scope, + 3600, + ) + .await + { + Ok(t) => t, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": "server_error", + "details": e.to_string() + })), + ) + .into_response() + } + }; + + // Build ID Token with profile claims + let id_token = match build_id_token( + &state, + &client_id, + &subject, + None, // No nonce for device flow + consumed_device_code.auth_time, + Some(&access.token), + consumed_device_code.amr.as_deref(), + consumed_device_code.acr.as_deref(), + ) + .await + { + Ok(t) => t, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": "server_error", + "details": e.to_string() + })), + ) + .into_response() + } + }; + + // Issue refresh token if offline_access scope requested + let refresh_token_opt = if consumed_device_code + .scope + .split_whitespace() + .any(|s| s == "offline_access") + { + match storage::issue_refresh_token( + &state.db, + &client_id, + &subject, + &consumed_device_code.scope, + 2592000, // 30 days + None, // No parent token for device flow + ) + .await + { + Ok(rt) => Some(rt.token), + Err(_) => None, + } + } else { + None + }; + + let resp = TokenResponse { + access_token: access.token, + token_type: "bearer".into(), + expires_in: 3600, + id_token: Some(id_token), + refresh_token: refresh_token_opt, + }; + + // Add Cache-Control: no-store as required by OAuth 2.0 and OIDC specs + json_with_headers( + StatusCode::OK, + serde_json::to_value(resp).unwrap(), + &[ + ("cache-control", "no-store".to_string()), + ("pragma", "no-cache".to_string()), + ], + ) +} + fn authenticate_client( headers: &HeaderMap, req: &TokenRequest, @@ -2948,3 +3200,426 @@ async fn update_passkey_handler( Ok(StatusCode::NO_CONTENT) } + +// ======================================== +// Device Authorization Grant (RFC 8628) +// ======================================== + +#[derive(Deserialize)] +struct DeviceAuthorizationRequest { + client_id: Option, + client_name: Option, + scope: Option, +} + +#[derive(Serialize)] +struct DeviceAuthorizationResponse { + device_code: String, + user_code: String, + verification_uri: String, + verification_uri_complete: String, + expires_in: i64, + interval: i64, + #[serde(skip_serializing_if = "Option::is_none")] + client_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + client_secret: Option, +} + +/// POST /device_authorization - RFC 8628 Device Authorization Endpoint +async fn device_authorization( + State(state): State, + headers: HeaderMap, + Form(req): Form, +) -> Result, (StatusCode, Json)> { + // Extract device info + let user_agent = headers + .get("user-agent") + .and_then(|h| h.to_str().ok()) + .unwrap_or("Unknown") + .to_string(); + + let ip_address = headers + .get("x-forwarded-for") + .and_then(|h| h.to_str().ok()) + .and_then(|s| s.split(',').next()) + .unwrap_or("Unknown") + .to_string(); + + let device_info = serde_json::json!({ + "ip_address": ip_address, + "user_agent": user_agent, + }) + .to_string(); + + // Determine client_id (auto-register if not provided) + let (client_id, client_secret, client_name, auto_registered) = if let Some(cid) = req.client_id + { + // Validate existing client + let client = storage::get_client(&state.db, &cid).await.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": "server_error", + "error_description": format!("Database error: {}", e) + })), + ) + })?; + + if client.is_none() { + return Err(( + StatusCode::UNAUTHORIZED, + Json(json!({ + "error": "invalid_client", + "error_description": "Client not found" + })), + )); + } + + let client = client.unwrap(); + (client.client_id, client.client_secret, client.client_name, false) + } else { + // Auto-register new client + let new_client_name = req.client_name.unwrap_or_else(|| "Auto-registered Device".to_string()); + let new_client = storage::NewClient { + client_name: Some(new_client_name.clone()), + redirect_uris: vec![], // Device flow doesn't use redirect URIs + }; + + let client = storage::create_client(&state.db, new_client) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": "server_error", + "error_description": format!("Failed to create client: {}", e) + })), + ) + })?; + + (client.client_id, client.client_secret, Some(new_client_name), true) + }; + + // Validate scope (must include "openid" for OIDC) + let scope = req.scope.unwrap_or_else(|| "openid".to_string()); + if !scope.split_whitespace().any(|s| s == "openid") { + return Err(( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": "invalid_scope", + "error_description": "Scope must include 'openid'" + })), + )); + } + + // Create device code + let device_code = storage::create_device_code( + &state.db, + &client_id, + client_name, + &scope, + Some(device_info), + ) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": "server_error", + "error_description": format!("Failed to create device code: {}", e) + })), + ) + })?; + + // Build URIs + let issuer = state.settings.issuer(); + let verification_uri = format!("{}/device", issuer); + let verification_uri_complete = format!("{}/device?user_code={}", issuer, device_code.user_code); + + Ok(Json(DeviceAuthorizationResponse { + device_code: device_code.device_code, + user_code: device_code.user_code, + verification_uri, + verification_uri_complete, + expires_in: 1800, + interval: 5, + client_id: if auto_registered { Some(client_id) } else { None }, + client_secret: if auto_registered { Some(client_secret) } else { None }, + })) +} + +#[derive(Deserialize)] +struct DevicePageQuery { + user_code: Option, +} + +/// GET /device - Device verification page +async fn device_page( + State(state): State, + headers: HeaderMap, + Query(query): Query, +) -> Result, Redirect> { + // Check if user has session + let session_cookie = SessionCookie::from_headers(&headers); + if let Some(cookie) = session_cookie { + if let Ok(Some(_session)) = storage::get_session(&state.db, &cookie.session_id).await { + // User is authenticated, show form + let prefilled_code = query.user_code.as_deref().unwrap_or(""); + + let html = format!( + r#" + + + + + Device Verification + + + +
+

Device Verification

+
+

Enter the code shown on your device to authorize access.

+

Format: XXXX-XXXX (8 characters)

+
+
+ + +
+
+ +"#, + prefilled_code + ); + + return Ok(Html(html)); + } + } + + // No session, redirect to login + let return_to = if let Some(code) = query.user_code { + format!("/login?return_to={}", urlencoding::encode(&format!("/device?user_code={}", code))) + } else { + "/login?return_to=/device".to_string() + }; + + Err(Redirect::to(&return_to)) +} + +#[derive(Deserialize)] +struct DeviceVerifyRequest { + user_code: String, +} + +/// POST /device/verify - Verify user code and show consent page +async fn device_verify( + State(state): State, + headers: HeaderMap, + Form(req): Form, +) -> Result, (StatusCode, String)> { + // Require authenticated session + let session_cookie = SessionCookie::from_headers(&headers) + .ok_or((StatusCode::UNAUTHORIZED, "Not authenticated".to_string()))?; + + let _session = storage::get_session(&state.db, &session_cookie.session_id) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::UNAUTHORIZED, "Session not found".to_string()))?; + + // Lookup device code by user_code + let device_code = storage::get_device_code_by_user_code(&state.db, &req.user_code.to_uppercase()) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::NOT_FOUND, "Device code not found or expired".to_string()))?; + + // Parse device_info + let device_info: serde_json::Value = serde_json::from_str(device_code.device_info.as_deref().unwrap_or("{}")) + .unwrap_or(json!({})); + + let ip_address = device_info["ip_address"].as_str().unwrap_or("Unknown"); + let user_agent = device_info["user_agent"].as_str().unwrap_or("Unknown"); + + // Render consent page + let html = format!( + r#" + + + + + Authorize Device + + + +
+

Authorize Device

+
+ Verify this is your device! Only approve if you recognize the device information below. +
+
{}
+
+
+
Client:
+
{}
+
Requested Scopes:
+
{}
+
IP Address:
+
{}
+
User Agent:
+
{}
+
+
+
+
+ + + +
+
+ + + +
+
+
+ +"#, + device_code.user_code, + device_code.client_name.as_deref().unwrap_or("Unknown Application"), + device_code.scope, + ip_address, + user_agent, + device_code.user_code, + device_code.user_code + ); + + Ok(Html(html)) +} + +#[derive(Deserialize)] +struct DeviceConsentRequest { + user_code: String, + #[serde(default)] + approved: bool, +} + +/// POST /device/consent - Handle device authorization approval/denial +async fn device_consent( + State(state): State, + headers: HeaderMap, + Form(req): Form, +) -> Result, (StatusCode, String)> { + // Require authenticated session + let session_cookie = SessionCookie::from_headers(&headers) + .ok_or((StatusCode::UNAUTHORIZED, "Not authenticated".to_string()))?; + + let session = storage::get_session(&state.db, &session_cookie.session_id) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::UNAUTHORIZED, "Session not found".to_string()))?; + + // Lookup device code by user_code + let device_code = storage::get_device_code_by_user_code(&state.db, &req.user_code.to_uppercase()) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::NOT_FOUND, "Device code not found or expired".to_string()))?; + + // TODO: Check 2FA requirements (admin-enforced, high-value scopes, max_age) + // For now, we'll skip 2FA checks and proceed directly + + if req.approved { + // Approve the device code + storage::approve_device_code( + &state.db, + &device_code.device_code, + &session.subject, + session.auth_time, + session.amr, + session.acr, + ) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + // Show success page + Ok(Html( + r#" + + + + + Device Approved + + + +
+
āœ“
+

Device Approved

+

You can now return to your device and continue.

+
+ +"# + .to_string(), + )) + } else { + // Deny the device code + storage::deny_device_code(&state.db, &device_code.device_code) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + // Show denial page + Ok(Html( + r#" + + + + + Device Denied + + + +
+
āœ—
+

Device Access Denied

+

The authorization request has been rejected.

+
+ +"# + .to_string(), + )) + } +}