diff --git a/src/handler/health.rs b/src/handler/health.rs index 69dfde2..2057fd3 100644 --- a/src/handler/health.rs +++ b/src/handler/health.rs @@ -1,11 +1,23 @@ +use axum::extract::State; use axum::http::StatusCode; use axum::routing::get; use axum::Router; +use sea_orm::{ConnectionTrait, Statement}; use crate::state::AppState; -async fn healthz() -> StatusCode { - StatusCode::OK +async fn healthz(State(state): State) -> StatusCode { + match state + .db + .execute(Statement::from_string( + sea_orm::DatabaseBackend::Sqlite, + "SELECT 1".to_string(), + )) + .await + { + Ok(_) => StatusCode::OK, + Err(_) => StatusCode::SERVICE_UNAVAILABLE, + } } pub fn router() -> Router { diff --git a/src/handler/metrics.rs b/src/handler/metrics.rs new file mode 100644 index 0000000..b33632f --- /dev/null +++ b/src/handler/metrics.rs @@ -0,0 +1,13 @@ +use axum::extract::State; +use axum::routing::get; +use axum::Router; + +use crate::state::AppState; + +async fn metrics(State(state): State) -> String { + state.metrics_handle.render() +} + +pub fn router() -> Router { + Router::new().route("/metrics", get(metrics)) +} diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 5e17f69..7435a86 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -2,6 +2,7 @@ pub mod domains; mod health; mod host_meta; pub mod links; +mod metrics; pub mod tokens; mod webfinger; @@ -45,5 +46,6 @@ pub fn router(state: AppState) -> Router { .merge(public_routes) .merge(api_routes) .merge(health::router()) + .merge(metrics::router()) .with_state(state) } diff --git a/src/handler/webfinger.rs b/src/handler/webfinger.rs index 7cf2c13..b421f51 100644 --- a/src/handler/webfinger.rs +++ b/src/handler/webfinger.rs @@ -35,15 +35,36 @@ async fn webfinger( State(state): State, uri: Uri, ) -> AppResult { + let start = std::time::Instant::now(); + let (resource_opt, rels) = parse_webfinger_query(&uri); let resource = resource_opt .ok_or_else(|| AppError::BadRequest("missing resource parameter".into()))?; - let cached = state - .cache - .get(&resource) - .ok_or(AppError::NotFound)?; + // Extract domain from resource for metrics labeling + let resource_domain = resource + .split('@') + .nth(1) + .or_else(|| { + // Handle URI-style resources like https://domain/path + resource.split("://").nth(1).and_then(|s| s.split('/').next()) + }) + .unwrap_or("unknown") + .to_string(); + + let cached = match state.cache.get(&resource) { + Some(c) => { + metrics::counter!("webfinger_queries_total", "domain" => resource_domain.clone(), "status" => "hit").increment(1); + c + } + None => { + metrics::counter!("webfinger_queries_total", "domain" => resource_domain.clone(), "status" => "miss").increment(1); + let elapsed = start.elapsed().as_secs_f64(); + metrics::histogram!("webfinger_query_duration_seconds").record(elapsed); + return Err(AppError::NotFound); + } + }; let links: Vec = cached .links @@ -94,6 +115,9 @@ async fn webfinger( response_body.insert("links".into(), json!(links)); + let elapsed = start.elapsed().as_secs_f64(); + metrics::histogram!("webfinger_query_duration_seconds").record(elapsed); + Ok(( [ (header::CONTENT_TYPE, "application/jrd+json"), diff --git a/src/state.rs b/src/state.rs index da8de5f..8e218da 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,3 +1,4 @@ +use metrics_exporter_prometheus::PrometheusHandle; use sea_orm::DatabaseConnection; use std::sync::Arc; @@ -11,4 +12,5 @@ pub struct AppState { pub cache: Cache, pub settings: Arc, pub challenge_verifier: Arc, + pub metrics_handle: PrometheusHandle, } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 7a04c02..8c86dde 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,3 +1,4 @@ +use metrics_exporter_prometheus::PrometheusBuilder; use sea_orm::{ConnectOptions, ConnectionTrait, Database, DatabaseConnection, Statement}; use sea_orm_migration::MigratorTrait; use std::sync::Arc; @@ -57,10 +58,20 @@ pub async fn test_state_with_settings(settings: Settings) -> AppState { let db = setup_test_db().await; let cache = Cache::new(); cache.hydrate(&db).await.unwrap(); + // Each test gets its own metrics recorder. If install_recorder fails because + // another test already installed one, build a standalone recorder and grab its handle. + let metrics_handle = PrometheusBuilder::new() + .install_recorder() + .unwrap_or_else(|_| { + let recorder = PrometheusBuilder::new().build_recorder(); + recorder.handle() + }); + AppState { db, cache, settings: Arc::new(settings), challenge_verifier: Arc::new(webfingerd::challenge::MockChallengeVerifier), + metrics_handle, } }