mirror of
https://github.com/CloudNebulaProject/webfingerd.git
synced 2026-04-10 13:10:41 +00:00
feat: add rate limiting, request ID, and CORS middleware
This commit is contained in:
parent
66b3de433f
commit
df3cc1eb91
6 changed files with 170 additions and 1 deletions
|
|
@ -5,16 +5,45 @@ pub mod links;
|
||||||
pub mod tokens;
|
pub mod tokens;
|
||||||
mod webfinger;
|
mod webfinger;
|
||||||
|
|
||||||
|
use axum::middleware as axum_mw;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
|
|
||||||
|
use crate::middleware::rate_limit::{self, KeyedLimiter};
|
||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
|
|
||||||
pub fn router(state: AppState) -> Router {
|
pub fn router(state: AppState) -> Router {
|
||||||
Router::new()
|
let public_limiter = KeyedLimiter::new(state.settings.rate_limit.public_rpm);
|
||||||
|
let api_limiter = KeyedLimiter::new(state.settings.rate_limit.api_rpm);
|
||||||
|
|
||||||
|
let public_cors = CorsLayer::new()
|
||||||
|
.allow_origin(Any)
|
||||||
|
.allow_methods(Any)
|
||||||
|
.allow_headers(Any);
|
||||||
|
|
||||||
|
// Public endpoints: webfinger + host-meta with per-IP rate limit + CORS
|
||||||
|
let public_routes = Router::new()
|
||||||
.merge(webfinger::router())
|
.merge(webfinger::router())
|
||||||
.merge(host_meta::router())
|
.merge(host_meta::router())
|
||||||
|
.layer(public_cors)
|
||||||
|
.layer(axum_mw::from_fn(move |req, next| {
|
||||||
|
let limiter = public_limiter.clone();
|
||||||
|
rate_limit::rate_limit_by_ip(limiter, req, next)
|
||||||
|
}));
|
||||||
|
|
||||||
|
// API endpoints with per-token rate limit (no wildcard CORS)
|
||||||
|
let api_routes = Router::new()
|
||||||
.merge(domains::router())
|
.merge(domains::router())
|
||||||
.merge(tokens::router())
|
.merge(tokens::router())
|
||||||
.merge(links::router())
|
.merge(links::router())
|
||||||
|
.layer(axum_mw::from_fn(move |req, next| {
|
||||||
|
let limiter = api_limiter.clone();
|
||||||
|
rate_limit::rate_limit_by_token(limiter, req, next)
|
||||||
|
}));
|
||||||
|
|
||||||
|
Router::new()
|
||||||
|
.merge(public_routes)
|
||||||
|
.merge(api_routes)
|
||||||
.merge(health::router())
|
.merge(health::router())
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,5 +5,6 @@ pub mod config;
|
||||||
pub mod entity;
|
pub mod entity;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod handler;
|
pub mod handler;
|
||||||
|
pub mod middleware;
|
||||||
pub mod reaper;
|
pub mod reaper;
|
||||||
pub mod state;
|
pub mod state;
|
||||||
|
|
|
||||||
2
src/middleware/mod.rs
Normal file
2
src/middleware/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
pub mod rate_limit;
|
||||||
|
pub mod request_id;
|
||||||
82
src/middleware/rate_limit.rs
Normal file
82
src/middleware/rate_limit.rs
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
use axum::body::Body;
|
||||||
|
use axum::http::{Request, StatusCode};
|
||||||
|
use axum::middleware::Next;
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use governor::clock::DefaultClock;
|
||||||
|
use governor::state::{InMemoryState, NotKeyed};
|
||||||
|
use governor::{Quota, RateLimiter};
|
||||||
|
use std::num::NonZeroU32;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Per-key rate limiter using DashMap for keyed limiting (per IP or per token).
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct KeyedLimiter {
|
||||||
|
limiters: Arc<DashMap<String, Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>>>,
|
||||||
|
quota: Quota,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KeyedLimiter {
|
||||||
|
pub fn new(rpm: u32) -> Self {
|
||||||
|
let quota = Quota::per_minute(NonZeroU32::new(rpm).expect("rpm must be > 0"));
|
||||||
|
Self {
|
||||||
|
limiters: Arc::new(DashMap::new()),
|
||||||
|
quota,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_key(&self, key: &str) -> bool {
|
||||||
|
let limiter = self
|
||||||
|
.limiters
|
||||||
|
.entry(key.to_string())
|
||||||
|
.or_insert_with(|| Arc::new(RateLimiter::direct(self.quota)))
|
||||||
|
.clone();
|
||||||
|
limiter.check().is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rate limit middleware for public endpoints (keyed by client IP).
|
||||||
|
pub async fn rate_limit_by_ip(limiter: KeyedLimiter, request: Request<Body>, next: Next) -> Response {
|
||||||
|
// Extract IP from x-forwarded-for or fall back to "unknown"
|
||||||
|
let ip = request
|
||||||
|
.headers()
|
||||||
|
.get("x-forwarded-for")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|v| v.split(',').next())
|
||||||
|
.map(|s| s.trim().to_string())
|
||||||
|
.unwrap_or_else(|| "unknown".to_string());
|
||||||
|
|
||||||
|
if !limiter.check_key(&ip) {
|
||||||
|
return (
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
[("retry-after", "60")],
|
||||||
|
"rate limited",
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
next.run(request).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rate limit middleware for API endpoints (keyed by Bearer token prefix).
|
||||||
|
pub async fn rate_limit_by_token(limiter: KeyedLimiter, request: Request<Body>, next: Next) -> Response {
|
||||||
|
let key = request
|
||||||
|
.headers()
|
||||||
|
.get("authorization")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|v| v.strip_prefix("Bearer "))
|
||||||
|
.and_then(|t| t.split('.').next()) // Use token ID prefix as key
|
||||||
|
.unwrap_or("anonymous")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
if !limiter.check_key(&key) {
|
||||||
|
return (
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
[("retry-after", "60")],
|
||||||
|
"rate limited",
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
next.run(request).await
|
||||||
|
}
|
||||||
23
src/middleware/request_id.rs
Normal file
23
src/middleware/request_id.rs
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
use axum::body::Body;
|
||||||
|
use axum::http::{HeaderName, HeaderValue, Request};
|
||||||
|
use axum::middleware::Next;
|
||||||
|
use axum::response::Response;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
static X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||||
|
|
||||||
|
/// Middleware that generates a unique request ID and attaches it to both the
|
||||||
|
/// request (for downstream handlers) and the response (for clients).
|
||||||
|
pub async fn request_id(mut request: Request<Body>, next: Next) -> Response {
|
||||||
|
let id = Uuid::new_v4().to_string();
|
||||||
|
request
|
||||||
|
.headers_mut()
|
||||||
|
.insert(X_REQUEST_ID.clone(), HeaderValue::from_str(&id).unwrap());
|
||||||
|
|
||||||
|
let mut response = next.run(request).await;
|
||||||
|
response
|
||||||
|
.headers_mut()
|
||||||
|
.insert(X_REQUEST_ID.clone(), HeaderValue::from_str(&id).unwrap());
|
||||||
|
|
||||||
|
response
|
||||||
|
}
|
||||||
32
tests/test_rate_limit.rs
Normal file
32
tests/test_rate_limit.rs
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use axum_test::TestServer;
|
||||||
|
use webfingerd::handler;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_public_rate_limiting() {
|
||||||
|
let mut settings = common::test_settings();
|
||||||
|
settings.rate_limit.public_rpm = 2; // Very low for testing
|
||||||
|
|
||||||
|
let state = common::test_state_with_settings(settings).await;
|
||||||
|
let app = handler::router(state);
|
||||||
|
let server = TestServer::new(app);
|
||||||
|
|
||||||
|
// First two requests should succeed (even with 404)
|
||||||
|
server
|
||||||
|
.get("/.well-known/webfinger")
|
||||||
|
.add_query_param("resource", "acct:a@a.com")
|
||||||
|
.await;
|
||||||
|
server
|
||||||
|
.get("/.well-known/webfinger")
|
||||||
|
.add_query_param("resource", "acct:b@b.com")
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Third should be rate limited
|
||||||
|
let response = server
|
||||||
|
.get("/.well-known/webfinger")
|
||||||
|
.add_query_param("resource", "acct:c@c.com")
|
||||||
|
.await;
|
||||||
|
|
||||||
|
response.assert_status(axum::http::StatusCode::TOO_MANY_REQUESTS);
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue