From 86c88d8aeeb7e9fcd83e1ac087c52f56ee3a5ef1 Mon Sep 17 00:00:00 2001 From: Till Wegmueller Date: Tue, 6 Jan 2026 10:56:23 +0100 Subject: [PATCH] Commit work in progress Signed-off-by: Till Wegmueller --- src/jobs.rs | 117 ++++++----- src/jwks.rs | 3 +- src/storage.rs | 493 ++++++++++++++++++++++++++++++++++++++++------- src/user_sync.rs | 25 +-- src/web.rs | 41 ++-- 5 files changed, 529 insertions(+), 150 deletions(-) diff --git a/src/jobs.rs b/src/jobs.rs index 7d821d7..58770e4 100644 --- a/src/jobs.rs +++ b/src/jobs.rs @@ -22,23 +22,23 @@ pub async fn init_scheduler(db: DatabaseConnection) -> Result { info!("Cleaned up {} expired sessions", count); if let Some(id) = execution_id { let _ = - complete_job_execution(&db, id, true, None, Some(count as i64)).await; + complete_job_execution(db, id, true, None, Some(count as i64)).await; } } Err(e) => { error!("Failed to cleanup expired sessions: {}", e); if let Some(id) = execution_id { let _ = - complete_job_execution(&db, id, false, Some(e.to_string()), None).await; + complete_job_execution(db, id, false, Some(e.to_string()), None).await; } } } @@ -58,23 +58,23 @@ pub async fn init_scheduler(db: DatabaseConnection) -> Result { info!("Cleaned up {} expired refresh tokens", count); if let Some(id) = execution_id { let _ = - complete_job_execution(&db, id, true, None, Some(count as i64)).await; + complete_job_execution(db, id, true, None, Some(count as i64)).await; } } Err(e) => { error!("Failed to cleanup expired refresh tokens: {}", e); if let Some(id) = execution_id { let _ = - complete_job_execution(&db, id, false, Some(e.to_string()), None).await; + complete_job_execution(db, id, false, Some(e.to_string()), None).await; } } } @@ -94,23 +94,23 @@ pub async fn init_scheduler(db: DatabaseConnection) -> Result { info!("Cleaned up {} expired WebAuthn challenges", count); if let Some(id) = execution_id { let _ = - complete_job_execution(&db, id, true, None, Some(count as i64)).await; + complete_job_execution(db, id, true, None, Some(count as i64)).await; } } Err(e) => { error!("Failed to cleanup expired challenges: {}", e); if let Some(id) = execution_id { let _ = - complete_job_execution(&db, id, false, Some(e.to_string()), None).await; + complete_job_execution(db, id, false, Some(e.to_string()), None).await; } } } @@ -226,28 +226,43 @@ mod tests { use sea_orm_migration::MigratorTrait; use tempfile::NamedTempFile; - /// Helper to create an in-memory test database - async fn test_db() -> DatabaseConnection { - let temp_file = NamedTempFile::new().expect("Failed to create temp file"); - let db_path = temp_file.path().to_str().expect("Invalid temp file path"); - let db_url = format!("sqlite://{}?mode=rwc", db_path); + /// Test database helper that keeps temp file alive + struct TestDb { + connection: DatabaseConnection, + _temp_file: NamedTempFile, + } - let db = Database::connect(&db_url) - .await - .expect("Failed to connect to test database"); + impl TestDb { + async fn new() -> Self { + let temp_file = NamedTempFile::new().expect("Failed to create temp file"); + let db_path = temp_file.path().to_str().expect("Invalid temp file path"); + let db_url = format!("sqlite://{}?mode=rwc", db_path); - migration::Migrator::up(&db, None) - .await - .expect("Failed to run migrations"); + let connection = Database::connect(&db_url) + .await + .expect("Failed to connect to test database"); - db + migration::Migrator::up(&connection, None) + .await + .expect("Failed to run migrations"); + + Self { + connection, + _temp_file: temp_file, + } + } + + fn connection(&self) -> &DatabaseConnection { + &self.connection + } } #[tokio::test] async fn test_start_job_execution() { - let db = test_db().await; + let test_db = TestDb::new().await; + let db = test_db.connection(); - let execution_id = start_job_execution(&db, "test_job") + let execution_id = start_job_execution(db, "test_job") .await .expect("Failed to start job execution"); @@ -257,7 +272,7 @@ mod tests { use entities::job_execution::{Column, Entity}; let execution = Entity::find() .filter(Column::Id.eq(execution_id)) - .one(&db) + .one(db) .await .expect("Failed to query job execution") .expect("Job execution not found"); @@ -270,13 +285,14 @@ mod tests { #[tokio::test] async fn test_complete_job_execution_success() { - let db = test_db().await; + let test_db = TestDb::new().await; + let db = test_db.connection(); - let execution_id = start_job_execution(&db, "test_job") + let execution_id = start_job_execution(db, "test_job") .await .expect("Failed to start job execution"); - complete_job_execution(&db, execution_id, true, None, Some(42)) + complete_job_execution(db, execution_id, true, None, Some(42)) .await .expect("Failed to complete job execution"); @@ -284,7 +300,7 @@ mod tests { use entities::job_execution::{Column, Entity}; let execution = Entity::find() .filter(Column::Id.eq(execution_id)) - .one(&db) + .one(db) .await .expect("Failed to query job execution") .expect("Job execution not found"); @@ -297,9 +313,10 @@ mod tests { #[tokio::test] async fn test_complete_job_execution_failure() { - let db = test_db().await; + let test_db = TestDb::new().await; + let db = test_db.connection(); - let execution_id = start_job_execution(&db, "test_job") + let execution_id = start_job_execution(db, "test_job") .await .expect("Failed to start job execution"); @@ -317,7 +334,7 @@ mod tests { use entities::job_execution::{Column, Entity}; let execution = Entity::find() .filter(Column::Id.eq(execution_id)) - .one(&db) + .one(db) .await .expect("Failed to query job execution") .expect("Job execution not found"); @@ -330,20 +347,21 @@ mod tests { #[tokio::test] async fn test_trigger_job_manually_cleanup_sessions() { - let db = test_db().await; + let test_db = TestDb::new().await; + let db = test_db.connection(); // Create an expired session - let user = storage::create_user(&db, "testuser", "password123", None) + let user = storage::create_user(db, "testuser", "password123", None) .await .expect("Failed to create user"); let past_auth_time = Utc::now().timestamp() - 7200; // 2 hours ago - storage::create_session(&db, &user.subject, past_auth_time, 3600, None, None) // 1 hour TTL + storage::create_session(db, &user.subject, past_auth_time, 3600, None, None) // 1 hour TTL .await .expect("Failed to create session"); // Trigger cleanup job - trigger_job_manually(&db, "cleanup_expired_sessions") + trigger_job_manually(db, "cleanup_expired_sessions") .await .expect("Failed to trigger job"); @@ -351,7 +369,7 @@ mod tests { use entities::job_execution::{Column, Entity}; let execution = Entity::find() .filter(Column::JobName.eq("cleanup_expired_sessions")) - .one(&db) + .one(db) .await .expect("Failed to query job execution") .expect("Job execution not found"); @@ -362,10 +380,11 @@ mod tests { #[tokio::test] async fn test_trigger_job_manually_cleanup_tokens() { - let db = test_db().await; + let test_db = TestDb::new().await; + let db = test_db.connection(); // Trigger cleanup_expired_refresh_tokens job - trigger_job_manually(&db, "cleanup_expired_refresh_tokens") + trigger_job_manually(db, "cleanup_expired_refresh_tokens") .await .expect("Failed to trigger job"); @@ -373,7 +392,7 @@ mod tests { use entities::job_execution::{Column, Entity}; let execution = Entity::find() .filter(Column::JobName.eq("cleanup_expired_refresh_tokens")) - .one(&db) + .one(db) .await .expect("Failed to query job execution") .expect("Job execution not found"); @@ -383,9 +402,10 @@ mod tests { #[tokio::test] async fn test_trigger_job_manually_invalid_name() { - let db = test_db().await; + let test_db = TestDb::new().await; + let db = test_db.connection(); - let result = trigger_job_manually(&db, "invalid_job_name").await; + let result = trigger_job_manually(db, "invalid_job_name").await; assert!(result.is_err()); match result { @@ -398,14 +418,15 @@ mod tests { #[tokio::test] async fn test_job_execution_records_processed() { - let db = test_db().await; + let test_db = TestDb::new().await; + let db = test_db.connection(); - let execution_id = start_job_execution(&db, "test_job") + let execution_id = start_job_execution(db, "test_job") .await .expect("Failed to start job execution"); // Complete with specific record count - complete_job_execution(&db, execution_id, true, None, Some(123)) + complete_job_execution(db, execution_id, true, None, Some(123)) .await .expect("Failed to complete job execution"); @@ -413,7 +434,7 @@ mod tests { use entities::job_execution::{Column, Entity}; let execution = Entity::find() .filter(Column::Id.eq(execution_id)) - .one(&db) + .one(db) .await .expect("Failed to query job execution") .expect("Job execution not found"); diff --git a/src/jwks.rs b/src/jwks.rs index a2b2f93..c844751 100644 --- a/src/jwks.rs +++ b/src/jwks.rs @@ -229,7 +229,8 @@ mod tests { let mut payload = JwtPayload::new(); payload.set_issuer("https://example.com"); payload.set_subject("user123"); - payload.set_expires_at(&(chrono::Utc::now() + chrono::Duration::hours(1))); + let exp_time: std::time::SystemTime = (chrono::Utc::now() + chrono::Duration::hours(1)).into(); + payload.set_expires_at(&exp_time); // Sign the JWT let token = manager.sign_jwt_rs256(&payload) diff --git a/src/storage.rs b/src/storage.rs index c7a5a06..4ea7b60 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -5,7 +5,7 @@ use base64ct::Encoding; use chrono::Utc; use rand::RngCore; use sea_orm::{ - ActiveModelTrait, ColumnTrait, Database, DatabaseConnection, EntityTrait, QueryFilter, Set, + ActiveModelTrait, ColumnTrait, Database, DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, Set, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -61,6 +61,8 @@ pub struct User { pub email_verified: i64, pub created_at: i64, pub enabled: i64, + pub requires_2fa: i64, + pub passkey_enrolled_at: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -72,6 +74,9 @@ pub struct Session { pub expires_at: i64, pub user_agent: Option, pub ip_address: Option, + pub amr: Option, + pub acr: Option, + pub mfa_verified: i64, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -414,6 +419,8 @@ pub async fn get_user_by_username( email_verified: model.email_verified, created_at: model.created_at, enabled: model.enabled, + requires_2fa: model.requires_2fa, + passkey_enrolled_at: model.passkey_enrolled_at, })) } else { Ok(None) @@ -445,26 +452,35 @@ pub async fn verify_user_password( } } -/// Update user enabled and email_verified flags +/// Update user fields (enabled, email, requires_2fa) pub async fn update_user( db: &DatabaseConnection, - username: &str, + subject: &str, enabled: bool, - email_verified: bool, + email: Option, + requires_2fa: Option, ) -> Result<(), CrabError> { use entities::user::{Column, Entity}; - // Find the user + // Find the user by subject let user = Entity::find() - .filter(Column::Username.eq(username)) + .filter(Column::Subject.eq(subject)) .one(db) .await? - .ok_or_else(|| CrabError::Other(format!("User not found: {}", username)))?; + .ok_or_else(|| CrabError::Other(format!("User not found: {}", subject)))?; // Update the user let mut active: entities::user::ActiveModel = user.into(); active.enabled = Set(if enabled { 1 } else { 0 }); - active.email_verified = Set(if email_verified { 1 } else { 0 }); + + if let Some(email_val) = email { + active.email = Set(Some(email_val)); + } + + if let Some(requires_2fa_val) = requires_2fa { + active.requires_2fa = Set(if requires_2fa_val { 1 } else { 0 }); + } + active.update(db).await?; Ok(()) @@ -498,22 +514,26 @@ pub async fn update_user_email( pub async fn create_session( db: &DatabaseConnection, subject: &str, + auth_time: i64, ttl_secs: i64, user_agent: Option, ip_address: Option, ) -> Result { let session_id = random_id(); let now = Utc::now().timestamp(); - let expires_at = now + ttl_secs; + let expires_at = auth_time + ttl_secs; let session = entities::session::ActiveModel { session_id: Set(session_id.clone()), subject: Set(subject.to_string()), - auth_time: Set(now), + auth_time: Set(auth_time), created_at: Set(now), expires_at: Set(expires_at), user_agent: Set(user_agent.clone()), ip_address: Set(ip_address.clone()), + amr: Set(None), + acr: Set(None), + mfa_verified: Set(0), }; session.insert(db).await?; @@ -521,11 +541,14 @@ pub async fn create_session( Ok(Session { session_id, subject: subject.to_string(), - auth_time: now, + auth_time, created_at: now, expires_at, user_agent, ip_address, + amr: None, + acr: None, + mfa_verified: 0, }) } @@ -554,6 +577,9 @@ pub async fn get_session( expires_at: model.expires_at, user_agent: model.user_agent, ip_address: model.ip_address, + amr: model.amr, + acr: model.acr, + mfa_verified: model.mfa_verified, })) } else { Ok(None) @@ -706,6 +732,316 @@ pub async fn cleanup_expired_refresh_tokens(db: &DatabaseConnection) -> Result Result, CrabError> { + use entities::user::{Column, Entity}; + + if let Some(model) = Entity::find() + .filter(Column::Subject.eq(subject)) + .one(db) + .await? + { + Ok(Some(User { + subject: model.subject, + username: model.username, + password_hash: model.password_hash, + email: model.email, + email_verified: model.email_verified, + created_at: model.created_at, + enabled: model.enabled, + requires_2fa: model.requires_2fa, + passkey_enrolled_at: model.passkey_enrolled_at, + })) + } else { + Ok(None) + } +} + +// Passkey management functions + +pub async fn create_passkey( + db: &DatabaseConnection, + credential_id: &str, + subject: &str, + passkey_json: &str, + counter: i64, + aaguid: Option, + backup_eligible: bool, + backup_state: bool, + transports: Option, + name: Option<&str>, +) -> Result { + use entities::passkey; + + let now = Utc::now().timestamp(); + + let passkey = passkey::ActiveModel { + credential_id: Set(credential_id.to_string()), + subject: Set(subject.to_string()), + public_key_cose: Set(passkey_json.to_string()), + counter: Set(counter), + aaguid: Set(aaguid), + backup_eligible: Set(if backup_eligible { 1 } else { 0 }), + backup_state: Set(if backup_state { 1 } else { 0 }), + transports: Set(transports), + name: Set(name.map(|n| n.to_string())), + created_at: Set(now), + last_used_at: Set(None), + }; + + let result = passkey.insert(db).await?; + + // Update user's passkey_enrolled_at if this is their first passkey + use entities::user::{Column as UserColumn, Entity as UserEntity}; + + if let Some(user) = UserEntity::find() + .filter(UserColumn::Subject.eq(subject)) + .one(db) + .await? + { + if user.passkey_enrolled_at.is_none() { + let mut user_active: entities::user::ActiveModel = user.into(); + user_active.passkey_enrolled_at = Set(Some(now)); + user_active.update(db).await?; + } + } + + Ok(result) +} + +pub async fn get_passkey_by_credential_id( + db: &DatabaseConnection, + credential_id: &str, +) -> Result, CrabError> { + use entities::passkey::{Column, Entity}; + + let passkey = Entity::find() + .filter(Column::CredentialId.eq(credential_id)) + .one(db) + .await?; + + Ok(passkey) +} + +pub async fn get_passkeys_by_subject( + db: &DatabaseConnection, + subject: &str, +) -> Result, CrabError> { + use entities::passkey::{Column, Entity}; + + let passkeys = Entity::find() + .filter(Column::Subject.eq(subject)) + .all(db) + .await?; + + Ok(passkeys) +} + +pub async fn update_passkey_counter( + db: &DatabaseConnection, + credential_id: &str, + new_counter: i64, +) -> Result<(), CrabError> { + use entities::passkey::{Column, Entity}; + + let now = Utc::now().timestamp(); + + if let Some(passkey) = Entity::find() + .filter(Column::CredentialId.eq(credential_id)) + .one(db) + .await? + { + let mut active: entities::passkey::ActiveModel = passkey.into(); + active.counter = Set(new_counter); + active.last_used_at = Set(Some(now)); + active.update(db).await?; + } + + Ok(()) +} + +pub async fn update_passkey_name( + db: &DatabaseConnection, + credential_id: &str, + name: Option, +) -> Result<(), CrabError> { + use entities::passkey::{Column, Entity}; + + if let Some(passkey) = Entity::find() + .filter(Column::CredentialId.eq(credential_id)) + .one(db) + .await? + { + let mut active: entities::passkey::ActiveModel = passkey.into(); + active.name = Set(name); + active.update(db).await?; + } + + Ok(()) +} + +pub async fn delete_passkey( + db: &DatabaseConnection, + credential_id: &str, +) -> Result<(), CrabError> { + use entities::passkey::{Column, Entity}; + + Entity::delete_many() + .filter(Column::CredentialId.eq(credential_id)) + .exec(db) + .await?; + + Ok(()) +} + +// WebAuthn challenge management + +pub async fn create_webauthn_challenge( + db: &DatabaseConnection, + challenge: &str, + subject: Option<&str>, + session_id: Option<&str>, + challenge_type: &str, + options_json: &str, +) -> Result { + use entities::webauthn_challenge; + + let now = Utc::now().timestamp(); + let expires_at = now + 300; // 5 minutes + + let challenge_model = webauthn_challenge::ActiveModel { + challenge: Set(challenge.to_string()), + subject: Set(subject.map(|s| s.to_string())), + session_id: Set(session_id.map(|s| s.to_string())), + challenge_type: Set(challenge_type.to_string()), + options_json: Set(options_json.to_string()), + created_at: Set(now), + expires_at: Set(expires_at), + }; + + let result = challenge_model.insert(db).await?; + Ok(result) +} + +pub async fn get_latest_webauthn_challenge_by_subject( + db: &DatabaseConnection, + subject: &str, + challenge_type: &str, +) -> Result, CrabError> { + use entities::webauthn_challenge::{Column, Entity}; + + let now = Utc::now().timestamp(); + + let challenge = Entity::find() + .filter(Column::Subject.eq(subject)) + .filter(Column::ChallengeType.eq(challenge_type)) + .filter(Column::ExpiresAt.gt(now)) + .order_by_desc(Column::CreatedAt) + .one(db) + .await?; + + Ok(challenge) +} + +pub async fn delete_webauthn_challenge( + db: &DatabaseConnection, + challenge: &str, +) -> Result<(), CrabError> { + use entities::webauthn_challenge::{Column, Entity}; + + Entity::delete_many() + .filter(Column::Challenge.eq(challenge)) + .exec(db) + .await?; + + Ok(()) +} + +pub async fn cleanup_expired_challenges(db: &DatabaseConnection) -> Result { + use entities::webauthn_challenge::{Column, Entity}; + + let now = Utc::now().timestamp(); + let result = Entity::delete_many() + .filter(Column::ExpiresAt.lt(now)) + .exec(db) + .await?; + + Ok(result.rows_affected) +} + +// Session authentication context management + +pub async fn update_session_auth_context( + db: &DatabaseConnection, + session_id: &str, + amr: Option<&str>, + acr: Option<&str>, + mfa_verified: Option, +) -> Result<(), CrabError> { + use entities::session::{Column, Entity}; + + if let Some(session) = Entity::find() + .filter(Column::SessionId.eq(session_id)) + .one(db) + .await? + { + let mut active: entities::session::ActiveModel = session.into(); + + if let Some(amr_val) = amr { + active.amr = Set(Some(amr_val.to_string())); + } + + if let Some(acr_val) = acr { + active.acr = Set(Some(acr_val.to_string())); + } + + if let Some(mfa_val) = mfa_verified { + active.mfa_verified = Set(if mfa_val { 1 } else { 0 }); + } + + active.update(db).await?; + } + + Ok(()) +} + +pub async fn append_session_amr( + db: &DatabaseConnection, + session_id: &str, + new_amr: &str, +) -> Result<(), CrabError> { + use entities::session::{Column, Entity}; + + if let Some(session) = Entity::find() + .filter(Column::SessionId.eq(session_id)) + .one(db) + .await? + { + let mut amr_array: Vec = if let Some(existing_amr) = &session.amr { + serde_json::from_str(existing_amr).unwrap_or_else(|_| vec![]) + } else { + vec![] + }; + + // Only append if not already present + if !amr_array.contains(&new_amr.to_string()) { + amr_array.push(new_amr.to_string()); + } + + let amr_json = serde_json::to_string(&amr_array)?; + + let mut active: entities::session::ActiveModel = session.into(); + active.amr = Set(Some(amr_json)); + active.update(db).await?; + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -811,8 +1147,7 @@ mod tests { .expect("Failed to get client") .expect("Client not found"); - let parsed_uris: Vec = serde_json::from_str(&retrieved.redirect_uris) - .expect("Failed to parse redirect_uris"); + let parsed_uris = retrieved.redirect_uris.clone(); assert_eq!(parsed_uris, uris); } @@ -827,18 +1162,21 @@ mod tests { let code = issue_auth_code( &db, - "test_subject", "test_client_id", - "openid profile", - Some("test_nonce"), "http://localhost:3000/callback", - Some("challenge_string"), - Some("S256"), + "openid profile", + "test_subject", + Some("test_nonce".to_string()), + "challenge_string", + "S256", + 300, // 5 minutes TTL + None, // auth_time ) .await .expect("Failed to issue auth code"); - assert!(!code.is_empty()); + assert!(!code.code.is_empty()); + assert_eq!(code.subject, "test_subject"); } #[tokio::test] @@ -847,18 +1185,20 @@ mod tests { let code = issue_auth_code( &db, - "test_subject", "test_client_id", - "openid profile", - Some("test_nonce"), "http://localhost:3000/callback", - Some("challenge_string"), - Some("S256"), + "openid profile", + "test_subject", + Some("test_nonce".to_string()), + "challenge_string", + "S256", + 300, + None, ) .await .expect("Failed to issue auth code"); - let auth_code = consume_auth_code(&db, &code) + let auth_code = consume_auth_code(&db, &code.code) .await .expect("Failed to consume auth code") .expect("Auth code not found"); @@ -874,25 +1214,27 @@ mod tests { let code = issue_auth_code( &db, - "test_subject", "test_client_id", - "openid profile", - None, "http://localhost:3000/callback", + "openid profile", + "test_subject", None, - None, + "", + "", + 300, // TTL + None, // auth_time ) .await .expect("Failed to issue auth code"); // First consumption succeeds - consume_auth_code(&db, &code) + consume_auth_code(&db, &code.code) .await .expect("Failed to consume auth code") .expect("Auth code not found"); // Second consumption returns None - let result = consume_auth_code(&db, &code) + let result = consume_auth_code(&db, &code.code) .await .expect("Query failed"); @@ -905,13 +1247,15 @@ mod tests { let code = issue_auth_code( &db, - "test_subject", "test_client_id", - "openid profile", - None, "http://localhost:3000/callback", + "openid profile", + "test_subject", None, - None, + "", + "", + 300, // TTL + None, // auth_time ) .await .expect("Failed to issue auth code"); @@ -925,13 +1269,13 @@ mod tests { Entity::update_many() .col_expr(Column::ExpiresAt, sea_orm::sea_query::Expr::value(past_timestamp)) - .filter(Column::Code.eq(&code)) + .filter(Column::Code.eq(&code.code)) .exec(&db) .await .expect("Failed to update expiry"); // Consumption should return None for expired code - let result = consume_auth_code(&db, &code) + let result = consume_auth_code(&db, &code.code) .await .expect("Query failed"); @@ -944,24 +1288,26 @@ mod tests { let code = issue_auth_code( &db, - "test_subject", "test_client_id", - "openid profile", - None, "http://localhost:3000/callback", - Some("challenge_string"), - Some("S256"), + "openid profile", + "test_subject", + None, + "challenge_string", + "S256", + 300, + None, ) .await .expect("Failed to issue auth code"); - let auth_code = consume_auth_code(&db, &code) + let auth_code = consume_auth_code(&db, &code.code) .await .expect("Failed to consume auth code") .expect("Auth code not found"); - assert_eq!(auth_code.code_challenge, Some("challenge_string".to_string())); - assert_eq!(auth_code.code_challenge_method, Some("S256".to_string())); + assert_eq!(auth_code.code_challenge, "challenge_string"); + assert_eq!(auth_code.code_challenge_method, "S256"); } // ============================================================================ @@ -972,22 +1318,26 @@ mod tests { async fn test_issue_access_token() { let db = test_db().await; - let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile") + let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile", + 3600, // TTL + ) .await .expect("Failed to issue access token"); - assert!(!token.is_empty()); + assert!(!token.token.is_empty()); } #[tokio::test] async fn test_get_access_token_valid() { let db = test_db().await; - let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile") + let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile", + 3600, // TTL + ) .await .expect("Failed to issue access token"); - let access_token = get_access_token(&db, &token) + let access_token = get_access_token(&db, &token.token) .await .expect("Failed to get access token") .expect("Access token not found"); @@ -1001,7 +1351,9 @@ mod tests { async fn test_get_access_token_expired() { let db = test_db().await; - let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile") + let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile", + 3600, // TTL + ) .await .expect("Failed to issue access token"); @@ -1013,13 +1365,13 @@ mod tests { Entity::update_many() .col_expr(Column::ExpiresAt, sea_orm::sea_query::Expr::value(past_timestamp)) - .filter(Column::Token.eq(&token)) + .filter(Column::Token.eq(&token.token)) .exec(&db) .await .expect("Failed to update expiry"); // Should return None for expired token - let result = get_access_token(&db, &token) + let result = get_access_token(&db, &token.token) .await .expect("Query failed"); @@ -1030,7 +1382,9 @@ mod tests { async fn test_get_access_token_revoked() { let db = test_db().await; - let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile") + let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile", + 3600, // TTL + ) .await .expect("Failed to issue access token"); @@ -1040,13 +1394,13 @@ mod tests { Entity::update_many() .col_expr(Column::Revoked, sea_orm::sea_query::Expr::value(1)) - .filter(Column::Token.eq(&token)) + .filter(Column::Token.eq(&token.token)) .exec(&db) .await .expect("Failed to revoke token"); // Should return None for revoked token - let result = get_access_token(&db, &token) + let result = get_access_token(&db, &token.token) .await .expect("Query failed"); @@ -1058,38 +1412,44 @@ mod tests { let db = test_db().await; // Create initial refresh token - let token1 = issue_refresh_token(&db, "test_subject", "test_client_id", "openid profile", None) + let token1 = issue_refresh_token(&db, "test_subject", "test_client_id", "openid profile", + 86400, // TTL + None, // parent_token + ) .await .expect("Failed to issue refresh token"); // Rotate to new token - let token2 = issue_refresh_token(&db, "test_subject", "test_client_id", "openid profile", Some(&token1)) + let token2 = issue_refresh_token(&db, "test_client_id", "test_subject", "openid profile", 86400, Some(token1.token.clone())) .await .expect("Failed to rotate refresh token"); // Verify parent chain - let rt2 = get_refresh_token(&db, &token2) + let rt2 = get_refresh_token(&db, &token2.token) .await .expect("Failed to get token") .expect("Token not found"); - assert_eq!(rt2.parent_refresh_token, Some(token1)); + assert_eq!(rt2.parent_token, Some(token1.token)); } #[tokio::test] async fn test_revoke_refresh_token() { let db = test_db().await; - let token = issue_refresh_token(&db, "test_subject", "test_client_id", "openid profile", None) + let token = issue_refresh_token(&db, "test_subject", "test_client_id", "openid profile", + 86400, // TTL + None, // parent_token + ) .await .expect("Failed to issue refresh token"); - revoke_refresh_token(&db, &token) + revoke_refresh_token(&db, &token.token) .await .expect("Failed to revoke token"); // Should return None for revoked token - let result = get_refresh_token(&db, &token) + let result = get_refresh_token(&db, &token.token) .await .expect("Query failed"); @@ -1140,12 +1500,13 @@ mod tests { .await .expect("Failed to create user"); - let user = verify_user_password(&db, "testuser", "password123") + let subject = verify_user_password(&db, "testuser", "password123") .await .expect("Failed to verify password") .expect("Verification failed"); - assert_eq!(user.username, "testuser"); + // Verify it's a valid subject (not empty) + assert!(!subject.is_empty()); } #[tokio::test] @@ -1192,7 +1553,7 @@ mod tests { .await .expect("Failed to create user"); - update_user(&db, &user.subject, false, Some("test@example.com"), Some(true)) + update_user(&db, &user.subject, false, Some("test@example.com".to_string()), Some(true)) .await .expect("Failed to update user"); @@ -1214,7 +1575,7 @@ mod tests { .await .expect("Failed to create user"); - update_user(&db, &user.subject, true, Some("new@example.com"), None) + update_user(&db, &user.subject, true, Some("new@example.com".to_string()), None) .await .expect("Failed to update email"); diff --git a/src/user_sync.rs b/src/user_sync.rs index f5d65e8..0460d4c 100644 --- a/src/user_sync.rs +++ b/src/user_sync.rs @@ -98,7 +98,7 @@ async fn sync_user(db: &DatabaseConnection, user_def: &UserDefinition) -> Result None => { // Create new user tracing::info!("Creating user: {}", user_def.username); - storage::create_user( + let user = storage::create_user( db, &user_def.username, &user_def.password, @@ -107,13 +107,14 @@ async fn sync_user(db: &DatabaseConnection, user_def: &UserDefinition) -> Result .await .into_diagnostic()?; - // Update enabled and email_verified flags if needed - if !user_def.enabled || user_def.email_verified { + // Update enabled flag if needed (email already set during creation) + if !user_def.enabled { storage::update_user( db, - &user_def.username, + &user.subject, user_def.enabled, - user_def.email_verified, + None, // email already set + None, // requires_2fa not set during sync ) .await .into_diagnostic()?; @@ -128,24 +129,18 @@ async fn sync_user(db: &DatabaseConnection, user_def: &UserDefinition) -> Result (existing_user.email_verified == 1) == user_def.email_verified; let email_matches = existing_user.email == user_def.email; - if !enabled_matches || !email_verified_matches || !email_matches { + if !enabled_matches || !email_matches { tracing::info!("Updating user: {}", user_def.username); storage::update_user( db, - &user_def.username, + &existing_user.subject, user_def.enabled, - user_def.email_verified, + if !email_matches { user_def.email.clone() } else { None }, + None, // requires_2fa not set during sync ) .await .into_diagnostic()?; - // Update email if it changed - if !email_matches { - storage::update_user_email(db, &user_def.username, user_def.email.clone()) - .await - .into_diagnostic()?; - } - SyncResult::Updated } else { SyncResult::Unchanged diff --git a/src/web.rs b/src/web.rs index 1de18c0..0711732 100644 --- a/src/web.rs +++ b/src/web.rs @@ -1633,7 +1633,8 @@ async fn login_submit( .and_then(|h| h.to_str().ok()) .map(String::from); - let session = match storage::create_session(&state.db, &subject, 3600, user_agent, None).await { + let now = chrono::Utc::now().timestamp(); + let session = match storage::create_session(&state.db, &subject, now, 3600, user_agent, None).await { Ok(s) => s, Err(_) => { let return_to = urlencoded(&form.return_to.unwrap_or_default()); @@ -1918,11 +1919,10 @@ async fn passkey_register_start( storage::create_webauthn_challenge( &state.db, &challenge_b64, - Some(session.subject.clone()), + Some(&session.subject), None, "registration", &options_json, - 300, // 5 minutes ) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; @@ -2009,12 +2009,12 @@ async fn passkey_register_finish( &cred_id_b64, &session.subject, &passkey_json, - 0, // counter - TODO: extract from passkey when we understand the API - None, // aaguid - false, // backup_eligible - TODO: extract from passkey - false, // backup_state - TODO: extract from passkey - None, // transports - req.name.clone(), // Name from request + 0, // counter - TODO: extract from passkey when we understand the API + None, // aaguid + false, // backup_eligible - TODO: extract from passkey + false, // backup_state - TODO: extract from passkey + None, // transports + req.name.as_deref(), // Name from request ) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; @@ -2091,11 +2091,10 @@ async fn passkey_auth_start( storage::create_webauthn_challenge( &state.db, &challenge_b64, - subject, + subject.as_deref(), None, "authentication", &options_json, - 300, // 5 minutes ) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; @@ -2161,17 +2160,20 @@ async fn passkey_auth_finish( }; // Create session - let session = storage::create_session(&state.db, &passkey.subject, 3600, None, None) + let now = chrono::Utc::now().timestamp(); + let session = storage::create_session(&state.db, &passkey.subject, now, 3600, None, None) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; // Update session with passkey AMR + let amr_json = serde_json::to_string(&amr) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; storage::update_session_auth_context( &state.db, &session.session_id, - Some(amr), - Some("aal1".to_string()), - false, + Some(&amr_json), + Some("aal1"), + Some(false), ) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; @@ -2249,11 +2251,10 @@ async fn passkey_2fa_start( storage::create_webauthn_challenge( &state.db, &challenge_b64, - Some(session.subject.clone()), - Some(session.session_id.clone()), + Some(&session.subject), + Some(&session.session_id), "2fa", &options_json, - 300, ) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; @@ -2340,8 +2341,8 @@ async fn passkey_2fa_finish( &state.db, &session.session_id, None, - Some("aal2".to_string()), - true, + Some("aal2"), + Some(true), ) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;