From df3cc1eb917d040efc9c182faf70f8940d9e8fa1 Mon Sep 17 00:00:00 2001 From: Till Wegmueller Date: Fri, 3 Apr 2026 19:43:21 +0200 Subject: [PATCH] feat: add rate limiting, request ID, and CORS middleware --- src/handler/mod.rs | 31 +++++++++++++- src/lib.rs | 1 + src/middleware/mod.rs | 2 + src/middleware/rate_limit.rs | 82 ++++++++++++++++++++++++++++++++++++ src/middleware/request_id.rs | 23 ++++++++++ tests/test_rate_limit.rs | 32 ++++++++++++++ 6 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 src/middleware/mod.rs create mode 100644 src/middleware/rate_limit.rs create mode 100644 src/middleware/request_id.rs create mode 100644 tests/test_rate_limit.rs diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 10f5d8e..5e17f69 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -5,16 +5,45 @@ pub mod links; pub mod tokens; mod webfinger; +use axum::middleware as axum_mw; use axum::Router; +use tower_http::cors::{Any, CorsLayer}; + +use crate::middleware::rate_limit::{self, KeyedLimiter}; use crate::state::AppState; pub fn router(state: AppState) -> Router { - Router::new() + let public_limiter = KeyedLimiter::new(state.settings.rate_limit.public_rpm); + let api_limiter = KeyedLimiter::new(state.settings.rate_limit.api_rpm); + + let public_cors = CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any); + + // Public endpoints: webfinger + host-meta with per-IP rate limit + CORS + let public_routes = Router::new() .merge(webfinger::router()) .merge(host_meta::router()) + .layer(public_cors) + .layer(axum_mw::from_fn(move |req, next| { + let limiter = public_limiter.clone(); + rate_limit::rate_limit_by_ip(limiter, req, next) + })); + + // API endpoints with per-token rate limit (no wildcard CORS) + let api_routes = Router::new() .merge(domains::router()) .merge(tokens::router()) .merge(links::router()) + .layer(axum_mw::from_fn(move |req, next| { + let limiter = api_limiter.clone(); + rate_limit::rate_limit_by_token(limiter, req, next) + })); + + Router::new() + .merge(public_routes) + .merge(api_routes) .merge(health::router()) .with_state(state) } diff --git a/src/lib.rs b/src/lib.rs index 8c30282..4d52411 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,5 +5,6 @@ pub mod config; pub mod entity; pub mod error; pub mod handler; +pub mod middleware; pub mod reaper; pub mod state; diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs new file mode 100644 index 0000000..bf6d350 --- /dev/null +++ b/src/middleware/mod.rs @@ -0,0 +1,2 @@ +pub mod rate_limit; +pub mod request_id; diff --git a/src/middleware/rate_limit.rs b/src/middleware/rate_limit.rs new file mode 100644 index 0000000..a3352bc --- /dev/null +++ b/src/middleware/rate_limit.rs @@ -0,0 +1,82 @@ +use axum::body::Body; +use axum::http::{Request, StatusCode}; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use dashmap::DashMap; +use governor::clock::DefaultClock; +use governor::state::{InMemoryState, NotKeyed}; +use governor::{Quota, RateLimiter}; +use std::num::NonZeroU32; +use std::sync::Arc; + +/// Per-key rate limiter using DashMap for keyed limiting (per IP or per token). +#[derive(Clone)] +pub struct KeyedLimiter { + limiters: Arc>>>, + quota: Quota, +} + +impl KeyedLimiter { + pub fn new(rpm: u32) -> Self { + let quota = Quota::per_minute(NonZeroU32::new(rpm).expect("rpm must be > 0")); + Self { + limiters: Arc::new(DashMap::new()), + quota, + } + } + + pub fn check_key(&self, key: &str) -> bool { + let limiter = self + .limiters + .entry(key.to_string()) + .or_insert_with(|| Arc::new(RateLimiter::direct(self.quota))) + .clone(); + limiter.check().is_ok() + } +} + +/// Rate limit middleware for public endpoints (keyed by client IP). +pub async fn rate_limit_by_ip(limiter: KeyedLimiter, request: Request, next: Next) -> Response { + // Extract IP from x-forwarded-for or fall back to "unknown" + let ip = request + .headers() + .get("x-forwarded-for") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.split(',').next()) + .map(|s| s.trim().to_string()) + .unwrap_or_else(|| "unknown".to_string()); + + if !limiter.check_key(&ip) { + return ( + StatusCode::TOO_MANY_REQUESTS, + [("retry-after", "60")], + "rate limited", + ) + .into_response(); + } + + next.run(request).await +} + +/// Rate limit middleware for API endpoints (keyed by Bearer token prefix). +pub async fn rate_limit_by_token(limiter: KeyedLimiter, request: Request, next: Next) -> Response { + let key = request + .headers() + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) + .and_then(|t| t.split('.').next()) // Use token ID prefix as key + .unwrap_or("anonymous") + .to_string(); + + if !limiter.check_key(&key) { + return ( + StatusCode::TOO_MANY_REQUESTS, + [("retry-after", "60")], + "rate limited", + ) + .into_response(); + } + + next.run(request).await +} diff --git a/src/middleware/request_id.rs b/src/middleware/request_id.rs new file mode 100644 index 0000000..657f2c0 --- /dev/null +++ b/src/middleware/request_id.rs @@ -0,0 +1,23 @@ +use axum::body::Body; +use axum::http::{HeaderName, HeaderValue, Request}; +use axum::middleware::Next; +use axum::response::Response; +use uuid::Uuid; + +static X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); + +/// Middleware that generates a unique request ID and attaches it to both the +/// request (for downstream handlers) and the response (for clients). +pub async fn request_id(mut request: Request, next: Next) -> Response { + let id = Uuid::new_v4().to_string(); + request + .headers_mut() + .insert(X_REQUEST_ID.clone(), HeaderValue::from_str(&id).unwrap()); + + let mut response = next.run(request).await; + response + .headers_mut() + .insert(X_REQUEST_ID.clone(), HeaderValue::from_str(&id).unwrap()); + + response +} diff --git a/tests/test_rate_limit.rs b/tests/test_rate_limit.rs new file mode 100644 index 0000000..265d1cb --- /dev/null +++ b/tests/test_rate_limit.rs @@ -0,0 +1,32 @@ +mod common; + +use axum_test::TestServer; +use webfingerd::handler; + +#[tokio::test] +async fn test_public_rate_limiting() { + let mut settings = common::test_settings(); + settings.rate_limit.public_rpm = 2; // Very low for testing + + let state = common::test_state_with_settings(settings).await; + let app = handler::router(state); + let server = TestServer::new(app); + + // First two requests should succeed (even with 404) + server + .get("/.well-known/webfinger") + .add_query_param("resource", "acct:a@a.com") + .await; + server + .get("/.well-known/webfinger") + .add_query_param("resource", "acct:b@b.com") + .await; + + // Third should be rate limited + let response = server + .get("/.well-known/webfinger") + .add_query_param("resource", "acct:c@c.com") + .await; + + response.assert_status(axum::http::StatusCode::TOO_MANY_REQUESTS); +}