feat: add rate limiting, request ID, and CORS middleware

This commit is contained in:
Till Wegmueller 2026-04-03 19:43:21 +02:00
parent 66b3de433f
commit df3cc1eb91
No known key found for this signature in database
6 changed files with 170 additions and 1 deletions

View file

@ -5,16 +5,45 @@ pub mod links;
pub mod tokens; pub mod tokens;
mod webfinger; mod webfinger;
use axum::middleware as axum_mw;
use axum::Router; use axum::Router;
use tower_http::cors::{Any, CorsLayer};
use crate::middleware::rate_limit::{self, KeyedLimiter};
use crate::state::AppState; use crate::state::AppState;
pub fn router(state: AppState) -> Router { pub fn router(state: AppState) -> Router {
Router::new() let public_limiter = KeyedLimiter::new(state.settings.rate_limit.public_rpm);
let api_limiter = KeyedLimiter::new(state.settings.rate_limit.api_rpm);
let public_cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
// Public endpoints: webfinger + host-meta with per-IP rate limit + CORS
let public_routes = Router::new()
.merge(webfinger::router()) .merge(webfinger::router())
.merge(host_meta::router()) .merge(host_meta::router())
.layer(public_cors)
.layer(axum_mw::from_fn(move |req, next| {
let limiter = public_limiter.clone();
rate_limit::rate_limit_by_ip(limiter, req, next)
}));
// API endpoints with per-token rate limit (no wildcard CORS)
let api_routes = Router::new()
.merge(domains::router()) .merge(domains::router())
.merge(tokens::router()) .merge(tokens::router())
.merge(links::router()) .merge(links::router())
.layer(axum_mw::from_fn(move |req, next| {
let limiter = api_limiter.clone();
rate_limit::rate_limit_by_token(limiter, req, next)
}));
Router::new()
.merge(public_routes)
.merge(api_routes)
.merge(health::router()) .merge(health::router())
.with_state(state) .with_state(state)
} }

View file

@ -5,5 +5,6 @@ pub mod config;
pub mod entity; pub mod entity;
pub mod error; pub mod error;
pub mod handler; pub mod handler;
pub mod middleware;
pub mod reaper; pub mod reaper;
pub mod state; pub mod state;

2
src/middleware/mod.rs Normal file
View file

@ -0,0 +1,2 @@
pub mod rate_limit;
pub mod request_id;

View file

@ -0,0 +1,82 @@
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use dashmap::DashMap;
use governor::clock::DefaultClock;
use governor::state::{InMemoryState, NotKeyed};
use governor::{Quota, RateLimiter};
use std::num::NonZeroU32;
use std::sync::Arc;
/// Per-key rate limiter using DashMap for keyed limiting (per IP or per token).
#[derive(Clone)]
pub struct KeyedLimiter {
limiters: Arc<DashMap<String, Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>>>,
quota: Quota,
}
impl KeyedLimiter {
pub fn new(rpm: u32) -> Self {
let quota = Quota::per_minute(NonZeroU32::new(rpm).expect("rpm must be > 0"));
Self {
limiters: Arc::new(DashMap::new()),
quota,
}
}
pub fn check_key(&self, key: &str) -> bool {
let limiter = self
.limiters
.entry(key.to_string())
.or_insert_with(|| Arc::new(RateLimiter::direct(self.quota)))
.clone();
limiter.check().is_ok()
}
}
/// Rate limit middleware for public endpoints (keyed by client IP).
pub async fn rate_limit_by_ip(limiter: KeyedLimiter, request: Request<Body>, next: Next) -> Response {
// Extract IP from x-forwarded-for or fall back to "unknown"
let ip = request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
.map(|s| s.trim().to_string())
.unwrap_or_else(|| "unknown".to_string());
if !limiter.check_key(&ip) {
return (
StatusCode::TOO_MANY_REQUESTS,
[("retry-after", "60")],
"rate limited",
)
.into_response();
}
next.run(request).await
}
/// Rate limit middleware for API endpoints (keyed by Bearer token prefix).
pub async fn rate_limit_by_token(limiter: KeyedLimiter, request: Request<Body>, next: Next) -> Response {
let key = request
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.and_then(|t| t.split('.').next()) // Use token ID prefix as key
.unwrap_or("anonymous")
.to_string();
if !limiter.check_key(&key) {
return (
StatusCode::TOO_MANY_REQUESTS,
[("retry-after", "60")],
"rate limited",
)
.into_response();
}
next.run(request).await
}

View file

@ -0,0 +1,23 @@
use axum::body::Body;
use axum::http::{HeaderName, HeaderValue, Request};
use axum::middleware::Next;
use axum::response::Response;
use uuid::Uuid;
static X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
/// Middleware that generates a unique request ID and attaches it to both the
/// request (for downstream handlers) and the response (for clients).
pub async fn request_id(mut request: Request<Body>, next: Next) -> Response {
let id = Uuid::new_v4().to_string();
request
.headers_mut()
.insert(X_REQUEST_ID.clone(), HeaderValue::from_str(&id).unwrap());
let mut response = next.run(request).await;
response
.headers_mut()
.insert(X_REQUEST_ID.clone(), HeaderValue::from_str(&id).unwrap());
response
}

32
tests/test_rate_limit.rs Normal file
View file

@ -0,0 +1,32 @@
mod common;
use axum_test::TestServer;
use webfingerd::handler;
#[tokio::test]
async fn test_public_rate_limiting() {
let mut settings = common::test_settings();
settings.rate_limit.public_rpm = 2; // Very low for testing
let state = common::test_state_with_settings(settings).await;
let app = handler::router(state);
let server = TestServer::new(app);
// First two requests should succeed (even with 404)
server
.get("/.well-known/webfinger")
.add_query_param("resource", "acct:a@a.com")
.await;
server
.get("/.well-known/webfinger")
.add_query_param("resource", "acct:b@b.com")
.await;
// Third should be rate limited
let response = server
.get("/.well-known/webfinger")
.add_query_param("resource", "acct:c@c.com")
.await;
response.assert_status(axum::http::StatusCode::TOO_MANY_REQUESTS);
}