Commit work in progress

Signed-off-by: Till Wegmueller <toasterson@gmail.com>
This commit is contained in:
Till Wegmueller 2026-01-06 10:56:23 +01:00
parent d7bdd51164
commit 86c88d8aee
No known key found for this signature in database
5 changed files with 529 additions and 150 deletions

View file

@ -22,23 +22,23 @@ pub async fn init_scheduler(db: DatabaseConnection) -> Result<JobScheduler, Crab
let db = db_clone.clone(); let db = db_clone.clone();
Box::pin(async move { Box::pin(async move {
info!("Running cleanup_expired_sessions job"); info!("Running cleanup_expired_sessions job");
let execution_id = start_job_execution(&db, "cleanup_expired_sessions") let execution_id = start_job_execution(db, "cleanup_expired_sessions")
.await .await
.ok(); .ok();
match storage::cleanup_expired_sessions(&db).await { match storage::cleanup_expired_sessions(db).await {
Ok(count) => { Ok(count) => {
info!("Cleaned up {} expired sessions", count); info!("Cleaned up {} expired sessions", count);
if let Some(id) = execution_id { if let Some(id) = execution_id {
let _ = let _ =
complete_job_execution(&db, id, true, None, Some(count as i64)).await; complete_job_execution(db, id, true, None, Some(count as i64)).await;
} }
} }
Err(e) => { Err(e) => {
error!("Failed to cleanup expired sessions: {}", e); error!("Failed to cleanup expired sessions: {}", e);
if let Some(id) = execution_id { if let Some(id) = execution_id {
let _ = let _ =
complete_job_execution(&db, id, false, Some(e.to_string()), None).await; complete_job_execution(db, id, false, Some(e.to_string()), None).await;
} }
} }
} }
@ -58,23 +58,23 @@ pub async fn init_scheduler(db: DatabaseConnection) -> Result<JobScheduler, Crab
let db = db_clone.clone(); let db = db_clone.clone();
Box::pin(async move { Box::pin(async move {
info!("Running cleanup_expired_refresh_tokens job"); info!("Running cleanup_expired_refresh_tokens job");
let execution_id = start_job_execution(&db, "cleanup_expired_refresh_tokens") let execution_id = start_job_execution(db, "cleanup_expired_refresh_tokens")
.await .await
.ok(); .ok();
match storage::cleanup_expired_refresh_tokens(&db).await { match storage::cleanup_expired_refresh_tokens(db).await {
Ok(count) => { Ok(count) => {
info!("Cleaned up {} expired refresh tokens", count); info!("Cleaned up {} expired refresh tokens", count);
if let Some(id) = execution_id { if let Some(id) = execution_id {
let _ = let _ =
complete_job_execution(&db, id, true, None, Some(count as i64)).await; complete_job_execution(db, id, true, None, Some(count as i64)).await;
} }
} }
Err(e) => { Err(e) => {
error!("Failed to cleanup expired refresh tokens: {}", e); error!("Failed to cleanup expired refresh tokens: {}", e);
if let Some(id) = execution_id { if let Some(id) = execution_id {
let _ = let _ =
complete_job_execution(&db, id, false, Some(e.to_string()), None).await; complete_job_execution(db, id, false, Some(e.to_string()), None).await;
} }
} }
} }
@ -94,23 +94,23 @@ pub async fn init_scheduler(db: DatabaseConnection) -> Result<JobScheduler, Crab
let db = db_clone.clone(); let db = db_clone.clone();
Box::pin(async move { Box::pin(async move {
info!("Running cleanup_expired_challenges job"); info!("Running cleanup_expired_challenges job");
let execution_id = start_job_execution(&db, "cleanup_expired_challenges") let execution_id = start_job_execution(db, "cleanup_expired_challenges")
.await .await
.ok(); .ok();
match storage::cleanup_expired_challenges(&db).await { match storage::cleanup_expired_challenges(db).await {
Ok(count) => { Ok(count) => {
info!("Cleaned up {} expired WebAuthn challenges", count); info!("Cleaned up {} expired WebAuthn challenges", count);
if let Some(id) = execution_id { if let Some(id) = execution_id {
let _ = let _ =
complete_job_execution(&db, id, true, None, Some(count as i64)).await; complete_job_execution(db, id, true, None, Some(count as i64)).await;
} }
} }
Err(e) => { Err(e) => {
error!("Failed to cleanup expired challenges: {}", e); error!("Failed to cleanup expired challenges: {}", e);
if let Some(id) = execution_id { if let Some(id) = execution_id {
let _ = let _ =
complete_job_execution(&db, id, false, Some(e.to_string()), None).await; complete_job_execution(db, id, false, Some(e.to_string()), None).await;
} }
} }
} }
@ -226,28 +226,43 @@ mod tests {
use sea_orm_migration::MigratorTrait; use sea_orm_migration::MigratorTrait;
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
/// Helper to create an in-memory test database /// Test database helper that keeps temp file alive
async fn test_db() -> DatabaseConnection { struct TestDb {
let temp_file = NamedTempFile::new().expect("Failed to create temp file"); connection: DatabaseConnection,
let db_path = temp_file.path().to_str().expect("Invalid temp file path"); _temp_file: NamedTempFile,
let db_url = format!("sqlite://{}?mode=rwc", db_path); }
let db = Database::connect(&db_url) impl TestDb {
.await async fn new() -> Self {
.expect("Failed to connect to test database"); let temp_file = NamedTempFile::new().expect("Failed to create temp file");
let db_path = temp_file.path().to_str().expect("Invalid temp file path");
let db_url = format!("sqlite://{}?mode=rwc", db_path);
migration::Migrator::up(&db, None) let connection = Database::connect(&db_url)
.await .await
.expect("Failed to run migrations"); .expect("Failed to connect to test database");
db migration::Migrator::up(&connection, None)
.await
.expect("Failed to run migrations");
Self {
connection,
_temp_file: temp_file,
}
}
fn connection(&self) -> &DatabaseConnection {
&self.connection
}
} }
#[tokio::test] #[tokio::test]
async fn test_start_job_execution() { async fn test_start_job_execution() {
let db = test_db().await; let test_db = TestDb::new().await;
let db = test_db.connection();
let execution_id = start_job_execution(&db, "test_job") let execution_id = start_job_execution(db, "test_job")
.await .await
.expect("Failed to start job execution"); .expect("Failed to start job execution");
@ -257,7 +272,7 @@ mod tests {
use entities::job_execution::{Column, Entity}; use entities::job_execution::{Column, Entity};
let execution = Entity::find() let execution = Entity::find()
.filter(Column::Id.eq(execution_id)) .filter(Column::Id.eq(execution_id))
.one(&db) .one(db)
.await .await
.expect("Failed to query job execution") .expect("Failed to query job execution")
.expect("Job execution not found"); .expect("Job execution not found");
@ -270,13 +285,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_complete_job_execution_success() { async fn test_complete_job_execution_success() {
let db = test_db().await; let test_db = TestDb::new().await;
let db = test_db.connection();
let execution_id = start_job_execution(&db, "test_job") let execution_id = start_job_execution(db, "test_job")
.await .await
.expect("Failed to start job execution"); .expect("Failed to start job execution");
complete_job_execution(&db, execution_id, true, None, Some(42)) complete_job_execution(db, execution_id, true, None, Some(42))
.await .await
.expect("Failed to complete job execution"); .expect("Failed to complete job execution");
@ -284,7 +300,7 @@ mod tests {
use entities::job_execution::{Column, Entity}; use entities::job_execution::{Column, Entity};
let execution = Entity::find() let execution = Entity::find()
.filter(Column::Id.eq(execution_id)) .filter(Column::Id.eq(execution_id))
.one(&db) .one(db)
.await .await
.expect("Failed to query job execution") .expect("Failed to query job execution")
.expect("Job execution not found"); .expect("Job execution not found");
@ -297,9 +313,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_complete_job_execution_failure() { async fn test_complete_job_execution_failure() {
let db = test_db().await; let test_db = TestDb::new().await;
let db = test_db.connection();
let execution_id = start_job_execution(&db, "test_job") let execution_id = start_job_execution(db, "test_job")
.await .await
.expect("Failed to start job execution"); .expect("Failed to start job execution");
@ -317,7 +334,7 @@ mod tests {
use entities::job_execution::{Column, Entity}; use entities::job_execution::{Column, Entity};
let execution = Entity::find() let execution = Entity::find()
.filter(Column::Id.eq(execution_id)) .filter(Column::Id.eq(execution_id))
.one(&db) .one(db)
.await .await
.expect("Failed to query job execution") .expect("Failed to query job execution")
.expect("Job execution not found"); .expect("Job execution not found");
@ -330,20 +347,21 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_trigger_job_manually_cleanup_sessions() { async fn test_trigger_job_manually_cleanup_sessions() {
let db = test_db().await; let test_db = TestDb::new().await;
let db = test_db.connection();
// Create an expired session // Create an expired session
let user = storage::create_user(&db, "testuser", "password123", None) let user = storage::create_user(db, "testuser", "password123", None)
.await .await
.expect("Failed to create user"); .expect("Failed to create user");
let past_auth_time = Utc::now().timestamp() - 7200; // 2 hours ago let past_auth_time = Utc::now().timestamp() - 7200; // 2 hours ago
storage::create_session(&db, &user.subject, past_auth_time, 3600, None, None) // 1 hour TTL storage::create_session(db, &user.subject, past_auth_time, 3600, None, None) // 1 hour TTL
.await .await
.expect("Failed to create session"); .expect("Failed to create session");
// Trigger cleanup job // Trigger cleanup job
trigger_job_manually(&db, "cleanup_expired_sessions") trigger_job_manually(db, "cleanup_expired_sessions")
.await .await
.expect("Failed to trigger job"); .expect("Failed to trigger job");
@ -351,7 +369,7 @@ mod tests {
use entities::job_execution::{Column, Entity}; use entities::job_execution::{Column, Entity};
let execution = Entity::find() let execution = Entity::find()
.filter(Column::JobName.eq("cleanup_expired_sessions")) .filter(Column::JobName.eq("cleanup_expired_sessions"))
.one(&db) .one(db)
.await .await
.expect("Failed to query job execution") .expect("Failed to query job execution")
.expect("Job execution not found"); .expect("Job execution not found");
@ -362,10 +380,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_trigger_job_manually_cleanup_tokens() { async fn test_trigger_job_manually_cleanup_tokens() {
let db = test_db().await; let test_db = TestDb::new().await;
let db = test_db.connection();
// Trigger cleanup_expired_refresh_tokens job // Trigger cleanup_expired_refresh_tokens job
trigger_job_manually(&db, "cleanup_expired_refresh_tokens") trigger_job_manually(db, "cleanup_expired_refresh_tokens")
.await .await
.expect("Failed to trigger job"); .expect("Failed to trigger job");
@ -373,7 +392,7 @@ mod tests {
use entities::job_execution::{Column, Entity}; use entities::job_execution::{Column, Entity};
let execution = Entity::find() let execution = Entity::find()
.filter(Column::JobName.eq("cleanup_expired_refresh_tokens")) .filter(Column::JobName.eq("cleanup_expired_refresh_tokens"))
.one(&db) .one(db)
.await .await
.expect("Failed to query job execution") .expect("Failed to query job execution")
.expect("Job execution not found"); .expect("Job execution not found");
@ -383,9 +402,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_trigger_job_manually_invalid_name() { async fn test_trigger_job_manually_invalid_name() {
let db = test_db().await; let test_db = TestDb::new().await;
let db = test_db.connection();
let result = trigger_job_manually(&db, "invalid_job_name").await; let result = trigger_job_manually(db, "invalid_job_name").await;
assert!(result.is_err()); assert!(result.is_err());
match result { match result {
@ -398,14 +418,15 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_job_execution_records_processed() { async fn test_job_execution_records_processed() {
let db = test_db().await; let test_db = TestDb::new().await;
let db = test_db.connection();
let execution_id = start_job_execution(&db, "test_job") let execution_id = start_job_execution(db, "test_job")
.await .await
.expect("Failed to start job execution"); .expect("Failed to start job execution");
// Complete with specific record count // Complete with specific record count
complete_job_execution(&db, execution_id, true, None, Some(123)) complete_job_execution(db, execution_id, true, None, Some(123))
.await .await
.expect("Failed to complete job execution"); .expect("Failed to complete job execution");
@ -413,7 +434,7 @@ mod tests {
use entities::job_execution::{Column, Entity}; use entities::job_execution::{Column, Entity};
let execution = Entity::find() let execution = Entity::find()
.filter(Column::Id.eq(execution_id)) .filter(Column::Id.eq(execution_id))
.one(&db) .one(db)
.await .await
.expect("Failed to query job execution") .expect("Failed to query job execution")
.expect("Job execution not found"); .expect("Job execution not found");

View file

@ -229,7 +229,8 @@ mod tests {
let mut payload = JwtPayload::new(); let mut payload = JwtPayload::new();
payload.set_issuer("https://example.com"); payload.set_issuer("https://example.com");
payload.set_subject("user123"); payload.set_subject("user123");
payload.set_expires_at(&(chrono::Utc::now() + chrono::Duration::hours(1))); let exp_time: std::time::SystemTime = (chrono::Utc::now() + chrono::Duration::hours(1)).into();
payload.set_expires_at(&exp_time);
// Sign the JWT // Sign the JWT
let token = manager.sign_jwt_rs256(&payload) let token = manager.sign_jwt_rs256(&payload)

View file

@ -5,7 +5,7 @@ use base64ct::Encoding;
use chrono::Utc; use chrono::Utc;
use rand::RngCore; use rand::RngCore;
use sea_orm::{ use sea_orm::{
ActiveModelTrait, ColumnTrait, Database, DatabaseConnection, EntityTrait, QueryFilter, Set, ActiveModelTrait, ColumnTrait, Database, DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, Set,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
@ -61,6 +61,8 @@ pub struct User {
pub email_verified: i64, pub email_verified: i64,
pub created_at: i64, pub created_at: i64,
pub enabled: i64, pub enabled: i64,
pub requires_2fa: i64,
pub passkey_enrolled_at: Option<i64>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -72,6 +74,9 @@ pub struct Session {
pub expires_at: i64, pub expires_at: i64,
pub user_agent: Option<String>, pub user_agent: Option<String>,
pub ip_address: Option<String>, pub ip_address: Option<String>,
pub amr: Option<String>,
pub acr: Option<String>,
pub mfa_verified: i64,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -414,6 +419,8 @@ pub async fn get_user_by_username(
email_verified: model.email_verified, email_verified: model.email_verified,
created_at: model.created_at, created_at: model.created_at,
enabled: model.enabled, enabled: model.enabled,
requires_2fa: model.requires_2fa,
passkey_enrolled_at: model.passkey_enrolled_at,
})) }))
} else { } else {
Ok(None) Ok(None)
@ -445,26 +452,35 @@ pub async fn verify_user_password(
} }
} }
/// Update user enabled and email_verified flags /// Update user fields (enabled, email, requires_2fa)
pub async fn update_user( pub async fn update_user(
db: &DatabaseConnection, db: &DatabaseConnection,
username: &str, subject: &str,
enabled: bool, enabled: bool,
email_verified: bool, email: Option<String>,
requires_2fa: Option<bool>,
) -> Result<(), CrabError> { ) -> Result<(), CrabError> {
use entities::user::{Column, Entity}; use entities::user::{Column, Entity};
// Find the user // Find the user by subject
let user = Entity::find() let user = Entity::find()
.filter(Column::Username.eq(username)) .filter(Column::Subject.eq(subject))
.one(db) .one(db)
.await? .await?
.ok_or_else(|| CrabError::Other(format!("User not found: {}", username)))?; .ok_or_else(|| CrabError::Other(format!("User not found: {}", subject)))?;
// Update the user // Update the user
let mut active: entities::user::ActiveModel = user.into(); let mut active: entities::user::ActiveModel = user.into();
active.enabled = Set(if enabled { 1 } else { 0 }); active.enabled = Set(if enabled { 1 } else { 0 });
active.email_verified = Set(if email_verified { 1 } else { 0 });
if let Some(email_val) = email {
active.email = Set(Some(email_val));
}
if let Some(requires_2fa_val) = requires_2fa {
active.requires_2fa = Set(if requires_2fa_val { 1 } else { 0 });
}
active.update(db).await?; active.update(db).await?;
Ok(()) Ok(())
@ -498,22 +514,26 @@ pub async fn update_user_email(
pub async fn create_session( pub async fn create_session(
db: &DatabaseConnection, db: &DatabaseConnection,
subject: &str, subject: &str,
auth_time: i64,
ttl_secs: i64, ttl_secs: i64,
user_agent: Option<String>, user_agent: Option<String>,
ip_address: Option<String>, ip_address: Option<String>,
) -> Result<Session, CrabError> { ) -> Result<Session, CrabError> {
let session_id = random_id(); let session_id = random_id();
let now = Utc::now().timestamp(); let now = Utc::now().timestamp();
let expires_at = now + ttl_secs; let expires_at = auth_time + ttl_secs;
let session = entities::session::ActiveModel { let session = entities::session::ActiveModel {
session_id: Set(session_id.clone()), session_id: Set(session_id.clone()),
subject: Set(subject.to_string()), subject: Set(subject.to_string()),
auth_time: Set(now), auth_time: Set(auth_time),
created_at: Set(now), created_at: Set(now),
expires_at: Set(expires_at), expires_at: Set(expires_at),
user_agent: Set(user_agent.clone()), user_agent: Set(user_agent.clone()),
ip_address: Set(ip_address.clone()), ip_address: Set(ip_address.clone()),
amr: Set(None),
acr: Set(None),
mfa_verified: Set(0),
}; };
session.insert(db).await?; session.insert(db).await?;
@ -521,11 +541,14 @@ pub async fn create_session(
Ok(Session { Ok(Session {
session_id, session_id,
subject: subject.to_string(), subject: subject.to_string(),
auth_time: now, auth_time,
created_at: now, created_at: now,
expires_at, expires_at,
user_agent, user_agent,
ip_address, ip_address,
amr: None,
acr: None,
mfa_verified: 0,
}) })
} }
@ -554,6 +577,9 @@ pub async fn get_session(
expires_at: model.expires_at, expires_at: model.expires_at,
user_agent: model.user_agent, user_agent: model.user_agent,
ip_address: model.ip_address, ip_address: model.ip_address,
amr: model.amr,
acr: model.acr,
mfa_verified: model.mfa_verified,
})) }))
} else { } else {
Ok(None) Ok(None)
@ -706,6 +732,316 @@ pub async fn cleanup_expired_refresh_tokens(db: &DatabaseConnection) -> Result<u
Ok(result.rows_affected) Ok(result.rows_affected)
} }
// User helper functions
pub async fn get_user_by_subject(
db: &DatabaseConnection,
subject: &str,
) -> Result<Option<User>, CrabError> {
use entities::user::{Column, Entity};
if let Some(model) = Entity::find()
.filter(Column::Subject.eq(subject))
.one(db)
.await?
{
Ok(Some(User {
subject: model.subject,
username: model.username,
password_hash: model.password_hash,
email: model.email,
email_verified: model.email_verified,
created_at: model.created_at,
enabled: model.enabled,
requires_2fa: model.requires_2fa,
passkey_enrolled_at: model.passkey_enrolled_at,
}))
} else {
Ok(None)
}
}
// Passkey management functions
pub async fn create_passkey(
db: &DatabaseConnection,
credential_id: &str,
subject: &str,
passkey_json: &str,
counter: i64,
aaguid: Option<String>,
backup_eligible: bool,
backup_state: bool,
transports: Option<String>,
name: Option<&str>,
) -> Result<entities::passkey::Model, CrabError> {
use entities::passkey;
let now = Utc::now().timestamp();
let passkey = passkey::ActiveModel {
credential_id: Set(credential_id.to_string()),
subject: Set(subject.to_string()),
public_key_cose: Set(passkey_json.to_string()),
counter: Set(counter),
aaguid: Set(aaguid),
backup_eligible: Set(if backup_eligible { 1 } else { 0 }),
backup_state: Set(if backup_state { 1 } else { 0 }),
transports: Set(transports),
name: Set(name.map(|n| n.to_string())),
created_at: Set(now),
last_used_at: Set(None),
};
let result = passkey.insert(db).await?;
// Update user's passkey_enrolled_at if this is their first passkey
use entities::user::{Column as UserColumn, Entity as UserEntity};
if let Some(user) = UserEntity::find()
.filter(UserColumn::Subject.eq(subject))
.one(db)
.await?
{
if user.passkey_enrolled_at.is_none() {
let mut user_active: entities::user::ActiveModel = user.into();
user_active.passkey_enrolled_at = Set(Some(now));
user_active.update(db).await?;
}
}
Ok(result)
}
pub async fn get_passkey_by_credential_id(
db: &DatabaseConnection,
credential_id: &str,
) -> Result<Option<entities::passkey::Model>, CrabError> {
use entities::passkey::{Column, Entity};
let passkey = Entity::find()
.filter(Column::CredentialId.eq(credential_id))
.one(db)
.await?;
Ok(passkey)
}
pub async fn get_passkeys_by_subject(
db: &DatabaseConnection,
subject: &str,
) -> Result<Vec<entities::passkey::Model>, CrabError> {
use entities::passkey::{Column, Entity};
let passkeys = Entity::find()
.filter(Column::Subject.eq(subject))
.all(db)
.await?;
Ok(passkeys)
}
pub async fn update_passkey_counter(
db: &DatabaseConnection,
credential_id: &str,
new_counter: i64,
) -> Result<(), CrabError> {
use entities::passkey::{Column, Entity};
let now = Utc::now().timestamp();
if let Some(passkey) = Entity::find()
.filter(Column::CredentialId.eq(credential_id))
.one(db)
.await?
{
let mut active: entities::passkey::ActiveModel = passkey.into();
active.counter = Set(new_counter);
active.last_used_at = Set(Some(now));
active.update(db).await?;
}
Ok(())
}
pub async fn update_passkey_name(
db: &DatabaseConnection,
credential_id: &str,
name: Option<String>,
) -> Result<(), CrabError> {
use entities::passkey::{Column, Entity};
if let Some(passkey) = Entity::find()
.filter(Column::CredentialId.eq(credential_id))
.one(db)
.await?
{
let mut active: entities::passkey::ActiveModel = passkey.into();
active.name = Set(name);
active.update(db).await?;
}
Ok(())
}
pub async fn delete_passkey(
db: &DatabaseConnection,
credential_id: &str,
) -> Result<(), CrabError> {
use entities::passkey::{Column, Entity};
Entity::delete_many()
.filter(Column::CredentialId.eq(credential_id))
.exec(db)
.await?;
Ok(())
}
// WebAuthn challenge management
pub async fn create_webauthn_challenge(
db: &DatabaseConnection,
challenge: &str,
subject: Option<&str>,
session_id: Option<&str>,
challenge_type: &str,
options_json: &str,
) -> Result<entities::webauthn_challenge::Model, CrabError> {
use entities::webauthn_challenge;
let now = Utc::now().timestamp();
let expires_at = now + 300; // 5 minutes
let challenge_model = webauthn_challenge::ActiveModel {
challenge: Set(challenge.to_string()),
subject: Set(subject.map(|s| s.to_string())),
session_id: Set(session_id.map(|s| s.to_string())),
challenge_type: Set(challenge_type.to_string()),
options_json: Set(options_json.to_string()),
created_at: Set(now),
expires_at: Set(expires_at),
};
let result = challenge_model.insert(db).await?;
Ok(result)
}
pub async fn get_latest_webauthn_challenge_by_subject(
db: &DatabaseConnection,
subject: &str,
challenge_type: &str,
) -> Result<Option<entities::webauthn_challenge::Model>, CrabError> {
use entities::webauthn_challenge::{Column, Entity};
let now = Utc::now().timestamp();
let challenge = Entity::find()
.filter(Column::Subject.eq(subject))
.filter(Column::ChallengeType.eq(challenge_type))
.filter(Column::ExpiresAt.gt(now))
.order_by_desc(Column::CreatedAt)
.one(db)
.await?;
Ok(challenge)
}
pub async fn delete_webauthn_challenge(
db: &DatabaseConnection,
challenge: &str,
) -> Result<(), CrabError> {
use entities::webauthn_challenge::{Column, Entity};
Entity::delete_many()
.filter(Column::Challenge.eq(challenge))
.exec(db)
.await?;
Ok(())
}
pub async fn cleanup_expired_challenges(db: &DatabaseConnection) -> Result<u64, CrabError> {
use entities::webauthn_challenge::{Column, Entity};
let now = Utc::now().timestamp();
let result = Entity::delete_many()
.filter(Column::ExpiresAt.lt(now))
.exec(db)
.await?;
Ok(result.rows_affected)
}
// Session authentication context management
pub async fn update_session_auth_context(
db: &DatabaseConnection,
session_id: &str,
amr: Option<&str>,
acr: Option<&str>,
mfa_verified: Option<bool>,
) -> Result<(), CrabError> {
use entities::session::{Column, Entity};
if let Some(session) = Entity::find()
.filter(Column::SessionId.eq(session_id))
.one(db)
.await?
{
let mut active: entities::session::ActiveModel = session.into();
if let Some(amr_val) = amr {
active.amr = Set(Some(amr_val.to_string()));
}
if let Some(acr_val) = acr {
active.acr = Set(Some(acr_val.to_string()));
}
if let Some(mfa_val) = mfa_verified {
active.mfa_verified = Set(if mfa_val { 1 } else { 0 });
}
active.update(db).await?;
}
Ok(())
}
pub async fn append_session_amr(
db: &DatabaseConnection,
session_id: &str,
new_amr: &str,
) -> Result<(), CrabError> {
use entities::session::{Column, Entity};
if let Some(session) = Entity::find()
.filter(Column::SessionId.eq(session_id))
.one(db)
.await?
{
let mut amr_array: Vec<String> = if let Some(existing_amr) = &session.amr {
serde_json::from_str(existing_amr).unwrap_or_else(|_| vec![])
} else {
vec![]
};
// Only append if not already present
if !amr_array.contains(&new_amr.to_string()) {
amr_array.push(new_amr.to_string());
}
let amr_json = serde_json::to_string(&amr_array)?;
let mut active: entities::session::ActiveModel = session.into();
active.amr = Set(Some(amr_json));
active.update(db).await?;
}
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -811,8 +1147,7 @@ mod tests {
.expect("Failed to get client") .expect("Failed to get client")
.expect("Client not found"); .expect("Client not found");
let parsed_uris: Vec<String> = serde_json::from_str(&retrieved.redirect_uris) let parsed_uris = retrieved.redirect_uris.clone();
.expect("Failed to parse redirect_uris");
assert_eq!(parsed_uris, uris); assert_eq!(parsed_uris, uris);
} }
@ -827,18 +1162,21 @@ mod tests {
let code = issue_auth_code( let code = issue_auth_code(
&db, &db,
"test_subject",
"test_client_id", "test_client_id",
"openid profile",
Some("test_nonce"),
"http://localhost:3000/callback", "http://localhost:3000/callback",
Some("challenge_string"), "openid profile",
Some("S256"), "test_subject",
Some("test_nonce".to_string()),
"challenge_string",
"S256",
300, // 5 minutes TTL
None, // auth_time
) )
.await .await
.expect("Failed to issue auth code"); .expect("Failed to issue auth code");
assert!(!code.is_empty()); assert!(!code.code.is_empty());
assert_eq!(code.subject, "test_subject");
} }
#[tokio::test] #[tokio::test]
@ -847,18 +1185,20 @@ mod tests {
let code = issue_auth_code( let code = issue_auth_code(
&db, &db,
"test_subject",
"test_client_id", "test_client_id",
"openid profile",
Some("test_nonce"),
"http://localhost:3000/callback", "http://localhost:3000/callback",
Some("challenge_string"), "openid profile",
Some("S256"), "test_subject",
Some("test_nonce".to_string()),
"challenge_string",
"S256",
300,
None,
) )
.await .await
.expect("Failed to issue auth code"); .expect("Failed to issue auth code");
let auth_code = consume_auth_code(&db, &code) let auth_code = consume_auth_code(&db, &code.code)
.await .await
.expect("Failed to consume auth code") .expect("Failed to consume auth code")
.expect("Auth code not found"); .expect("Auth code not found");
@ -874,25 +1214,27 @@ mod tests {
let code = issue_auth_code( let code = issue_auth_code(
&db, &db,
"test_subject",
"test_client_id", "test_client_id",
"openid profile",
None,
"http://localhost:3000/callback", "http://localhost:3000/callback",
"openid profile",
"test_subject",
None, None,
None, "",
"",
300, // TTL
None, // auth_time
) )
.await .await
.expect("Failed to issue auth code"); .expect("Failed to issue auth code");
// First consumption succeeds // First consumption succeeds
consume_auth_code(&db, &code) consume_auth_code(&db, &code.code)
.await .await
.expect("Failed to consume auth code") .expect("Failed to consume auth code")
.expect("Auth code not found"); .expect("Auth code not found");
// Second consumption returns None // Second consumption returns None
let result = consume_auth_code(&db, &code) let result = consume_auth_code(&db, &code.code)
.await .await
.expect("Query failed"); .expect("Query failed");
@ -905,13 +1247,15 @@ mod tests {
let code = issue_auth_code( let code = issue_auth_code(
&db, &db,
"test_subject",
"test_client_id", "test_client_id",
"openid profile",
None,
"http://localhost:3000/callback", "http://localhost:3000/callback",
"openid profile",
"test_subject",
None, None,
None, "",
"",
300, // TTL
None, // auth_time
) )
.await .await
.expect("Failed to issue auth code"); .expect("Failed to issue auth code");
@ -925,13 +1269,13 @@ mod tests {
Entity::update_many() Entity::update_many()
.col_expr(Column::ExpiresAt, sea_orm::sea_query::Expr::value(past_timestamp)) .col_expr(Column::ExpiresAt, sea_orm::sea_query::Expr::value(past_timestamp))
.filter(Column::Code.eq(&code)) .filter(Column::Code.eq(&code.code))
.exec(&db) .exec(&db)
.await .await
.expect("Failed to update expiry"); .expect("Failed to update expiry");
// Consumption should return None for expired code // Consumption should return None for expired code
let result = consume_auth_code(&db, &code) let result = consume_auth_code(&db, &code.code)
.await .await
.expect("Query failed"); .expect("Query failed");
@ -944,24 +1288,26 @@ mod tests {
let code = issue_auth_code( let code = issue_auth_code(
&db, &db,
"test_subject",
"test_client_id", "test_client_id",
"openid profile",
None,
"http://localhost:3000/callback", "http://localhost:3000/callback",
Some("challenge_string"), "openid profile",
Some("S256"), "test_subject",
None,
"challenge_string",
"S256",
300,
None,
) )
.await .await
.expect("Failed to issue auth code"); .expect("Failed to issue auth code");
let auth_code = consume_auth_code(&db, &code) let auth_code = consume_auth_code(&db, &code.code)
.await .await
.expect("Failed to consume auth code") .expect("Failed to consume auth code")
.expect("Auth code not found"); .expect("Auth code not found");
assert_eq!(auth_code.code_challenge, Some("challenge_string".to_string())); assert_eq!(auth_code.code_challenge, "challenge_string");
assert_eq!(auth_code.code_challenge_method, Some("S256".to_string())); assert_eq!(auth_code.code_challenge_method, "S256");
} }
// ============================================================================ // ============================================================================
@ -972,22 +1318,26 @@ mod tests {
async fn test_issue_access_token() { async fn test_issue_access_token() {
let db = test_db().await; let db = test_db().await;
let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile") let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile",
3600, // TTL
)
.await .await
.expect("Failed to issue access token"); .expect("Failed to issue access token");
assert!(!token.is_empty()); assert!(!token.token.is_empty());
} }
#[tokio::test] #[tokio::test]
async fn test_get_access_token_valid() { async fn test_get_access_token_valid() {
let db = test_db().await; let db = test_db().await;
let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile") let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile",
3600, // TTL
)
.await .await
.expect("Failed to issue access token"); .expect("Failed to issue access token");
let access_token = get_access_token(&db, &token) let access_token = get_access_token(&db, &token.token)
.await .await
.expect("Failed to get access token") .expect("Failed to get access token")
.expect("Access token not found"); .expect("Access token not found");
@ -1001,7 +1351,9 @@ mod tests {
async fn test_get_access_token_expired() { async fn test_get_access_token_expired() {
let db = test_db().await; let db = test_db().await;
let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile") let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile",
3600, // TTL
)
.await .await
.expect("Failed to issue access token"); .expect("Failed to issue access token");
@ -1013,13 +1365,13 @@ mod tests {
Entity::update_many() Entity::update_many()
.col_expr(Column::ExpiresAt, sea_orm::sea_query::Expr::value(past_timestamp)) .col_expr(Column::ExpiresAt, sea_orm::sea_query::Expr::value(past_timestamp))
.filter(Column::Token.eq(&token)) .filter(Column::Token.eq(&token.token))
.exec(&db) .exec(&db)
.await .await
.expect("Failed to update expiry"); .expect("Failed to update expiry");
// Should return None for expired token // Should return None for expired token
let result = get_access_token(&db, &token) let result = get_access_token(&db, &token.token)
.await .await
.expect("Query failed"); .expect("Query failed");
@ -1030,7 +1382,9 @@ mod tests {
async fn test_get_access_token_revoked() { async fn test_get_access_token_revoked() {
let db = test_db().await; let db = test_db().await;
let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile") let token = issue_access_token(&db, "test_subject", "test_client_id", "openid profile",
3600, // TTL
)
.await .await
.expect("Failed to issue access token"); .expect("Failed to issue access token");
@ -1040,13 +1394,13 @@ mod tests {
Entity::update_many() Entity::update_many()
.col_expr(Column::Revoked, sea_orm::sea_query::Expr::value(1)) .col_expr(Column::Revoked, sea_orm::sea_query::Expr::value(1))
.filter(Column::Token.eq(&token)) .filter(Column::Token.eq(&token.token))
.exec(&db) .exec(&db)
.await .await
.expect("Failed to revoke token"); .expect("Failed to revoke token");
// Should return None for revoked token // Should return None for revoked token
let result = get_access_token(&db, &token) let result = get_access_token(&db, &token.token)
.await .await
.expect("Query failed"); .expect("Query failed");
@ -1058,38 +1412,44 @@ mod tests {
let db = test_db().await; let db = test_db().await;
// Create initial refresh token // Create initial refresh token
let token1 = issue_refresh_token(&db, "test_subject", "test_client_id", "openid profile", None) let token1 = issue_refresh_token(&db, "test_subject", "test_client_id", "openid profile",
86400, // TTL
None, // parent_token
)
.await .await
.expect("Failed to issue refresh token"); .expect("Failed to issue refresh token");
// Rotate to new token // Rotate to new token
let token2 = issue_refresh_token(&db, "test_subject", "test_client_id", "openid profile", Some(&token1)) let token2 = issue_refresh_token(&db, "test_client_id", "test_subject", "openid profile", 86400, Some(token1.token.clone()))
.await .await
.expect("Failed to rotate refresh token"); .expect("Failed to rotate refresh token");
// Verify parent chain // Verify parent chain
let rt2 = get_refresh_token(&db, &token2) let rt2 = get_refresh_token(&db, &token2.token)
.await .await
.expect("Failed to get token") .expect("Failed to get token")
.expect("Token not found"); .expect("Token not found");
assert_eq!(rt2.parent_refresh_token, Some(token1)); assert_eq!(rt2.parent_token, Some(token1.token));
} }
#[tokio::test] #[tokio::test]
async fn test_revoke_refresh_token() { async fn test_revoke_refresh_token() {
let db = test_db().await; let db = test_db().await;
let token = issue_refresh_token(&db, "test_subject", "test_client_id", "openid profile", None) let token = issue_refresh_token(&db, "test_subject", "test_client_id", "openid profile",
86400, // TTL
None, // parent_token
)
.await .await
.expect("Failed to issue refresh token"); .expect("Failed to issue refresh token");
revoke_refresh_token(&db, &token) revoke_refresh_token(&db, &token.token)
.await .await
.expect("Failed to revoke token"); .expect("Failed to revoke token");
// Should return None for revoked token // Should return None for revoked token
let result = get_refresh_token(&db, &token) let result = get_refresh_token(&db, &token.token)
.await .await
.expect("Query failed"); .expect("Query failed");
@ -1140,12 +1500,13 @@ mod tests {
.await .await
.expect("Failed to create user"); .expect("Failed to create user");
let user = verify_user_password(&db, "testuser", "password123") let subject = verify_user_password(&db, "testuser", "password123")
.await .await
.expect("Failed to verify password") .expect("Failed to verify password")
.expect("Verification failed"); .expect("Verification failed");
assert_eq!(user.username, "testuser"); // Verify it's a valid subject (not empty)
assert!(!subject.is_empty());
} }
#[tokio::test] #[tokio::test]
@ -1192,7 +1553,7 @@ mod tests {
.await .await
.expect("Failed to create user"); .expect("Failed to create user");
update_user(&db, &user.subject, false, Some("test@example.com"), Some(true)) update_user(&db, &user.subject, false, Some("test@example.com".to_string()), Some(true))
.await .await
.expect("Failed to update user"); .expect("Failed to update user");
@ -1214,7 +1575,7 @@ mod tests {
.await .await
.expect("Failed to create user"); .expect("Failed to create user");
update_user(&db, &user.subject, true, Some("new@example.com"), None) update_user(&db, &user.subject, true, Some("new@example.com".to_string()), None)
.await .await
.expect("Failed to update email"); .expect("Failed to update email");

View file

@ -98,7 +98,7 @@ async fn sync_user(db: &DatabaseConnection, user_def: &UserDefinition) -> Result
None => { None => {
// Create new user // Create new user
tracing::info!("Creating user: {}", user_def.username); tracing::info!("Creating user: {}", user_def.username);
storage::create_user( let user = storage::create_user(
db, db,
&user_def.username, &user_def.username,
&user_def.password, &user_def.password,
@ -107,13 +107,14 @@ async fn sync_user(db: &DatabaseConnection, user_def: &UserDefinition) -> Result
.await .await
.into_diagnostic()?; .into_diagnostic()?;
// Update enabled and email_verified flags if needed // Update enabled flag if needed (email already set during creation)
if !user_def.enabled || user_def.email_verified { if !user_def.enabled {
storage::update_user( storage::update_user(
db, db,
&user_def.username, &user.subject,
user_def.enabled, user_def.enabled,
user_def.email_verified, None, // email already set
None, // requires_2fa not set during sync
) )
.await .await
.into_diagnostic()?; .into_diagnostic()?;
@ -128,24 +129,18 @@ async fn sync_user(db: &DatabaseConnection, user_def: &UserDefinition) -> Result
(existing_user.email_verified == 1) == user_def.email_verified; (existing_user.email_verified == 1) == user_def.email_verified;
let email_matches = existing_user.email == user_def.email; let email_matches = existing_user.email == user_def.email;
if !enabled_matches || !email_verified_matches || !email_matches { if !enabled_matches || !email_matches {
tracing::info!("Updating user: {}", user_def.username); tracing::info!("Updating user: {}", user_def.username);
storage::update_user( storage::update_user(
db, db,
&user_def.username, &existing_user.subject,
user_def.enabled, user_def.enabled,
user_def.email_verified, if !email_matches { user_def.email.clone() } else { None },
None, // requires_2fa not set during sync
) )
.await .await
.into_diagnostic()?; .into_diagnostic()?;
// Update email if it changed
if !email_matches {
storage::update_user_email(db, &user_def.username, user_def.email.clone())
.await
.into_diagnostic()?;
}
SyncResult::Updated SyncResult::Updated
} else { } else {
SyncResult::Unchanged SyncResult::Unchanged

View file

@ -1633,7 +1633,8 @@ async fn login_submit(
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.map(String::from); .map(String::from);
let session = match storage::create_session(&state.db, &subject, 3600, user_agent, None).await { let now = chrono::Utc::now().timestamp();
let session = match storage::create_session(&state.db, &subject, now, 3600, user_agent, None).await {
Ok(s) => s, Ok(s) => s,
Err(_) => { Err(_) => {
let return_to = urlencoded(&form.return_to.unwrap_or_default()); let return_to = urlencoded(&form.return_to.unwrap_or_default());
@ -1918,11 +1919,10 @@ async fn passkey_register_start(
storage::create_webauthn_challenge( storage::create_webauthn_challenge(
&state.db, &state.db,
&challenge_b64, &challenge_b64,
Some(session.subject.clone()), Some(&session.subject),
None, None,
"registration", "registration",
&options_json, &options_json,
300, // 5 minutes
) )
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
@ -2009,12 +2009,12 @@ async fn passkey_register_finish(
&cred_id_b64, &cred_id_b64,
&session.subject, &session.subject,
&passkey_json, &passkey_json,
0, // counter - TODO: extract from passkey when we understand the API 0, // counter - TODO: extract from passkey when we understand the API
None, // aaguid None, // aaguid
false, // backup_eligible - TODO: extract from passkey false, // backup_eligible - TODO: extract from passkey
false, // backup_state - TODO: extract from passkey false, // backup_state - TODO: extract from passkey
None, // transports None, // transports
req.name.clone(), // Name from request req.name.as_deref(), // Name from request
) )
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
@ -2091,11 +2091,10 @@ async fn passkey_auth_start(
storage::create_webauthn_challenge( storage::create_webauthn_challenge(
&state.db, &state.db,
&challenge_b64, &challenge_b64,
subject, subject.as_deref(),
None, None,
"authentication", "authentication",
&options_json, &options_json,
300, // 5 minutes
) )
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
@ -2161,17 +2160,20 @@ async fn passkey_auth_finish(
}; };
// Create session // Create session
let session = storage::create_session(&state.db, &passkey.subject, 3600, None, None) let now = chrono::Utc::now().timestamp();
let session = storage::create_session(&state.db, &passkey.subject, now, 3600, None, None)
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// Update session with passkey AMR // Update session with passkey AMR
let amr_json = serde_json::to_string(&amr)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
storage::update_session_auth_context( storage::update_session_auth_context(
&state.db, &state.db,
&session.session_id, &session.session_id,
Some(amr), Some(&amr_json),
Some("aal1".to_string()), Some("aal1"),
false, Some(false),
) )
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
@ -2249,11 +2251,10 @@ async fn passkey_2fa_start(
storage::create_webauthn_challenge( storage::create_webauthn_challenge(
&state.db, &state.db,
&challenge_b64, &challenge_b64,
Some(session.subject.clone()), Some(&session.subject),
Some(session.session_id.clone()), Some(&session.session_id),
"2fa", "2fa",
&options_json, &options_json,
300,
) )
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
@ -2340,8 +2341,8 @@ async fn passkey_2fa_finish(
&state.db, &state.db,
&session.session_id, &session.session_id,
None, None,
Some("aal2".to_string()), Some("aal2"),
true, Some(true),
) )
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;