mirror of
https://github.com/CloudNebulaProject/barycenter.git
synced 2026-04-10 13:10:42 +00:00
Initial commit: Barycenter OpenID Connect Identity Provider
Barycenter is an OpenID Connect Identity Provider (IdP) implementing OAuth 2.0 Authorization Code flow with PKCE. Written in Rust using axum, SeaORM, and josekit. Features: - Authorization Code flow with PKCE (S256) - Dynamic client registration - Token endpoint with multiple auth methods - ID Token signing (RS256) - UserInfo endpoint - Discovery and JWKS publication 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
commit
64b31e40df
16 changed files with 3074 additions and 0 deletions
21
.claude/settings.local.json
Normal file
21
.claude/settings.local.json
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(cargo build:*)",
|
||||
"Bash(curl:*)",
|
||||
"mcp__context7__resolve-library-id",
|
||||
"mcp__context7__get-library-docs",
|
||||
"Bash(jq:*)",
|
||||
"Bash(cargo install:*)",
|
||||
"Bash(cargo nextest run:*)",
|
||||
"Bash(RUST_BACKTRACE=1 cargo nextest run:*)",
|
||||
"Bash(lsof:*)",
|
||||
"Bash(pkill:*)",
|
||||
"mcp__github__search_repositories",
|
||||
"mcp__github__get_me",
|
||||
"mcp__github__search_users"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
}
|
||||
}
|
||||
16
.config/nextest.toml
Normal file
16
.config/nextest.toml
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
[profile.default]
|
||||
# Run tests serially to avoid port conflicts
|
||||
test-threads = 1
|
||||
# Increase timeout for integration tests that start the server
|
||||
slow-timeout = { period = "180s", terminate-after = 3 }
|
||||
# Show output for all tests
|
||||
status-level = "all"
|
||||
# Always show stdout/stderr for integration tests
|
||||
failure-output = "immediate-final"
|
||||
success-output = "final"
|
||||
|
||||
[profile.ci]
|
||||
# CI profile with stricter settings
|
||||
test-threads = 1
|
||||
retries = 2
|
||||
slow-timeout = { period = "120s", terminate-after = 3 }
|
||||
29
.gitignore
vendored
Normal file
29
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Rust build artifacts
|
||||
/target/
|
||||
**/*.rs.bk
|
||||
*.pdb
|
||||
|
||||
# Cargo
|
||||
Cargo.lock
|
||||
|
||||
# IDE and editor files
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# Application-specific
|
||||
/keys/
|
||||
/jwks.json
|
||||
/private_key.json
|
||||
/data/jwks.json
|
||||
/data/private_key.pem
|
||||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
|
||||
# Environment and config (optional - uncomment if you want to ignore local configs)
|
||||
# config.toml
|
||||
# .env
|
||||
167
CLAUDE.md
Normal file
167
CLAUDE.md
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
Barycenter is an OpenID Connect Identity Provider (IdP) implementing OAuth 2.0 Authorization Code flow with PKCE. The project is written in Rust using axum for the web framework, SeaORM for database access (SQLite), and josekit for JOSE/JWT operations.
|
||||
|
||||
## Build and Development Commands
|
||||
|
||||
```bash
|
||||
# Build the project
|
||||
cargo build
|
||||
|
||||
# Run the application (defaults to config.toml)
|
||||
cargo run
|
||||
|
||||
# Run with custom config
|
||||
cargo run -- --config path/to/config.toml
|
||||
|
||||
# Run in release mode
|
||||
cargo build --release
|
||||
cargo run --release
|
||||
|
||||
# Check code without building
|
||||
cargo check
|
||||
|
||||
# Run tests
|
||||
cargo test
|
||||
|
||||
# Run with logging (uses RUST_LOG environment variable)
|
||||
RUST_LOG=debug cargo run
|
||||
RUST_LOG=barycenter=trace cargo run
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The application loads configuration from:
|
||||
1. Default values (defined in `src/settings.rs`)
|
||||
2. Configuration file (default: `config.toml`)
|
||||
3. Environment variables with prefix `CRABIDP__` (e.g., `CRABIDP__SERVER__PORT=9090`)
|
||||
|
||||
Environment variables use double underscores as separators for nested keys.
|
||||
|
||||
## Architecture and Module Structure
|
||||
|
||||
### Entry Point (`src/main.rs`)
|
||||
The application initializes in this order:
|
||||
1. Parse CLI arguments for config file path
|
||||
2. Load settings from config file and environment
|
||||
3. Initialize database connection and create tables via `storage::init()`
|
||||
4. Initialize JWKS manager (generates or loads RSA keys)
|
||||
5. Start web server with `web::serve()`
|
||||
|
||||
### Settings (`src/settings.rs`)
|
||||
Manages configuration with four main sections:
|
||||
- `Server`: listen address and public base URL (issuer)
|
||||
- `Database`: SQLite connection string
|
||||
- `Keys`: JWKS and private key paths, signing algorithm
|
||||
- `Federation`: trust anchor URLs (future use)
|
||||
|
||||
The `issuer()` method returns the OAuth issuer URL, preferring `public_base_url` or falling back to `http://{host}:{port}`.
|
||||
|
||||
### Storage (`src/storage.rs`)
|
||||
Database layer with raw SQL using SeaORM's `DatabaseConnection`. Tables:
|
||||
- `clients`: OAuth client registrations (client_id, client_secret, redirect_uris)
|
||||
- `auth_codes`: Authorization codes with PKCE challenge, subject, scope, nonce
|
||||
- `access_tokens`: Bearer tokens with subject, scope, expiration
|
||||
- `properties`: Key-value store for arbitrary user properties (owner, key, value)
|
||||
|
||||
All IDs and tokens are generated via `random_id()` (24 random bytes, base64url-encoded).
|
||||
|
||||
### JWKS Manager (`src/jwks.rs`)
|
||||
Handles RSA key generation, persistence, and JWT signing:
|
||||
- Generates 2048-bit RSA key on first run
|
||||
- Persists private key as JSON to `private_key_path`
|
||||
- Publishes public key set to `jwks_path`
|
||||
- Provides `sign_jwt_rs256()` for ID Token signing with kid header
|
||||
|
||||
### Web Endpoints (`src/web.rs`)
|
||||
Implements OpenID Connect and OAuth 2.0 endpoints:
|
||||
|
||||
**Discovery & Registration:**
|
||||
- `GET /.well-known/openid-configuration` - OpenID Provider metadata
|
||||
- `GET /.well-known/jwks.json` - Public signing keys
|
||||
- `POST /connect/register` - Dynamic client registration
|
||||
|
||||
**OAuth/OIDC Flow:**
|
||||
- `GET /authorize` - Authorization endpoint (issues authorization code with PKCE)
|
||||
- Currently uses fixed subject "demo-user" (pending login flow implementation per docs/next-iteration-plan.md)
|
||||
- Validates client_id, redirect_uri, scope (must include "openid"), PKCE S256
|
||||
- Returns redirect with code and state
|
||||
- `POST /token` - Token endpoint (exchanges code for tokens)
|
||||
- Supports `client_secret_basic` (Authorization header) and `client_secret_post` (form body)
|
||||
- Validates PKCE S256 code_verifier
|
||||
- Returns access_token, id_token (JWT), token_type, expires_in
|
||||
- `GET /userinfo` - UserInfo endpoint (returns claims for Bearer token)
|
||||
|
||||
**Non-Standard:**
|
||||
- `GET /properties/:owner/:key` - Get property value
|
||||
- `PUT /properties/:owner/:key` - Set property value
|
||||
- `GET /federation/trust-anchors` - List trust anchors
|
||||
|
||||
### Error Handling (`src/errors.rs`)
|
||||
Defines `CrabError` for internal error handling with conversions from common error types.
|
||||
|
||||
## Key Implementation Details
|
||||
|
||||
### PKCE Flow
|
||||
- Only S256 code challenge method is supported (plain is rejected)
|
||||
- Code challenge stored with auth code
|
||||
- Code verifier validated at token endpoint by hashing and comparing
|
||||
|
||||
### Client Authentication
|
||||
Token endpoint accepts two methods:
|
||||
1. `client_secret_basic`: HTTP Basic auth (client_id:client_secret base64-encoded)
|
||||
2. `client_secret_post`: Form parameters (client_id and client_secret in body)
|
||||
|
||||
### ID Token Claims
|
||||
Generated ID tokens include:
|
||||
- Standard claims: iss, sub, aud, exp, iat
|
||||
- Optional: nonce (if provided in authorize request)
|
||||
- at_hash: hash of access token per OIDC spec (left 128 bits of SHA-256, base64url)
|
||||
- Signed with RS256, includes kid header matching JWKS
|
||||
|
||||
### State Management
|
||||
- Authorization codes: 5 minute TTL, single-use (marked consumed)
|
||||
- Access tokens: 1 hour TTL, checked for expiration and revoked flag
|
||||
- Both stored in SQLite with timestamps
|
||||
|
||||
## Current Implementation Status
|
||||
|
||||
See `docs/oidc-conformance.md` for detailed OIDC compliance requirements.
|
||||
|
||||
**Implemented:**
|
||||
- Authorization Code flow with PKCE (S256)
|
||||
- Dynamic client registration
|
||||
- Token endpoint with client_secret_basic and client_secret_post
|
||||
- ID Token signing (RS256) with at_hash and nonce
|
||||
- UserInfo endpoint with Bearer token authentication
|
||||
- Discovery and JWKS publication
|
||||
- Property storage API
|
||||
|
||||
**Pending (see docs/next-iteration-plan.md):**
|
||||
- User authentication and session management (currently uses fixed "demo-user" subject)
|
||||
- auth_time claim in ID Token (requires session tracking)
|
||||
- Cache-Control headers on token endpoint
|
||||
- Consent flow (currently auto-consents)
|
||||
- Refresh tokens
|
||||
- Token revocation and introspection
|
||||
- OpenID Federation trust chain validation
|
||||
|
||||
## Testing and Validation
|
||||
|
||||
No automated tests currently exist. Manual testing can be done with curl commands following the OAuth 2.0 Authorization Code + PKCE flow:
|
||||
|
||||
1. Register a client via `POST /connect/register`
|
||||
2. Generate PKCE verifier and challenge
|
||||
3. Navigate to `/authorize` with required parameters
|
||||
4. Exchange authorization code at `/token` with code_verifier
|
||||
5. Access `/userinfo` with Bearer access_token
|
||||
|
||||
Example PKCE generation (bash):
|
||||
```bash
|
||||
verifier=$(openssl rand -base64 32 | tr -d '=' | tr '+/' '-_')
|
||||
challenge=$(echo -n "$verifier" | openssl dgst -binary -sha256 | base64 | tr -d '=' | tr '+/' '-_')
|
||||
```
|
||||
53
Cargo.toml
Normal file
53
Cargo.toml
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
[package]
|
||||
name = "barycenter"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
description = "OpenID Connect IdP with federation, property storage, and auto-registration the center of gravity between multiple objects."
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.8", features = ["json", "form"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
miette = { version = "7", features = ["fancy"] }
|
||||
thiserror = "1"
|
||||
config = "0.14"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
serde_with = "3"
|
||||
|
||||
# SeaORM for SQLite
|
||||
sea-orm = { version = "1", default-features = false, features = ["sqlx-sqlite", "runtime-tokio-rustls"] }
|
||||
|
||||
# JOSE / JWKS & JWT
|
||||
josekit = "0.10"
|
||||
|
||||
chrono = { version = "0.4", features = ["serde", "clock"] }
|
||||
time = "0.3"
|
||||
rand = "0.8"
|
||||
base64ct = { version = "1", features = ["alloc"] }
|
||||
anyhow = "1"
|
||||
sha2 = "0.10"
|
||||
serde_urlencoded = "0.7"
|
||||
|
||||
# Password hashing
|
||||
argon2 = "0.5"
|
||||
|
||||
# Rate limiting
|
||||
tower = "0.5"
|
||||
tower_governor = "0.4"
|
||||
|
||||
# Validation
|
||||
regex = "1"
|
||||
url = "2"
|
||||
|
||||
[dev-dependencies]
|
||||
openidconnect = { version = "4", features = ["reqwest-blocking"] }
|
||||
oauth2 = "5"
|
||||
reqwest = { version = "0.12", features = ["blocking", "json", "cookies"] }
|
||||
urlencoding = "2"
|
||||
|
||||
[profile.release]
|
||||
debug = 1
|
||||
16
config.toml
Normal file
16
config.toml
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
[server]
|
||||
host = "0.0.0.0"
|
||||
port = 8080
|
||||
# Uncomment for production with HTTPS:
|
||||
# public_base_url = "https://idp.example.com"
|
||||
|
||||
[database]
|
||||
url = "sqlite://crabidp.db?mode=rwc"
|
||||
|
||||
[keys]
|
||||
jwks_path = "data/jwks.json"
|
||||
private_key_path = "data/private_key.pem"
|
||||
alg = "RS256"
|
||||
|
||||
[federation]
|
||||
trust_anchors = []
|
||||
70
docs/next-iteration-plan.md
Normal file
70
docs/next-iteration-plan.md
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
### Next Iteration Plan — OpenID Connect OP (dated 2025-11-24 16:48)
|
||||
|
||||
This plan builds on the current implementation (Authorization Code + PKCE, ID Token signing, UserInfo, client_secret_post and client_secret_basic, discovery, JWKS, dynamic registration). The goal of this iteration is to harden compliance, improve interoperability, and introduce a minimal authentication stub to replace the fixed demo subject.
|
||||
|
||||
Objectives
|
||||
- Compliance hardening: headers, error models, metadata accuracy.
|
||||
- Authentication stub: minimal user login and session handling to issue codes for a real subject instead of the demo placeholder.
|
||||
- Documentation updates and basic validation scripts/tests.
|
||||
|
||||
Scope and tasks
|
||||
1. Token endpoint response hygiene
|
||||
- Add headers per OAuth 2.0 recommendations:
|
||||
- Cache-Control: no-store
|
||||
- Pragma: no-cache
|
||||
- Ensure JSON error bodies conform to RFC 6749 (error, error_description, error_uri optional). Keep existing WWW-Authenticate for invalid_client.
|
||||
|
||||
2. ID Token claims improvements
|
||||
- Include auth_time when the OP has a known authentication time for the end-user session (see task 3 for session stub).
|
||||
- Ensure kid in JWS header is set (already implemented) and alg matches discovery. Verify exp/iat handling and clock skew tolerance notes for verifiers (doc).
|
||||
|
||||
3. Minimal authentication + consent stub
|
||||
- Introduce a basic login flow used by /authorize:
|
||||
- GET /login renders a simple page (or JSON instruction) to submit username (subject) and a fixed password placeholder.
|
||||
- POST /login sets a secure, HTTP-only cookie session with subject and auth_time; redirects back to the original /authorize request (preserve request params via a return_to parameter).
|
||||
- If an active session cookie is present at /authorize, skip login.
|
||||
- Consent: for MVP, auto-consent to requested scopes; record a TODO for explicit consent later.
|
||||
- Optionally, persist sessions in-memory for this iteration; a DB table can be added in a later iteration.
|
||||
|
||||
4. Error handling and redirects at /authorize
|
||||
- Continue using redirect-based error responses per OIDC (error, error_description, state passthrough).
|
||||
- Validate and return errors for: missing/invalid parameters, unsupported response_type, scope lacking openid, PKCE missing/invalid.
|
||||
|
||||
5. Discovery metadata accuracy
|
||||
- Verify discovery includes: userinfo_endpoint, token_endpoint_auth_methods_supported [client_secret_post, client_secret_basic], code_challenge_methods_supported [S256], response_types_supported [code], id_token_signing_alg_values_supported [RS256]. (Current implementation already does this; re-check after changes.)
|
||||
|
||||
6. Documentation updates
|
||||
- Update docs/oidc-conformance.md to mention client_secret_basic support and the new login/session stub behavior.
|
||||
- Add a short README snippet or docs/flows.md with example end-to-end curl/browser steps including the login step.
|
||||
|
||||
7. Basic validation scripts/tests
|
||||
- Add a scripts/ directory (or docs snippets) with curl commands to verify:
|
||||
- Discovery document fields.
|
||||
- Authorization + login + token exchange (with PKCE S256) producing a valid ID Token and access token.
|
||||
- UserInfo with Bearer access token and proper WWW-Authenticate on failure.
|
||||
|
||||
Non-goals (deferred)
|
||||
- Refresh tokens, rotation, revocation and introspection.
|
||||
- Rich user model and persistence (beyond minimal session stub).
|
||||
- OpenID Federation trust chain validation.
|
||||
- Key rotation and multi-key JWKS.
|
||||
|
||||
Acceptance criteria
|
||||
- /token responses include Cache-Control: no-store and Pragma: no-cache for both success and error responses.
|
||||
- /token invalid_client responses continue to include a proper WWW-Authenticate: Basic realm header.
|
||||
- ID Token includes auth_time when the user logs in during the flow (based on the session stub’s auth_time); includes nonce when provided; includes at_hash.
|
||||
- /authorize uses the logged-in user from the new session cookie; if no session, prompts to login and returns to continue the flow; redirects carry state on error.
|
||||
- Discovery still advertises capabilities accurately after changes.
|
||||
- Docs updated to reflect client_secret_basic and the login/session stub.
|
||||
- Example commands or scripts demonstrate a complete code flow using PKCE with a login step and successful token exchange and userinfo call.
|
||||
|
||||
Implementation notes
|
||||
- Keep the session cookie scoped to the IdP origin; mark HttpOnly, Secure in production, SameSite=Lax.
|
||||
- Use S256 exclusively for PKCE (plain not supported).
|
||||
- Continue to generate/sign with RS256; ensure kid header is present and published in JWKS.
|
||||
- Keep the issuer stable via server.public_base_url in production deployments.
|
||||
|
||||
Timeline and effort
|
||||
- Estimated effort: 1–2 short iterations.
|
||||
- Day 1: headers, error refinements, discovery verification, docs updates.
|
||||
- Day 2: login/session stub, auth_time claim, validation scripts.
|
||||
132
docs/oidc-conformance.md
Normal file
132
docs/oidc-conformance.md
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
### OpenID Connect OP conformance plan (context7 summary)
|
||||
|
||||
This project is an OpenID Provider (OP) scaffold. This document summarizes what must be implemented to align with OpenID Connect 1.0 (and adjacent OAuth2/OAuth 2.1 guidance), what we already have, and the minimal viable roadmap.
|
||||
|
||||
Scope references (context7 up-to-date pointers):
|
||||
- OpenID Connect Core 1.0 (final)
|
||||
- OpenID Connect Discovery 1.0 (/.well-known/openid-configuration)
|
||||
- OAuth 2.0 (RFC 6749) + PKCE (RFC 7636) and OAuth 2.1 draft guidance + OAuth 2.0 Security BCP (RFC 6819 successor BCPs)
|
||||
- OpenID Connect Dynamic Client Registration 1.0
|
||||
- OpenID Federation 1.0 (for later phases)
|
||||
|
||||
|
||||
1) Endpoints required for a basic, interoperable OP
|
||||
- Authorization endpoint: GET/POST /authorize
|
||||
- Response type: code (Authorization Code Flow)
|
||||
- Required request params: client_id, redirect_uri, response_type=code, scope (includes openid), state; Recommended: nonce; If PKCE is used: code_challenge, code_challenge_method=S256
|
||||
- Validations: registered redirect_uri match, scope contains openid, client is known, response_type supported, PKCE required for public clients
|
||||
- Output: HTTP redirect to redirect_uri with code and state
|
||||
|
||||
- Token endpoint: POST /token (application/x-www-form-urlencoded)
|
||||
- Grant type: authorization_code
|
||||
- Parameters: grant_type, code, redirect_uri, client authentication
|
||||
- Client auth: client_secret_post (initial support); consider client_secret_basic later
|
||||
- PKCE verification: code_verifier must match stored code_challenge (S256)
|
||||
- Output: JSON with access_token, token_type=bearer, expires_in, id_token (JWT), possibly refresh_token
|
||||
- Error model: RFC 6749 + OIDC specific where applicable
|
||||
|
||||
- UserInfo endpoint: GET/POST /userinfo (Bearer token)
|
||||
- Input: Authorization: Bearer <access_token>
|
||||
- Output: JSON claims for the subject consistent with scopes (profile, email, etc.)
|
||||
|
||||
- Discovery endpoint: GET /.well-known/openid-configuration
|
||||
- Must publish: issuer, authorization_endpoint, token_endpoint, jwks_uri, response_types_supported, subject_types_supported, id_token_signing_alg_values_supported
|
||||
- Should publish: registration_endpoint (if supported), scopes_supported, claims_supported, grant_types_supported, token_endpoint_auth_methods_supported, code_challenge_methods_supported
|
||||
|
||||
- JWKS endpoint: GET /.well-known/jwks.json
|
||||
- Publish public keys used to sign ID Tokens; include kid values
|
||||
|
||||
- Dynamic Client Registration endpoint: POST /connect/register
|
||||
- Accept a subset of metadata: redirect_uris (required), client_name, token_endpoint_auth_method, etc.
|
||||
- Output per spec: client_id, client_secret (for confidential clients), client_id_issued_at, token_endpoint_auth_method, and echoed metadata
|
||||
- Later: registration access token and client configuration endpoint
|
||||
|
||||
|
||||
2) Tokens and claims
|
||||
- ID Token (JWT, JWS RS256 initially)
|
||||
- Required claims: iss, sub, aud, exp, iat, auth_time (if max_age requested), nonce (if provided in auth request), at_hash (if access token issued), c_hash (optional if code returned from token endpoint)
|
||||
- kid header must match a key in JWKS; alg consistent with discovery
|
||||
|
||||
- Access Token
|
||||
- Can be opaque initially; include reference in DB; expires_in typically 3600s
|
||||
- Optionally JWT access tokens later
|
||||
|
||||
- Refresh Token (optional in MVP)
|
||||
- Issue for offline_access scope; secure storage; rotation recommended
|
||||
|
||||
|
||||
3) Storage additions (DB)
|
||||
- auth_codes: code, client_id, redirect_uri, scope, subject, nonce, code_challenge (S256), created_at, expires_at, consumed
|
||||
- access_tokens: token, client_id, subject, scope, created_at, expires_at, revoked
|
||||
- refresh_tokens (optional initially)
|
||||
|
||||
|
||||
4) Security requirements (minimum)
|
||||
- Enforce PKCE S256 for public clients; allow confidential without PKCE only if policy allows (recommended: require for all)
|
||||
- Validate redirect_uri exact match against one of the client’s registered URIs
|
||||
- Validate aud (client_id) and iss in ID Token; use correct exp/iat skew bounds
|
||||
- Use state and nonce to prevent CSRF and token replay
|
||||
- Use HTTPS in production; publish an https issuer in discovery via server.public_base_url
|
||||
|
||||
|
||||
5) Discovery metadata we should publish once endpoints exist
|
||||
- issuer: <base URL>
|
||||
- authorization_endpoint: <issuer>/authorize
|
||||
- token_endpoint: <issuer>/token
|
||||
- jwks_uri: <issuer>/.well-known/jwks.json
|
||||
- registration_endpoint: <issuer>/connect/register
|
||||
- response_types_supported: ["code"]
|
||||
- grant_types_supported: ["authorization_code"]
|
||||
- subject_types_supported: ["public"]
|
||||
- id_token_signing_alg_values_supported: ["RS256"]
|
||||
- token_endpoint_auth_methods_supported: ["client_secret_post"]
|
||||
- code_challenge_methods_supported: ["S256"]
|
||||
- scopes_supported: ["openid", "profile", "email"]
|
||||
- claims_supported: ["sub", "iss", "aud", "exp", "iat", "auth_time", "nonce", "name", "given_name", "family_name", "email", "email_verified"]
|
||||
- userinfo_endpoint: <issuer>/userinfo (once implemented)
|
||||
|
||||
|
||||
6) Current status in this repository
|
||||
- Implemented:
|
||||
- Discovery endpoint (basic subset)
|
||||
- JWKS publication with key generation and persistence
|
||||
- Dynamic client auto-registration (basic)
|
||||
- Simple property storage API (non-standard)
|
||||
- Federation trust anchors stub
|
||||
|
||||
- Missing for OIDC Core compliance:
|
||||
- /authorize (Authorization Code + PKCE)
|
||||
- /token (code exchange, ID Token signing)
|
||||
- /userinfo
|
||||
- Storage for auth codes and tokens
|
||||
- Full error models and input validation across endpoints
|
||||
- Robust client registration validation + optional configuration endpoint
|
||||
|
||||
|
||||
7) Minimal viable roadmap (incremental)
|
||||
Step 1: Data model and discovery metadata
|
||||
- Add DB tables for auth_codes and access_tokens
|
||||
- Extend discovery to include grant_types_supported, token_endpoint_auth_methods_supported, code_challenge_methods_supported, claims_supported
|
||||
|
||||
Step 2: Authorization Code + PKCE
|
||||
- Implement /authorize to issue short-lived codes; validate redirect_uri, scope, client, state, nonce, PKCE
|
||||
|
||||
Step 3: Token endpoint and ID Token
|
||||
- Implement /token; client_secret_post, PKCE verification; sign ID Token with RS256 using current JWK; include required claims
|
||||
|
||||
Step 4: UserInfo
|
||||
- Implement /userinfo backed by properties or a user table; authorize via access token
|
||||
|
||||
Step 5: Hardening and cleanup
|
||||
- Proper errors per specs; input validation; token lifetimes; background pruning of consumed/expired artifacts
|
||||
- Optional: client_secret_basic, refresh tokens, rotation, revocation, introspection
|
||||
|
||||
Step 6: Federation (later)
|
||||
- Entity statement issuance, publication, and trust chain verification; policy application to registration
|
||||
|
||||
|
||||
Implementation notes
|
||||
- Keep issuer stable and correct in settings.server.public_base_url for production
|
||||
- Ensure JWKS kid selection and alg entry match discovery
|
||||
- Prefer S256 for PKCE; do not support plain
|
||||
- Add tests or curl scripts to verify end-to-end flows
|
||||
39
src/errors.rs
Normal file
39
src/errors.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
use miette::Diagnostic;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error, Diagnostic)]
|
||||
pub enum CrabError {
|
||||
#[error("I/O error: {0}")]
|
||||
#[diagnostic(code(crabidp::io))]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Config error: {0}")]
|
||||
#[diagnostic(code(crabidp::config))]
|
||||
Config(#[from] config::ConfigError),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
#[diagnostic(code(crabidp::serde))]
|
||||
Serde(#[from] serde_json::Error),
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
#[diagnostic(code(crabidp::db))]
|
||||
Db(#[from] sea_orm::DbErr),
|
||||
|
||||
#[error("JOSE error: {0}")]
|
||||
#[diagnostic(code(crabidp::jose))]
|
||||
Jose(String),
|
||||
|
||||
#[error("Bad request: {0}")]
|
||||
#[diagnostic(code(crabidp::bad_request))]
|
||||
BadRequest(String),
|
||||
|
||||
#[error("{0}")]
|
||||
#[diagnostic(code(crabidp::other))]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl From<josekit::JoseError> for CrabError {
|
||||
fn from(value: josekit::JoseError) -> Self {
|
||||
CrabError::Jose(value.to_string())
|
||||
}
|
||||
}
|
||||
78
src/jwks.rs
Normal file
78
src/jwks.rs
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
use crate::errors::CrabError;
|
||||
use crate::settings::Keys;
|
||||
use base64ct::Encoding;
|
||||
use josekit::jwk::Jwk;
|
||||
use josekit::jwt::JwtPayload;
|
||||
use josekit::jwt;
|
||||
use josekit::jws::{JwsHeader, RS256};
|
||||
use rand::RngCore;
|
||||
use serde_json::{json, Value};
|
||||
use std::fs;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JwksManager {
|
||||
cfg: Keys,
|
||||
public_jwks_value: Arc<Value>,
|
||||
private_jwk: Arc<Jwk>,
|
||||
}
|
||||
|
||||
impl JwksManager {
|
||||
pub async fn new(cfg: Keys) -> Result<Self, CrabError> {
|
||||
// Ensure parent dirs exist
|
||||
if let Some(parent) = cfg.jwks_path.parent() { fs::create_dir_all(parent)?; }
|
||||
if let Some(parent) = cfg.private_key_path.parent() { fs::create_dir_all(parent)?; }
|
||||
|
||||
// If private key exists, load it; otherwise generate and persist both private and public
|
||||
let private_jwk = if cfg.private_key_path.exists() {
|
||||
let s = fs::read_to_string(&cfg.private_key_path)?;
|
||||
// Stored as JSON
|
||||
serde_json::from_str::<Jwk>(&s)?
|
||||
} else {
|
||||
let mut jwk = Jwk::generate_rsa_key(2048)?;
|
||||
let kid = cfg.key_id.clone().unwrap_or_else(random_kid);
|
||||
jwk.set_key_id(&kid);
|
||||
jwk.set_algorithm(cfg.alg.as_str());
|
||||
jwk.set_key_use("sig");
|
||||
// Persist private key as JSON
|
||||
let priv_json = serde_json::to_string_pretty(&jwk)?;
|
||||
fs::write(&cfg.private_key_path, priv_json)?;
|
||||
jwk
|
||||
};
|
||||
|
||||
// Ensure JWKS file exists or update from private_jwk
|
||||
if !cfg.jwks_path.exists() {
|
||||
let public = private_jwk.to_public_key()?;
|
||||
let jwk_val: Value = serde_json::to_value(public)?;
|
||||
let jwks = json!({ "keys": [jwk_val] });
|
||||
fs::write(&cfg.jwks_path, serde_json::to_string_pretty(&jwks)?)?;
|
||||
}
|
||||
|
||||
// Load public JWKS value
|
||||
let public_jwks_value: Value = serde_json::from_str(&fs::read_to_string(&cfg.jwks_path)?)?;
|
||||
|
||||
Ok(Self { cfg, public_jwks_value: Arc::new(public_jwks_value), private_jwk: Arc::new(private_jwk) })
|
||||
}
|
||||
|
||||
pub fn jwks_json(&self) -> Value { (*self.public_jwks_value).clone() }
|
||||
|
||||
pub fn private_jwk(&self) -> Jwk { (*self.private_jwk).clone() }
|
||||
|
||||
pub fn sign_jwt_rs256(&self, payload: &JwtPayload) -> Result<String, CrabError> {
|
||||
// Use RS256 signer from josekit
|
||||
let signer = RS256.signer_from_jwk(&self.private_jwk)?;
|
||||
let mut header = JwsHeader::new();
|
||||
if let Some(kid) = self.private_jwk.key_id() {
|
||||
header.set_key_id(kid);
|
||||
}
|
||||
header.set_algorithm("RS256");
|
||||
let token = jwt::encode_with_signer(payload, &header, &signer)?;
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
|
||||
fn random_kid() -> String {
|
||||
let mut bytes = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut bytes);
|
||||
base64ct::Base64UrlUnpadded::encode_string(&bytes)
|
||||
}
|
||||
60
src/main.rs
Normal file
60
src/main.rs
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
mod settings;
|
||||
mod errors;
|
||||
mod jwks;
|
||||
mod web;
|
||||
mod storage;
|
||||
mod session;
|
||||
|
||||
use clap::Parser;
|
||||
use miette::{IntoDiagnostic, Result};
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "crabidp", version, about = "OpenID Connect IdP (scaffold)")]
|
||||
struct Cli {
|
||||
/// Path to configuration file
|
||||
#[arg(short, long, default_value = "config.toml")]
|
||||
config: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// logging
|
||||
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
fmt().with_env_filter(env_filter).init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
// load settings
|
||||
let settings = settings::Settings::load(&cli.config)?;
|
||||
tracing::info!(?settings, "Loaded configuration");
|
||||
|
||||
// init storage (database)
|
||||
let db = storage::init(&settings.database).await?;
|
||||
|
||||
// ensure test users exist
|
||||
ensure_test_users(&db).await?;
|
||||
|
||||
// init jwks (generate if missing)
|
||||
let jwks_mgr = jwks::JwksManager::new(settings.keys.clone()).await?;
|
||||
|
||||
// start web server
|
||||
web::serve(settings, db, jwks_mgr).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_test_users(db: &sea_orm::DatabaseConnection) -> Result<()> {
|
||||
// Check if admin exists
|
||||
if storage::get_user_by_username(db, "admin").await.into_diagnostic()?.is_none() {
|
||||
storage::create_user(
|
||||
db,
|
||||
"admin",
|
||||
"password123",
|
||||
Some("admin@example.com".to_string()),
|
||||
)
|
||||
.await
|
||||
.into_diagnostic()?;
|
||||
tracing::info!("Created default admin user (username: admin, password: password123)");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
50
src/session.rs
Normal file
50
src/session.rs
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
use crate::settings::Settings;
|
||||
use axum::http::HeaderMap;
|
||||
|
||||
pub const SESSION_COOKIE_NAME: &str = "barycenter_session";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SessionCookie {
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
impl SessionCookie {
|
||||
pub fn new(session_id: String) -> Self {
|
||||
Self { session_id }
|
||||
}
|
||||
|
||||
pub fn from_headers(headers: &HeaderMap) -> Option<Self> {
|
||||
let cookie_header = headers.get(axum::http::header::COOKIE)?.to_str().ok()?;
|
||||
|
||||
// Parse cookie header for our session cookie
|
||||
for cookie in cookie_header.split(';') {
|
||||
let cookie = cookie.trim();
|
||||
if let Some(value) = cookie.strip_prefix(SESSION_COOKIE_NAME).and_then(|s| s.strip_prefix('=')) {
|
||||
return Some(Self {
|
||||
session_id: value.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn to_cookie_header(&self, settings: &Settings) -> String {
|
||||
let secure = settings.issuer().starts_with("https://");
|
||||
let max_age = 3600; // 1 hour default
|
||||
|
||||
format!(
|
||||
"{}={}; HttpOnly; {}SameSite=Lax; Path=/; Max-Age={}",
|
||||
SESSION_COOKIE_NAME,
|
||||
self.session_id,
|
||||
if secure { "Secure; " } else { "" },
|
||||
max_age
|
||||
)
|
||||
}
|
||||
|
||||
pub fn delete_cookie_header() -> String {
|
||||
format!(
|
||||
"{}=; HttpOnly; SameSite=Lax; Path=/; Max-Age=0",
|
||||
SESSION_COOKIE_NAME
|
||||
)
|
||||
}
|
||||
}
|
||||
143
src/settings.rs
Normal file
143
src/settings.rs
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
use miette::{miette, IntoDiagnostic, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Settings {
|
||||
pub server: Server,
|
||||
pub database: Database,
|
||||
pub keys: Keys,
|
||||
pub federation: Federation,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Server {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
/// If set, this is used as the issuer/public base URL, e.g., https://idp.example.com
|
||||
pub public_base_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Database {
|
||||
/// SeaORM/SQLx connection string, e.g., sqlite://crabidp.db?mode=rwc
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Keys {
|
||||
/// Path to persist JWKS (public keys). Default: data/jwks.json
|
||||
pub jwks_path: PathBuf,
|
||||
/// Optional explicit key id to set on generated keys
|
||||
pub key_id: Option<String>,
|
||||
/// JWS algorithm for ID tokens (currently RS256)
|
||||
pub alg: String,
|
||||
/// Path to persist the private key in PEM (PKCS#8). Default: data/private_key.pem
|
||||
pub private_key_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Federation {
|
||||
/// List of trust anchor URLs or fingerprints (placeholder for real federation)
|
||||
pub trust_anchors: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for Server {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
public_base_url: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Database {
|
||||
fn default() -> Self {
|
||||
Self { url: "sqlite://crabidp.db?mode=rwc".to_string() }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Keys {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
jwks_path: PathBuf::from("data/jwks.json"),
|
||||
key_id: None,
|
||||
alg: "RS256".to_string(),
|
||||
private_key_path: PathBuf::from("data/private_key.pem"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Settings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
server: Server::default(),
|
||||
database: Database::default(),
|
||||
keys: Keys::default(),
|
||||
federation: Federation::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Settings {
|
||||
pub fn load(path: &str) -> Result<Self> {
|
||||
let mut builder = config::Config::builder()
|
||||
.set_default("server.host", Server::default().host)
|
||||
.into_diagnostic()?
|
||||
.set_default("server.port", Server::default().port)
|
||||
.into_diagnostic()?
|
||||
.set_default(
|
||||
"database.url",
|
||||
Database::default().url,
|
||||
)
|
||||
.into_diagnostic()?
|
||||
.set_default(
|
||||
"keys.jwks_path",
|
||||
Keys::default().jwks_path.to_string_lossy().to_string(),
|
||||
)
|
||||
.into_diagnostic()?
|
||||
.set_default("keys.alg", Keys::default().alg)
|
||||
.into_diagnostic()?
|
||||
.set_default(
|
||||
"keys.private_key_path",
|
||||
Keys::default().private_key_path.to_string_lossy().to_string(),
|
||||
)
|
||||
.into_diagnostic()?;
|
||||
|
||||
// Optional file
|
||||
if Path::new(path).exists() {
|
||||
builder = builder.add_source(config::File::with_name(path));
|
||||
}
|
||||
|
||||
// Environment overrides: CRABIDP__SERVER__PORT=9090, etc.
|
||||
builder = builder.add_source(
|
||||
config::Environment::with_prefix("CRABIDP").separator("__"),
|
||||
);
|
||||
|
||||
let cfg = builder.build().into_diagnostic()?;
|
||||
let mut s: Settings = cfg.try_deserialize().into_diagnostic()?;
|
||||
|
||||
// Normalize jwks path to be relative to current dir
|
||||
if s.keys.jwks_path.is_relative() {
|
||||
s.keys.jwks_path = std::env::current_dir()
|
||||
.into_diagnostic()?
|
||||
.join(&s.keys.jwks_path);
|
||||
}
|
||||
if s.keys.private_key_path.is_relative() {
|
||||
s.keys.private_key_path = std::env::current_dir()
|
||||
.into_diagnostic()?
|
||||
.join(&s.keys.private_key_path);
|
||||
}
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
pub fn issuer(&self) -> String {
|
||||
if let Some(base) = &self.server.public_base_url {
|
||||
base.trim_end_matches('/').to_string()
|
||||
} else {
|
||||
format!("http://{}:{}", self.server.host, self.server.port)
|
||||
}
|
||||
}
|
||||
}
|
||||
793
src/storage.rs
Normal file
793
src/storage.rs
Normal file
|
|
@ -0,0 +1,793 @@
|
|||
use crate::errors::CrabError;
|
||||
use crate::settings::Database as DbCfg;
|
||||
use chrono::Utc;
|
||||
use rand::RngCore;
|
||||
use base64ct::Encoding;
|
||||
use sea_orm::{ConnectionTrait, Database, DatabaseConnection, DbBackend, Statement};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Client {
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub client_name: Option<String>,
|
||||
pub redirect_uris: Vec<String>,
|
||||
pub created_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NewClient {
|
||||
pub client_name: Option<String>,
|
||||
pub redirect_uris: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AuthCode {
|
||||
pub code: String,
|
||||
pub client_id: String,
|
||||
pub redirect_uri: String,
|
||||
pub scope: String,
|
||||
pub subject: String,
|
||||
pub nonce: Option<String>,
|
||||
pub code_challenge: String,
|
||||
pub code_challenge_method: String,
|
||||
pub created_at: i64,
|
||||
pub expires_at: i64,
|
||||
pub consumed: i64,
|
||||
pub auth_time: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccessToken {
|
||||
pub token: String,
|
||||
pub client_id: String,
|
||||
pub subject: String,
|
||||
pub scope: String,
|
||||
pub created_at: i64,
|
||||
pub expires_at: i64,
|
||||
pub revoked: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
pub subject: String,
|
||||
pub username: String,
|
||||
pub password_hash: String,
|
||||
pub email: Option<String>,
|
||||
pub email_verified: i64,
|
||||
pub created_at: i64,
|
||||
pub enabled: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Session {
|
||||
pub session_id: String,
|
||||
pub subject: String,
|
||||
pub auth_time: i64,
|
||||
pub created_at: i64,
|
||||
pub expires_at: i64,
|
||||
pub user_agent: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RefreshToken {
|
||||
pub token: String,
|
||||
pub client_id: String,
|
||||
pub subject: String,
|
||||
pub scope: String,
|
||||
pub created_at: i64,
|
||||
pub expires_at: i64,
|
||||
pub revoked: i64,
|
||||
pub parent_token: Option<String>, // For token rotation tracking
|
||||
}
|
||||
|
||||
pub async fn init(cfg: &DbCfg) -> Result<DatabaseConnection, CrabError> {
|
||||
let db = Database::connect(&cfg.url).await?;
|
||||
// bootstrap schema
|
||||
db.execute(Statement::from_string(DbBackend::Sqlite, "PRAGMA foreign_keys = ON"))
|
||||
.await?;
|
||||
|
||||
db.execute(Statement::from_string(
|
||||
DbBackend::Sqlite,
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS clients (
|
||||
client_id TEXT PRIMARY KEY,
|
||||
client_secret TEXT NOT NULL,
|
||||
client_name TEXT,
|
||||
redirect_uris TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL
|
||||
)
|
||||
"#
|
||||
))
|
||||
.await?;
|
||||
|
||||
db.execute(Statement::from_string(
|
||||
DbBackend::Sqlite,
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS properties (
|
||||
owner TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
PRIMARY KEY(owner, key)
|
||||
)
|
||||
"#
|
||||
))
|
||||
.await?;
|
||||
|
||||
db.execute(Statement::from_string(
|
||||
DbBackend::Sqlite,
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS auth_codes (
|
||||
code TEXT PRIMARY KEY,
|
||||
client_id TEXT NOT NULL,
|
||||
redirect_uri TEXT NOT NULL,
|
||||
scope TEXT NOT NULL,
|
||||
subject TEXT NOT NULL,
|
||||
nonce TEXT,
|
||||
code_challenge TEXT NOT NULL,
|
||||
code_challenge_method TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
expires_at INTEGER NOT NULL,
|
||||
consumed INTEGER NOT NULL DEFAULT 0,
|
||||
auth_time INTEGER
|
||||
)
|
||||
"#
|
||||
))
|
||||
.await?;
|
||||
|
||||
db.execute(Statement::from_string(
|
||||
DbBackend::Sqlite,
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS access_tokens (
|
||||
token TEXT PRIMARY KEY,
|
||||
client_id TEXT NOT NULL,
|
||||
subject TEXT NOT NULL,
|
||||
scope TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
expires_at INTEGER NOT NULL,
|
||||
revoked INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
"#
|
||||
))
|
||||
.await?;
|
||||
|
||||
db.execute(Statement::from_string(
|
||||
DbBackend::Sqlite,
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
subject TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
email TEXT,
|
||||
email_verified INTEGER NOT NULL DEFAULT 0,
|
||||
created_at INTEGER NOT NULL,
|
||||
enabled INTEGER NOT NULL DEFAULT 1
|
||||
)
|
||||
"#
|
||||
))
|
||||
.await?;
|
||||
|
||||
db.execute(Statement::from_string(
|
||||
DbBackend::Sqlite,
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
subject TEXT NOT NULL,
|
||||
auth_time INTEGER NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
expires_at INTEGER NOT NULL,
|
||||
user_agent TEXT,
|
||||
ip_address TEXT
|
||||
)
|
||||
"#
|
||||
))
|
||||
.await?;
|
||||
|
||||
db.execute(Statement::from_string(
|
||||
DbBackend::Sqlite,
|
||||
"CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at)"
|
||||
))
|
||||
.await?;
|
||||
|
||||
db.execute(Statement::from_string(
|
||||
DbBackend::Sqlite,
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
token TEXT PRIMARY KEY,
|
||||
client_id TEXT NOT NULL,
|
||||
subject TEXT NOT NULL,
|
||||
scope TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
expires_at INTEGER NOT NULL,
|
||||
revoked INTEGER NOT NULL DEFAULT 0,
|
||||
parent_token TEXT
|
||||
)
|
||||
"#
|
||||
))
|
||||
.await?;
|
||||
|
||||
db.execute(Statement::from_string(
|
||||
DbBackend::Sqlite,
|
||||
"CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires ON refresh_tokens(expires_at)"
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
pub async fn create_client(db: &DatabaseConnection, input: NewClient) -> Result<Client, CrabError> {
|
||||
let client_id = random_id();
|
||||
let client_secret = random_id();
|
||||
let created_at = Utc::now().timestamp();
|
||||
let redirect_uris_json = serde_json::to_string(&input.redirect_uris)?;
|
||||
|
||||
db.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"INSERT INTO clients (client_id, client_secret, client_name, redirect_uris, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)"#,
|
||||
[
|
||||
client_id.clone().into(),
|
||||
client_secret.clone().into(),
|
||||
input.client_name.clone().into(),
|
||||
redirect_uris_json.into(),
|
||||
created_at.into(),
|
||||
],
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(Client {
|
||||
client_id,
|
||||
client_secret,
|
||||
client_name: input.client_name,
|
||||
redirect_uris: input.redirect_uris,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_property(
|
||||
db: &DatabaseConnection,
|
||||
owner: &str,
|
||||
key: &str,
|
||||
) -> Result<Option<Value>, CrabError> {
|
||||
if let Some(row) = db
|
||||
.query_one(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
"SELECT value FROM properties WHERE owner = ? AND key = ?",
|
||||
[owner.into(), key.into()],
|
||||
))
|
||||
.await?
|
||||
{
|
||||
let value_str: String = row.try_get("", "value").unwrap_or_default();
|
||||
let json: Value = serde_json::from_str(&value_str)?;
|
||||
Ok(Some(json))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_property(
|
||||
db: &DatabaseConnection,
|
||||
owner: &str,
|
||||
key: &str,
|
||||
value: &Value,
|
||||
) -> Result<(), CrabError> {
|
||||
let now = Utc::now().timestamp();
|
||||
let json = serde_json::to_string(value)?;
|
||||
db.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"INSERT INTO properties (owner, key, value, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(owner, key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at"#,
|
||||
[owner.into(), key.into(), json.into(), now.into()],
|
||||
))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_client(db: &DatabaseConnection, client_id: &str) -> Result<Option<Client>, CrabError> {
|
||||
if let Some(row) = db
|
||||
.query_one(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"SELECT client_id, client_secret, client_name, redirect_uris, created_at FROM clients WHERE client_id = ?"#,
|
||||
[client_id.into()],
|
||||
))
|
||||
.await?
|
||||
{
|
||||
let client_id: String = row.try_get("", "client_id").unwrap_or_default();
|
||||
let client_secret: String = row.try_get("", "client_secret").unwrap_or_default();
|
||||
let client_name: Option<String> = row.try_get("", "client_name").ok();
|
||||
let redirect_uris_json: String = row.try_get("", "redirect_uris").unwrap_or_default();
|
||||
let redirect_uris: Vec<String> = serde_json::from_str(&redirect_uris_json).unwrap_or_default();
|
||||
let created_at: i64 = row.try_get("", "created_at").unwrap_or_default();
|
||||
Ok(Some(Client { client_id, client_secret, client_name, redirect_uris, created_at }))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn issue_auth_code(
|
||||
db: &DatabaseConnection,
|
||||
client_id: &str,
|
||||
redirect_uri: &str,
|
||||
scope: &str,
|
||||
subject: &str,
|
||||
nonce: Option<String>,
|
||||
code_challenge: &str,
|
||||
code_challenge_method: &str,
|
||||
ttl_secs: i64,
|
||||
auth_time: Option<i64>,
|
||||
) -> Result<AuthCode, CrabError> {
|
||||
let code = random_id();
|
||||
let now = Utc::now().timestamp();
|
||||
let expires_at = now + ttl_secs;
|
||||
db.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"INSERT INTO auth_codes (code, client_id, redirect_uri, scope, subject, nonce, code_challenge, code_challenge_method, created_at, expires_at, consumed, auth_time)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 0, ?)"#,
|
||||
[
|
||||
code.clone().into(),
|
||||
client_id.into(),
|
||||
redirect_uri.into(),
|
||||
scope.into(),
|
||||
subject.into(),
|
||||
nonce.clone().into(),
|
||||
code_challenge.into(),
|
||||
code_challenge_method.into(),
|
||||
now.into(),
|
||||
expires_at.into(),
|
||||
auth_time.into(),
|
||||
],
|
||||
))
|
||||
.await?;
|
||||
Ok(AuthCode {
|
||||
code,
|
||||
client_id: client_id.to_string(),
|
||||
redirect_uri: redirect_uri.to_string(),
|
||||
scope: scope.to_string(),
|
||||
subject: subject.to_string(),
|
||||
nonce,
|
||||
code_challenge: code_challenge.to_string(),
|
||||
code_challenge_method: code_challenge_method.to_string(),
|
||||
created_at: now,
|
||||
expires_at,
|
||||
consumed: 0,
|
||||
auth_time,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn consume_auth_code(
|
||||
db: &DatabaseConnection,
|
||||
code: &str,
|
||||
) -> Result<Option<AuthCode>, CrabError> {
|
||||
if let Some(row) = db
|
||||
.query_one(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"SELECT code, client_id, redirect_uri, scope, subject, nonce, code_challenge, code_challenge_method, created_at, expires_at, consumed, auth_time
|
||||
FROM auth_codes WHERE code = ?"#,
|
||||
[code.into()],
|
||||
))
|
||||
.await?
|
||||
{
|
||||
let consumed: i64 = row.try_get("", "consumed").unwrap_or_default();
|
||||
let expires_at: i64 = row.try_get("", "expires_at").unwrap_or_default();
|
||||
let now = Utc::now().timestamp();
|
||||
if consumed != 0 || now > expires_at {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Mark as consumed
|
||||
db.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"UPDATE auth_codes SET consumed = ? WHERE code = ?"#,
|
||||
[1.into(), code.into()],
|
||||
))
|
||||
.await?;
|
||||
|
||||
let code_val: String = row.try_get("", "code").unwrap_or_default();
|
||||
let client_id: String = row.try_get("", "client_id").unwrap_or_default();
|
||||
let redirect_uri: String = row.try_get("", "redirect_uri").unwrap_or_default();
|
||||
let scope: String = row.try_get("", "scope").unwrap_or_default();
|
||||
let subject: String = row.try_get("", "subject").unwrap_or_default();
|
||||
let nonce: Option<String> = row.try_get("", "nonce").ok();
|
||||
let code_challenge: String = row.try_get("", "code_challenge").unwrap_or_default();
|
||||
let code_challenge_method: String = row.try_get("", "code_challenge_method").unwrap_or_default();
|
||||
let created_at: i64 = row.try_get("", "created_at").unwrap_or_default();
|
||||
let expires_at: i64 = row.try_get("", "expires_at").unwrap_or_default();
|
||||
let auth_time: Option<i64> = row.try_get("", "auth_time").ok();
|
||||
Ok(Some(AuthCode { code: code_val, client_id, redirect_uri, scope, subject, nonce, code_challenge, code_challenge_method, created_at, expires_at, consumed: 1, auth_time }))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn issue_access_token(
|
||||
db: &DatabaseConnection,
|
||||
client_id: &str,
|
||||
subject: &str,
|
||||
scope: &str,
|
||||
ttl_secs: i64,
|
||||
) -> Result<AccessToken, CrabError> {
|
||||
let token = random_id();
|
||||
let now = Utc::now().timestamp();
|
||||
let expires_at = now + ttl_secs;
|
||||
db.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"INSERT INTO access_tokens (token, client_id, subject, scope, created_at, expires_at, revoked)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 0)"#,
|
||||
[token.clone().into(), client_id.into(), subject.into(), scope.into(), now.into(), expires_at.into()],
|
||||
))
|
||||
.await?;
|
||||
Ok(AccessToken { token, client_id: client_id.to_string(), subject: subject.to_string(), scope: scope.to_string(), created_at: now, expires_at, revoked: 0 })
|
||||
}
|
||||
|
||||
pub async fn get_access_token(db: &DatabaseConnection, token: &str) -> Result<Option<AccessToken>, CrabError> {
|
||||
if let Some(row) = db
|
||||
.query_one(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"SELECT token, client_id, subject, scope, created_at, expires_at, revoked FROM access_tokens WHERE token = ?"#,
|
||||
[token.into()],
|
||||
))
|
||||
.await?
|
||||
{
|
||||
let revoked: i64 = row.try_get("", "revoked").unwrap_or_default();
|
||||
let expires_at: i64 = row.try_get("", "expires_at").unwrap_or_default();
|
||||
let now = Utc::now().timestamp();
|
||||
if revoked != 0 || now > expires_at { return Ok(None); }
|
||||
let token: String = row.try_get("", "token").unwrap_or_default();
|
||||
let client_id: String = row.try_get("", "client_id").unwrap_or_default();
|
||||
let subject: String = row.try_get("", "subject").unwrap_or_default();
|
||||
let scope: String = row.try_get("", "scope").unwrap_or_default();
|
||||
let created_at: i64 = row.try_get("", "created_at").unwrap_or_default();
|
||||
Ok(Some(AccessToken { token, client_id, subject, scope, created_at, expires_at, revoked }))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn random_id() -> String {
|
||||
let mut bytes = [0u8; 24];
|
||||
rand::thread_rng().fill_bytes(&mut bytes);
|
||||
base64ct::Base64UrlUnpadded::encode_string(&bytes)
|
||||
}
|
||||
|
||||
// User management functions
|
||||
|
||||
pub async fn create_user(
|
||||
db: &DatabaseConnection,
|
||||
username: &str,
|
||||
password: &str,
|
||||
email: Option<String>,
|
||||
) -> Result<User, CrabError> {
|
||||
use argon2::{Argon2, PasswordHasher};
|
||||
use argon2::password_hash::{SaltString, rand_core::OsRng};
|
||||
|
||||
let subject = random_id();
|
||||
let created_at = Utc::now().timestamp();
|
||||
|
||||
// Hash password with Argon2id
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
let password_hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map_err(|e| CrabError::Other(format!("Password hashing failed: {}", e)))?
|
||||
.to_string();
|
||||
|
||||
db.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"INSERT INTO users (subject, username, password_hash, email, email_verified, created_at, enabled)
|
||||
VALUES (?, ?, ?, ?, 0, ?, 1)"#,
|
||||
[
|
||||
subject.clone().into(),
|
||||
username.into(),
|
||||
password_hash.clone().into(),
|
||||
email.clone().into(),
|
||||
created_at.into(),
|
||||
],
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(User {
|
||||
subject,
|
||||
username: username.to_string(),
|
||||
password_hash,
|
||||
email,
|
||||
email_verified: 0,
|
||||
created_at,
|
||||
enabled: 1,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_user_by_username(
|
||||
db: &DatabaseConnection,
|
||||
username: &str,
|
||||
) -> Result<Option<User>, CrabError> {
|
||||
if let Some(row) = db
|
||||
.query_one(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"SELECT subject, username, password_hash, email, email_verified, created_at, enabled
|
||||
FROM users WHERE username = ?"#,
|
||||
[username.into()],
|
||||
))
|
||||
.await?
|
||||
{
|
||||
let subject: String = row.try_get("", "subject").unwrap_or_default();
|
||||
let username: String = row.try_get("", "username").unwrap_or_default();
|
||||
let password_hash: String = row.try_get("", "password_hash").unwrap_or_default();
|
||||
let email: Option<String> = row.try_get("", "email").ok();
|
||||
let email_verified: i64 = row.try_get("", "email_verified").unwrap_or_default();
|
||||
let created_at: i64 = row.try_get("", "created_at").unwrap_or_default();
|
||||
let enabled: i64 = row.try_get("", "enabled").unwrap_or_default();
|
||||
|
||||
Ok(Some(User {
|
||||
subject,
|
||||
username,
|
||||
password_hash,
|
||||
email,
|
||||
email_verified,
|
||||
created_at,
|
||||
enabled,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn verify_user_password(
|
||||
db: &DatabaseConnection,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> Result<Option<String>, CrabError> {
|
||||
use argon2::{Argon2, PasswordVerifier, PasswordHash};
|
||||
|
||||
let user = match get_user_by_username(db, username).await? {
|
||||
Some(u) if u.enabled == 1 => u,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
let parsed_hash = PasswordHash::new(&user.password_hash)
|
||||
.map_err(|e| CrabError::Other(format!("Invalid password hash: {}", e)))?;
|
||||
|
||||
if Argon2::default()
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
.is_ok()
|
||||
{
|
||||
Ok(Some(user.subject))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
// Session management functions
|
||||
|
||||
pub async fn create_session(
|
||||
db: &DatabaseConnection,
|
||||
subject: &str,
|
||||
ttl_secs: i64,
|
||||
user_agent: Option<String>,
|
||||
ip_address: Option<String>,
|
||||
) -> Result<Session, CrabError> {
|
||||
let session_id = random_id();
|
||||
let now = Utc::now().timestamp();
|
||||
let expires_at = now + ttl_secs;
|
||||
|
||||
db.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"INSERT INTO sessions (session_id, subject, auth_time, created_at, expires_at, user_agent, ip_address)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)"#,
|
||||
[
|
||||
session_id.clone().into(),
|
||||
subject.into(),
|
||||
now.into(),
|
||||
now.into(),
|
||||
expires_at.into(),
|
||||
user_agent.clone().into(),
|
||||
ip_address.clone().into(),
|
||||
],
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(Session {
|
||||
session_id,
|
||||
subject: subject.to_string(),
|
||||
auth_time: now,
|
||||
created_at: now,
|
||||
expires_at,
|
||||
user_agent,
|
||||
ip_address,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_session(
|
||||
db: &DatabaseConnection,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Session>, CrabError> {
|
||||
if let Some(row) = db
|
||||
.query_one(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"SELECT session_id, subject, auth_time, created_at, expires_at, user_agent, ip_address
|
||||
FROM sessions WHERE session_id = ?"#,
|
||||
[session_id.into()],
|
||||
))
|
||||
.await?
|
||||
{
|
||||
let session_id: String = row.try_get("", "session_id").unwrap_or_default();
|
||||
let subject: String = row.try_get("", "subject").unwrap_or_default();
|
||||
let auth_time: i64 = row.try_get("", "auth_time").unwrap_or_default();
|
||||
let created_at: i64 = row.try_get("", "created_at").unwrap_or_default();
|
||||
let expires_at: i64 = row.try_get("", "expires_at").unwrap_or_default();
|
||||
let user_agent: Option<String> = row.try_get("", "user_agent").ok();
|
||||
let ip_address: Option<String> = row.try_get("", "ip_address").ok();
|
||||
|
||||
// Check if session is expired
|
||||
let now = Utc::now().timestamp();
|
||||
if now > expires_at {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(Session {
|
||||
session_id,
|
||||
subject,
|
||||
auth_time,
|
||||
created_at,
|
||||
expires_at,
|
||||
user_agent,
|
||||
ip_address,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_session(
|
||||
db: &DatabaseConnection,
|
||||
session_id: &str,
|
||||
) -> Result<(), CrabError> {
|
||||
db.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
"DELETE FROM sessions WHERE session_id = ?",
|
||||
[session_id.into()],
|
||||
))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired_sessions(db: &DatabaseConnection) -> Result<u64, CrabError> {
|
||||
let now = Utc::now().timestamp();
|
||||
let result = db
|
||||
.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
"DELETE FROM sessions WHERE expires_at < ?",
|
||||
[now.into()],
|
||||
))
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
// Refresh Token Functions
|
||||
|
||||
pub async fn issue_refresh_token(
|
||||
db: &DatabaseConnection,
|
||||
client_id: &str,
|
||||
subject: &str,
|
||||
scope: &str,
|
||||
ttl_secs: i64,
|
||||
parent_token: Option<String>,
|
||||
) -> Result<RefreshToken, CrabError> {
|
||||
let token = random_id();
|
||||
let now = Utc::now().timestamp();
|
||||
let expires_at = now + ttl_secs;
|
||||
|
||||
db.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"INSERT INTO refresh_tokens (token, client_id, subject, scope, created_at, expires_at, revoked, parent_token)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 0, ?)"#,
|
||||
[
|
||||
token.clone().into(),
|
||||
client_id.into(),
|
||||
subject.into(),
|
||||
scope.into(),
|
||||
now.into(),
|
||||
expires_at.into(),
|
||||
parent_token.clone().into(),
|
||||
],
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(RefreshToken {
|
||||
token,
|
||||
client_id: client_id.to_string(),
|
||||
subject: subject.to_string(),
|
||||
scope: scope.to_string(),
|
||||
created_at: now,
|
||||
expires_at,
|
||||
revoked: 0,
|
||||
parent_token,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_refresh_token(
|
||||
db: &DatabaseConnection,
|
||||
token: &str,
|
||||
) -> Result<Option<RefreshToken>, CrabError> {
|
||||
let result = db
|
||||
.query_one(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
r#"SELECT token, client_id, subject, scope, created_at, expires_at, revoked, parent_token
|
||||
FROM refresh_tokens WHERE token = ?"#,
|
||||
[token.into()],
|
||||
))
|
||||
.await?;
|
||||
|
||||
if let Some(row) = result {
|
||||
let token: String = row.try_get("", "token")?;
|
||||
let client_id: String = row.try_get("", "client_id")?;
|
||||
let subject: String = row.try_get("", "subject")?;
|
||||
let scope: String = row.try_get("", "scope")?;
|
||||
let created_at: i64 = row.try_get("", "created_at")?;
|
||||
let expires_at: i64 = row.try_get("", "expires_at")?;
|
||||
let revoked: i64 = row.try_get("", "revoked")?;
|
||||
let parent_token: Option<String> = row.try_get("", "parent_token").ok();
|
||||
|
||||
// Check if token is expired or revoked
|
||||
let now = Utc::now().timestamp();
|
||||
if revoked != 0 || now > expires_at {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(RefreshToken {
|
||||
token,
|
||||
client_id,
|
||||
subject,
|
||||
scope,
|
||||
created_at,
|
||||
expires_at,
|
||||
revoked,
|
||||
parent_token,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn revoke_refresh_token(
|
||||
db: &DatabaseConnection,
|
||||
token: &str,
|
||||
) -> Result<(), CrabError> {
|
||||
db.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
"UPDATE refresh_tokens SET revoked = 1 WHERE token = ?",
|
||||
[token.into()],
|
||||
))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn rotate_refresh_token(
|
||||
db: &DatabaseConnection,
|
||||
old_token: &str,
|
||||
client_id: &str,
|
||||
subject: &str,
|
||||
scope: &str,
|
||||
ttl_secs: i64,
|
||||
) -> Result<RefreshToken, CrabError> {
|
||||
// Revoke the old token
|
||||
revoke_refresh_token(db, old_token).await?;
|
||||
|
||||
// Issue a new token with the old token as parent
|
||||
issue_refresh_token(db, client_id, subject, scope, ttl_secs, Some(old_token.to_string())).await
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired_refresh_tokens(db: &DatabaseConnection) -> Result<u64, CrabError> {
|
||||
let now = Utc::now().timestamp();
|
||||
let result = db
|
||||
.execute(Statement::from_sql_and_values(
|
||||
DbBackend::Sqlite,
|
||||
"DELETE FROM refresh_tokens WHERE expires_at < ?",
|
||||
[now.into()],
|
||||
))
|
||||
.await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
928
src/web.rs
Normal file
928
src/web.rs
Normal file
|
|
@ -0,0 +1,928 @@
|
|||
//! See docs/axum.md for axum usage patterns and router/state conventions.
|
||||
//! See docs/oidc-conformance.md for OIDC OP requirements, current status, and roadmap.
|
||||
//! This module exposes the HTTP endpoints, some of which are discovery and JWKS
|
||||
//! as required by OpenID Connect. Authorization, token, and userinfo will follow
|
||||
//! per the conformance plan.
|
||||
use crate::jwks::JwksManager;
|
||||
use crate::errors::CrabError;
|
||||
use crate::settings::Settings;
|
||||
use crate::storage;
|
||||
use crate::session::SessionCookie;
|
||||
use axum::extract::{Path, State, Query, Form};
|
||||
use axum::http::{StatusCode, HeaderMap, HeaderName, HeaderValue, Request};
|
||||
use axum::response::{IntoResponse, Html, Redirect};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use axum::middleware::{self, Next};
|
||||
use axum::body::Body;
|
||||
use miette::IntoDiagnostic;
|
||||
use sea_orm::DatabaseConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use sha2::{Digest, Sha256};
|
||||
use base64ct::{Base64UrlUnpadded, Encoding, Base64};
|
||||
use axum::response::Response;
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub settings: Arc<Settings>,
|
||||
pub db: DatabaseConnection,
|
||||
pub jwks: JwksManager,
|
||||
}
|
||||
|
||||
// Security headers middleware
|
||||
async fn security_headers(request: Request<Body>, next: Next) -> impl IntoResponse {
|
||||
let mut response = next.run(request).await;
|
||||
let headers = response.headers_mut();
|
||||
|
||||
// X-Frame-Options: Prevent clickjacking
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-frame-options"),
|
||||
HeaderValue::from_static("DENY"),
|
||||
);
|
||||
|
||||
// X-Content-Type-Options: Prevent MIME sniffing
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-content-type-options"),
|
||||
HeaderValue::from_static("nosniff"),
|
||||
);
|
||||
|
||||
// X-XSS-Protection: Legacy XSS protection (still useful for older browsers)
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-xss-protection"),
|
||||
HeaderValue::from_static("1; mode=block"),
|
||||
);
|
||||
|
||||
// Content-Security-Policy: Restrict resource loading
|
||||
headers.insert(
|
||||
HeaderName::from_static("content-security-policy"),
|
||||
HeaderValue::from_static("default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; form-action 'self'"),
|
||||
);
|
||||
|
||||
// Referrer-Policy: Control referrer information
|
||||
headers.insert(
|
||||
HeaderName::from_static("referrer-policy"),
|
||||
HeaderValue::from_static("strict-origin-when-cross-origin"),
|
||||
);
|
||||
|
||||
// Permissions-Policy: Disable unnecessary browser features
|
||||
headers.insert(
|
||||
HeaderName::from_static("permissions-policy"),
|
||||
HeaderValue::from_static("geolocation=(), microphone=(), camera=()"),
|
||||
);
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
pub async fn serve(settings: Settings, db: DatabaseConnection, jwks: JwksManager) -> miette::Result<()> {
|
||||
let state = AppState { settings: Arc::new(settings), db, jwks };
|
||||
|
||||
// NOTE: Rate limiting should be implemented at the reverse proxy level (nginx, traefik, etc.)
|
||||
// for production deployments. This is more efficient and flexible than application-level
|
||||
// rate limiting. Configure your reverse proxy with limits like:
|
||||
// - Token endpoint: 10 req/min per IP
|
||||
// - Login endpoint: 5 attempts/min per IP
|
||||
// - Authorize endpoint: 20 req/min per IP
|
||||
|
||||
let router = Router::new()
|
||||
.route("/.well-known/openid-configuration", get(discovery))
|
||||
.route("/.well-known/jwks.json", get(jwks_handler))
|
||||
.route("/connect/register", post(register_client))
|
||||
.route("/properties/{owner}/{key}", get(get_property))
|
||||
.route("/federation/trust-anchors", get(trust_anchors))
|
||||
.route("/register", post(register_user))
|
||||
.route("/login", get(login_page).post(login_submit))
|
||||
.route("/logout", get(logout))
|
||||
.route("/authorize", get(authorize))
|
||||
.route("/token", post(token))
|
||||
.route("/userinfo", get(userinfo))
|
||||
.layer(middleware::from_fn(security_headers))
|
||||
.with_state(state.clone());
|
||||
|
||||
let addr: SocketAddr = format!("{}:{}", state.settings.server.host, state.settings.server.port)
|
||||
.parse()
|
||||
.map_err(|e| miette::miette!("bad listen addr: {e}"))?;
|
||||
tracing::info!(%addr, "listening");
|
||||
tracing::warn!("Rate limiting should be configured at the reverse proxy level for production");
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.into_diagnostic()?;
|
||||
axum::serve(listener, router).await.into_diagnostic()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn discovery(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let issuer = state.settings.issuer();
|
||||
let metadata = json!({
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": format!("{}/authorize", issuer),
|
||||
"token_endpoint": format!("{}/token", issuer),
|
||||
"jwks_uri": format!("{}/.well-known/jwks.json", issuer),
|
||||
"registration_endpoint": format!("{}/connect/register", issuer),
|
||||
"userinfo_endpoint": format!("{}/userinfo", issuer),
|
||||
"scopes_supported": ["openid", "profile", "email", "offline_access"],
|
||||
"response_types_supported": ["code", "id_token", "id_token token"],
|
||||
"subject_types_supported": ["public"],
|
||||
"id_token_signing_alg_values_supported": [state.settings.keys.alg],
|
||||
// Additional recommended metadata for better interoperability
|
||||
"grant_types_supported": ["authorization_code", "refresh_token", "implicit"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"claims_supported": [
|
||||
"sub", "iss", "aud", "exp", "iat", "auth_time", "nonce",
|
||||
"name", "given_name", "family_name", "email", "email_verified"
|
||||
],
|
||||
// OIDC Core 1.0 features
|
||||
"prompt_values_supported": ["none", "login", "consent", "select_account"],
|
||||
"display_values_supported": ["page"],
|
||||
"ui_locales_supported": ["en"],
|
||||
"claim_types_supported": ["normal"],
|
||||
});
|
||||
Json(metadata)
|
||||
}
|
||||
|
||||
async fn jwks_handler(State(state): State<AppState>) -> impl IntoResponse {
|
||||
Json(state.jwks.jwks_json())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AuthorizeQuery {
|
||||
client_id: String,
|
||||
redirect_uri: String,
|
||||
response_type: String,
|
||||
scope: String,
|
||||
state: Option<String>,
|
||||
nonce: Option<String>,
|
||||
code_challenge: Option<String>,
|
||||
code_challenge_method: Option<String>,
|
||||
prompt: Option<String>,
|
||||
display: Option<String>,
|
||||
ui_locales: Option<String>,
|
||||
claims_locales: Option<String>,
|
||||
max_age: Option<String>,
|
||||
acr_values: Option<String>,
|
||||
}
|
||||
|
||||
fn url_append_query(mut base: String, params: &[(&str, String)]) -> String {
|
||||
let qs = serde_urlencoded::to_string(
|
||||
params.iter().map(|(k, v)| (k.to_string(), v.clone())).collect::<Vec<(String, String)>>(),
|
||||
).unwrap_or_default();
|
||||
if base.contains('?') { base.push('&'); } else { base.push('?'); }
|
||||
base.push_str(&qs);
|
||||
base
|
||||
}
|
||||
|
||||
fn oauth_error_redirect(redirect_uri: &str, state: Option<&str>, error: &str, desc: &str) -> axum::response::Redirect {
|
||||
let mut params = vec![("error", error.to_string())];
|
||||
if !desc.is_empty() { params.push(("error_description", desc.to_string())); }
|
||||
if let Some(s) = state { params.push(("state", s.to_string())); }
|
||||
let loc = url_append_query(redirect_uri.to_string(), ¶ms);
|
||||
axum::response::Redirect::temporary(&loc)
|
||||
}
|
||||
|
||||
// OIDC-specific error codes per OpenID Connect Core 1.0 Section 3.1.2.6
|
||||
// login_required: Authentication is required but prompt=none was specified
|
||||
// consent_required: Consent is required but prompt=none was specified
|
||||
// interaction_required: User interaction is required but prompt=none was specified
|
||||
// account_selection_required: Account selection is required but prompt=none was specified
|
||||
fn oidc_error_redirect(redirect_uri: &str, state: Option<&str>, error: &str) -> axum::response::Redirect {
|
||||
oauth_error_redirect(redirect_uri, state, error, "")
|
||||
}
|
||||
|
||||
async fn authorize(State(state): State<AppState>, headers: HeaderMap, Query(q): Query<AuthorizeQuery>) -> impl IntoResponse {
|
||||
// Validate response_type - support code, id_token, and id_token token
|
||||
let valid_response_types = ["code", "id_token", "id_token token"];
|
||||
if !valid_response_types.contains(&q.response_type.as_str()) {
|
||||
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "unsupported_response_type",
|
||||
"only response_type=code, id_token, or 'id_token token' supported").into_response();
|
||||
}
|
||||
// Validate scope includes openid
|
||||
if !q.scope.split_whitespace().any(|s| s == "openid") {
|
||||
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "invalid_scope", "scope must include openid").into_response();
|
||||
}
|
||||
// Require PKCE S256
|
||||
let (code_challenge, ccm) = match (&q.code_challenge, &q.code_challenge_method) {
|
||||
(Some(cc), Some(m)) if m == "S256" => (cc.clone(), m.clone()),
|
||||
_ => {
|
||||
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "invalid_request", "PKCE (S256) required").into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Lookup client
|
||||
let client = match storage::get_client(&state.db, &q.client_id).await {
|
||||
Ok(Some(c)) => c,
|
||||
Ok(None) => return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "unauthorized_client", "unknown client_id").into_response(),
|
||||
Err(_) => return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "server_error", "db error").into_response(),
|
||||
};
|
||||
// Validate redirect_uri exact match
|
||||
if !client.redirect_uris.iter().any(|u| u == &q.redirect_uri) {
|
||||
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "invalid_request", "redirect_uri mismatch").into_response();
|
||||
}
|
||||
|
||||
// Parse prompt parameter (can be space-separated list)
|
||||
let prompt_values: Vec<&str> = q.prompt.as_ref()
|
||||
.map(|p| p.split_whitespace().collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
let has_prompt_none = prompt_values.contains(&"none");
|
||||
let has_prompt_login = prompt_values.contains(&"login");
|
||||
let has_prompt_select_account = prompt_values.contains(&"select_account");
|
||||
|
||||
// Check for existing session (but ignore if prompt=login or prompt=select_account)
|
||||
let session_opt = if has_prompt_login || has_prompt_select_account {
|
||||
None // Force re-authentication
|
||||
} else if let Some(cookie) = SessionCookie::from_headers(&headers) {
|
||||
storage::get_session(&state.db, &cookie.session_id).await.ok().flatten()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Handle max_age parameter - requires re-authentication if session is too old
|
||||
let needs_fresh_auth = if let Some(max_age_str) = &q.max_age {
|
||||
if let Ok(max_age) = max_age_str.parse::<i64>() {
|
||||
if let Some(ref sess) = session_opt {
|
||||
let age = chrono::Utc::now().timestamp() - sess.auth_time;
|
||||
age > max_age
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let (subject, auth_time) = match session_opt {
|
||||
Some(sess) if sess.expires_at > chrono::Utc::now().timestamp() && !needs_fresh_auth => {
|
||||
(sess.subject.clone(), Some(sess.auth_time))
|
||||
}
|
||||
_ => {
|
||||
// No valid session or session too old
|
||||
// If prompt=none, return error instead of redirecting
|
||||
if has_prompt_none {
|
||||
return oidc_error_redirect(&q.redirect_uri, q.state.as_deref(), "login_required").into_response();
|
||||
}
|
||||
|
||||
// Build return_to URL with all parameters
|
||||
let mut return_params = vec![
|
||||
("client_id", q.client_id.clone()),
|
||||
("redirect_uri", q.redirect_uri.clone()),
|
||||
("response_type", q.response_type.clone()),
|
||||
("scope", q.scope.clone()),
|
||||
("code_challenge", code_challenge.clone()),
|
||||
("code_challenge_method", ccm.clone()),
|
||||
];
|
||||
if let Some(s) = &q.state {
|
||||
return_params.push(("state", s.clone()));
|
||||
}
|
||||
if let Some(n) = &q.nonce {
|
||||
return_params.push(("nonce", n.clone()));
|
||||
}
|
||||
if let Some(p) = &q.prompt {
|
||||
return_params.push(("prompt", p.clone()));
|
||||
}
|
||||
if let Some(d) = &q.display {
|
||||
return_params.push(("display", d.clone()));
|
||||
}
|
||||
if let Some(ui) = &q.ui_locales {
|
||||
return_params.push(("ui_locales", ui.clone()));
|
||||
}
|
||||
if let Some(cl) = &q.claims_locales {
|
||||
return_params.push(("claims_locales", cl.clone()));
|
||||
}
|
||||
if let Some(ma) = &q.max_age {
|
||||
return_params.push(("max_age", ma.clone()));
|
||||
}
|
||||
if let Some(acr) = &q.acr_values {
|
||||
return_params.push(("acr_values", acr.clone()));
|
||||
}
|
||||
|
||||
let return_to = url_append_query("/authorize".to_string(),
|
||||
&return_params.iter().map(|(k, v)| (*k, v.clone())).collect::<Vec<_>>());
|
||||
let login_url = format!("/login?return_to={}", urlencoded(&return_to));
|
||||
return Redirect::temporary(&login_url).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let scope = q.scope.clone();
|
||||
let nonce = q.nonce.clone();
|
||||
|
||||
// Handle different response types
|
||||
match q.response_type.as_str() {
|
||||
"code" => {
|
||||
// Authorization Code Flow - issue auth code
|
||||
let ttl = 300; // 5 minutes
|
||||
match storage::issue_auth_code(&state.db, &q.client_id, &q.redirect_uri, &scope, &subject, nonce, &code_challenge, &ccm, ttl, auth_time).await {
|
||||
Ok(code) => {
|
||||
let mut params = vec![("code", code.code.clone())];
|
||||
if let Some(s) = &q.state { params.push(("state", s.clone())); }
|
||||
let loc = url_append_query(q.redirect_uri.clone(), ¶ms);
|
||||
axum::response::Redirect::temporary(&loc).into_response()
|
||||
}
|
||||
Err(_) => oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "server_error", "could not issue code").into_response(),
|
||||
}
|
||||
}
|
||||
"id_token" => {
|
||||
// Implicit Flow - return ID token in fragment
|
||||
// Require nonce for implicit flow (OIDC Core 1.0 Section 3.2.2.1)
|
||||
if nonce.is_none() {
|
||||
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "invalid_request", "nonce required for implicit flow").into_response();
|
||||
}
|
||||
|
||||
match build_id_token(&state, &q.client_id, &subject, nonce.as_deref(), auth_time, None).await {
|
||||
Ok(id_token) => {
|
||||
let mut fragment_params = vec![("id_token", id_token)];
|
||||
if let Some(s) = &q.state {
|
||||
fragment_params.push(("state", s.clone()));
|
||||
}
|
||||
let fragment = serde_urlencoded::to_string(&fragment_params).unwrap_or_default();
|
||||
let loc = format!("{}#{}", q.redirect_uri, fragment);
|
||||
axum::response::Redirect::temporary(&loc).into_response()
|
||||
}
|
||||
Err(_) => oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "server_error", "could not generate id_token").into_response(),
|
||||
}
|
||||
}
|
||||
"id_token token" => {
|
||||
// Implicit Flow - return both ID token and access token in fragment
|
||||
// Require nonce for implicit flow
|
||||
if nonce.is_none() {
|
||||
return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "invalid_request", "nonce required for implicit flow").into_response();
|
||||
}
|
||||
|
||||
// Issue access token
|
||||
let access = match storage::issue_access_token(&state.db, &q.client_id, &subject, &scope, 3600).await {
|
||||
Ok(t) => t,
|
||||
Err(_) => return oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "server_error", "could not issue access token").into_response(),
|
||||
};
|
||||
|
||||
// Build ID token with at_hash
|
||||
match build_id_token(&state, &q.client_id, &subject, nonce.as_deref(), auth_time, Some(&access.token)).await {
|
||||
Ok(id_token) => {
|
||||
let mut fragment_params = vec![
|
||||
("access_token", access.token),
|
||||
("token_type", "bearer".to_string()),
|
||||
("expires_in", "3600".to_string()),
|
||||
("id_token", id_token),
|
||||
];
|
||||
if let Some(s) = &q.state {
|
||||
fragment_params.push(("state", s.clone()));
|
||||
}
|
||||
let fragment = serde_urlencoded::to_string(&fragment_params).unwrap_or_default();
|
||||
let loc = format!("{}#{}", q.redirect_uri, fragment);
|
||||
axum::response::Redirect::temporary(&loc).into_response()
|
||||
}
|
||||
Err(_) => oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "server_error", "could not generate id_token").into_response(),
|
||||
}
|
||||
}
|
||||
_ => oauth_error_redirect(&q.redirect_uri, q.state.as_deref(), "unsupported_response_type", "invalid response_type").into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to build ID token
|
||||
async fn build_id_token(
|
||||
state: &AppState,
|
||||
client_id: &str,
|
||||
subject: &str,
|
||||
nonce: Option<&str>,
|
||||
auth_time: Option<i64>,
|
||||
access_token: Option<&str>, // For at_hash calculation
|
||||
) -> Result<String, CrabError> {
|
||||
let now = SystemTime::now();
|
||||
let exp_unix = now.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() as i64 + 3600;
|
||||
|
||||
let mut payload = josekit::jwt::JwtPayload::new();
|
||||
payload.set_issuer(state.settings.issuer());
|
||||
payload.set_subject(subject.to_string());
|
||||
payload.set_audience(vec![client_id.to_string()]);
|
||||
payload.set_issued_at(&now);
|
||||
let _ = payload.set_claim("exp", Some(serde_json::json!(exp_unix)));
|
||||
|
||||
if let Some(n) = nonce {
|
||||
let _ = payload.set_claim("nonce", Some(serde_json::Value::String(n.to_string())));
|
||||
}
|
||||
|
||||
if let Some(at) = auth_time {
|
||||
let _ = payload.set_claim("auth_time", Some(serde_json::json!(at)));
|
||||
}
|
||||
|
||||
// Add at_hash if access_token is provided (for id_token token response type)
|
||||
if let Some(token) = access_token {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(token.as_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let half = &digest[..16]; // left-most 128 bits
|
||||
let at_hash = Base64UrlUnpadded::encode_string(half);
|
||||
let _ = payload.set_claim("at_hash", Some(serde_json::Value::String(at_hash)));
|
||||
}
|
||||
|
||||
state.jwks.sign_jwt_rs256(&payload).map_err(|e| CrabError::Other(e.to_string()))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenRequest {
|
||||
grant_type: String,
|
||||
code: Option<String>,
|
||||
redirect_uri: Option<String>,
|
||||
client_id: Option<String>,
|
||||
client_secret: Option<String>,
|
||||
code_verifier: Option<String>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
expires_in: i64,
|
||||
id_token: Option<String>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
fn pkce_s256(verifier: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(verifier.as_bytes());
|
||||
let digest = hasher.finalize();
|
||||
// Take full digest then base64url without padding
|
||||
Base64UrlUnpadded::encode_string(&digest)
|
||||
}
|
||||
|
||||
fn json_with_headers(status: StatusCode, value: Value, headers: &[(&str, String)]) -> Response {
|
||||
let mut resp = (status, Json(value)).into_response();
|
||||
let h = resp.headers_mut();
|
||||
for (name, val) in headers {
|
||||
if let (Ok(n), Ok(v)) = (HeaderName::from_bytes(name.as_bytes()), HeaderValue::from_str(val)) {
|
||||
h.insert(n, v);
|
||||
}
|
||||
}
|
||||
resp
|
||||
}
|
||||
|
||||
async fn token(State(state): State<AppState>, headers: HeaderMap, Form(req): Form<TokenRequest>) -> impl IntoResponse {
|
||||
// Validate grant_type
|
||||
match req.grant_type.as_str() {
|
||||
"authorization_code" => handle_authorization_code_grant(state, headers, req).await,
|
||||
"refresh_token" => handle_refresh_token_grant(state, headers, req).await,
|
||||
_ => (StatusCode::BAD_REQUEST, Json(json!({"error":"unsupported_grant_type"}))).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_authorization_code_grant(state: AppState, headers: HeaderMap, req: TokenRequest) -> Response {
|
||||
// Client authentication: client_secret_basic preferred, then client_secret_post
|
||||
let (client_id, client_secret) = match authenticate_client(&headers, &req) {
|
||||
Ok(pair) => pair,
|
||||
Err(resp) => return resp,
|
||||
};
|
||||
|
||||
let client = match storage::get_client(&state.db, &client_id).await {
|
||||
Ok(Some(c)) => c,
|
||||
_ => {
|
||||
return json_with_headers(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({"error":"invalid_client"}),
|
||||
&[("www-authenticate", "Basic realm=\"token\", error=\"invalid_client\"".to_string())],
|
||||
)
|
||||
}
|
||||
};
|
||||
if client.client_secret != client_secret {
|
||||
return json_with_headers(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({"error":"invalid_client"}),
|
||||
&[("www-authenticate", "Basic realm=\"token\", error=\"invalid_client\"".to_string())],
|
||||
);
|
||||
}
|
||||
|
||||
// Require code
|
||||
let code = match req.code {
|
||||
Some(c) => c,
|
||||
None => return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_request","error_description":"code required"}))).into_response(),
|
||||
};
|
||||
|
||||
// Consume code
|
||||
let code_row = match storage::consume_auth_code(&state.db, &code).await {
|
||||
Ok(Some(c)) => c,
|
||||
_ => return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_grant"}))).into_response(),
|
||||
};
|
||||
|
||||
// Validate code binding
|
||||
let redirect_uri = req.redirect_uri.unwrap_or_default();
|
||||
if code_row.client_id != client_id || code_row.redirect_uri != redirect_uri {
|
||||
return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_grant"}))).into_response();
|
||||
}
|
||||
|
||||
// Validate PKCE S256
|
||||
let verifier = match &req.code_verifier {
|
||||
Some(v) => v,
|
||||
None => return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_request","error_description":"code_verifier required"}))).into_response(),
|
||||
};
|
||||
if code_row.code_challenge_method != "S256" || pkce_s256(verifier) != code_row.code_challenge {
|
||||
return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_grant","error_description":"pkce verification failed"}))).into_response();
|
||||
}
|
||||
|
||||
// Issue access token
|
||||
let access = match storage::issue_access_token(&state.db, &client_id, &code_row.subject, &code_row.scope, 3600).await {
|
||||
Ok(t) => t,
|
||||
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error":"server_error","details":e.to_string()}))).into_response(),
|
||||
};
|
||||
|
||||
// Build ID Token using helper function
|
||||
let id_token = match build_id_token(&state, &client_id, &code_row.subject, code_row.nonce.as_deref(), code_row.auth_time, Some(&access.token)).await {
|
||||
Ok(t) => t,
|
||||
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error":"server_error","details":e.to_string()}))).into_response(),
|
||||
};
|
||||
|
||||
// Issue refresh token if offline_access scope was requested
|
||||
let refresh_token = if code_row.scope.split_whitespace().any(|s| s == "offline_access") {
|
||||
match storage::issue_refresh_token(&state.db, &client_id, &code_row.subject, &code_row.scope, 2592000, None).await {
|
||||
Ok(rt) => Some(rt.token),
|
||||
Err(_) => None, // Don't fail the whole request if refresh token issuance fails
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let resp = TokenResponse {
|
||||
access_token: access.token,
|
||||
token_type: "bearer".into(),
|
||||
expires_in: 3600,
|
||||
id_token: Some(id_token),
|
||||
refresh_token,
|
||||
};
|
||||
|
||||
// Add Cache-Control: no-store as required by OAuth 2.0 and OIDC specs
|
||||
json_with_headers(
|
||||
StatusCode::OK,
|
||||
serde_json::to_value(resp).unwrap(),
|
||||
&[
|
||||
("cache-control", "no-store".to_string()),
|
||||
("pragma", "no-cache".to_string()),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_refresh_token_grant(state: AppState, headers: HeaderMap, req: TokenRequest) -> Response {
|
||||
// Client authentication
|
||||
let (client_id, client_secret) = match authenticate_client(&headers, &req) {
|
||||
Ok(pair) => pair,
|
||||
Err(resp) => return resp,
|
||||
};
|
||||
|
||||
let client = match storage::get_client(&state.db, &client_id).await {
|
||||
Ok(Some(c)) => c,
|
||||
_ => {
|
||||
return json_with_headers(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({"error":"invalid_client"}),
|
||||
&[("www-authenticate", "Basic realm=\"token\", error=\"invalid_client\"".to_string())],
|
||||
)
|
||||
}
|
||||
};
|
||||
if client.client_secret != client_secret {
|
||||
return json_with_headers(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({"error":"invalid_client"}),
|
||||
&[("www-authenticate", "Basic realm=\"token\", error=\"invalid_client\"".to_string())],
|
||||
);
|
||||
}
|
||||
|
||||
// Require refresh_token
|
||||
let refresh_token_str = match req.refresh_token {
|
||||
Some(rt) => rt,
|
||||
None => return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_request","error_description":"refresh_token required"}))).into_response(),
|
||||
};
|
||||
|
||||
// Get and validate refresh token
|
||||
let refresh_token = match storage::get_refresh_token(&state.db, &refresh_token_str).await {
|
||||
Ok(Some(rt)) => rt,
|
||||
_ => return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_grant","error_description":"invalid refresh_token"}))).into_response(),
|
||||
};
|
||||
|
||||
// Validate client_id matches
|
||||
if refresh_token.client_id != client_id {
|
||||
return (StatusCode::BAD_REQUEST, Json(json!({"error":"invalid_grant"}))).into_response();
|
||||
}
|
||||
|
||||
// Issue new access token
|
||||
let access = match storage::issue_access_token(&state.db, &client_id, &refresh_token.subject, &refresh_token.scope, 3600).await {
|
||||
Ok(t) => t,
|
||||
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error":"server_error","details":e.to_string()}))).into_response(),
|
||||
};
|
||||
|
||||
// Build ID Token (no nonce, no auth_time for refresh grants)
|
||||
let id_token = match build_id_token(&state, &client_id, &refresh_token.subject, None, None, Some(&access.token)).await {
|
||||
Ok(t) => t,
|
||||
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error":"server_error","details":e.to_string()}))).into_response(),
|
||||
};
|
||||
|
||||
// Rotate refresh token (issue new one and revoke old one)
|
||||
let new_refresh_token = match storage::rotate_refresh_token(&state.db, &refresh_token_str, &client_id, &refresh_token.subject, &refresh_token.scope, 2592000).await {
|
||||
Ok(rt) => Some(rt.token),
|
||||
Err(_) => None, // Don't fail the whole request if rotation fails
|
||||
};
|
||||
|
||||
let resp = TokenResponse {
|
||||
access_token: access.token,
|
||||
token_type: "bearer".into(),
|
||||
expires_in: 3600,
|
||||
id_token: Some(id_token),
|
||||
refresh_token: new_refresh_token,
|
||||
};
|
||||
|
||||
// Add Cache-Control: no-store as required by OAuth 2.0 and OIDC specs
|
||||
json_with_headers(
|
||||
StatusCode::OK,
|
||||
serde_json::to_value(resp).unwrap(),
|
||||
&[
|
||||
("cache-control", "no-store".to_string()),
|
||||
("pragma", "no-cache".to_string()),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn authenticate_client(headers: &HeaderMap, req: &TokenRequest) -> Result<(String, String), Response> {
|
||||
// Try client_secret_basic first (Authorization header)
|
||||
let mut basic_client: Option<(String, String)> = None;
|
||||
if let Some(auth_val) = headers.get(axum::http::header::AUTHORIZATION).and_then(|h| h.to_str().ok()) {
|
||||
if let Some(b64) = auth_val.strip_prefix("Basic ") {
|
||||
if let Ok(mut decoded) = Base64::decode_vec(b64) {
|
||||
if let Ok(s) = String::from_utf8(std::mem::take(&mut decoded)) {
|
||||
if let Some((id, sec)) = s.split_once(':') {
|
||||
basic_client = Some((id.to_string(), sec.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pair) = basic_client {
|
||||
Ok(pair)
|
||||
} else {
|
||||
// Try client_secret_post (form body)
|
||||
match (req.client_id.clone(), req.client_secret.clone()) {
|
||||
(Some(id), Some(sec)) => Ok((id, sec)),
|
||||
_ => {
|
||||
Err(json_with_headers(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({"error":"invalid_client","error_description":"missing client authentication"}),
|
||||
&[("www-authenticate", "Basic realm=\"token\", error=\"invalid_client\"".to_string())],
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn userinfo(State(state): State<AppState>, req: axum::http::Request<axum::body::Body>) -> impl IntoResponse {
|
||||
// Extract bearer token
|
||||
let token_opt = req.headers().get(axum::http::header::AUTHORIZATION).and_then(|h| h.to_str().ok()).and_then(|s| s.strip_prefix("Bearer ")).map(|s| s.to_string());
|
||||
let token = match token_opt {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
return json_with_headers(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({"error":"invalid_token"}),
|
||||
&[("www-authenticate", "Bearer realm=\"userinfo\", error=\"invalid_token\"".to_string())],
|
||||
)
|
||||
}
|
||||
};
|
||||
let token_row = match storage::get_access_token(&state.db, &token).await {
|
||||
Ok(Some(t)) => t,
|
||||
_ => {
|
||||
return json_with_headers(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
json!({"error":"invalid_token"}),
|
||||
&[("www-authenticate", "Bearer realm=\"userinfo\", error=\"invalid_token\"".to_string())],
|
||||
)
|
||||
}
|
||||
};
|
||||
let mut claims = serde_json::Map::new();
|
||||
claims.insert("sub".to_string(), serde_json::Value::String(token_row.subject.clone()));
|
||||
// Optional: email claims from properties
|
||||
if let Ok(Some(email)) = storage::get_property(&state.db, &token_row.subject, "email").await {
|
||||
if let Some(email_str) = email.as_str() {
|
||||
claims.insert("email".to_string(), serde_json::Value::String(email_str.to_string()));
|
||||
}
|
||||
}
|
||||
if let Ok(Some(verified)) = storage::get_property(&state.db, &token_row.subject, "email_verified").await {
|
||||
claims.insert("email_verified".to_string(), verified);
|
||||
}
|
||||
(StatusCode::OK, Json(serde_json::Value::Object(claims))).into_response()
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RegistrationRequest {
|
||||
client_name: Option<String>,
|
||||
redirect_uris: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RegistrationResponse {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
client_name: Option<String>,
|
||||
redirect_uris: Vec<String>,
|
||||
client_id_issued_at: i64,
|
||||
token_endpoint_auth_method: String,
|
||||
}
|
||||
|
||||
async fn register_client(State(state): State<AppState>, Json(req): Json<RegistrationRequest>) -> impl IntoResponse {
|
||||
if req.redirect_uris.is_empty() {
|
||||
return (StatusCode::BAD_REQUEST, Json(json!({"error": "invalid_client_metadata", "error_description": "redirect_uris required"}))).into_response();
|
||||
}
|
||||
let input = storage::NewClient { client_name: req.client_name.clone(), redirect_uris: req.redirect_uris.clone() };
|
||||
match storage::create_client(&state.db, input).await {
|
||||
Ok(client) => {
|
||||
let resp = RegistrationResponse {
|
||||
client_id: client.client_id,
|
||||
client_secret: client.client_secret,
|
||||
client_name: client.client_name,
|
||||
redirect_uris: client.redirect_uris,
|
||||
client_id_issued_at: client.created_at,
|
||||
token_endpoint_auth_method: "client_secret_post".into(),
|
||||
};
|
||||
(StatusCode::CREATED, Json(serde_json::to_value(resp).unwrap())).into_response()
|
||||
}
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_property(State(state): State<AppState>, Path((owner, key)): Path<(String, String)>) -> impl IntoResponse {
|
||||
match storage::get_property(&state.db, &owner, &key).await {
|
||||
Ok(Some(v)) => (StatusCode::OK, Json(v)).into_response(),
|
||||
Ok(None) => (StatusCode::NOT_FOUND, Json(json!({"error": "not_found"}))).into_response(),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn set_property(State(state): State<AppState>, Path((owner, key)): Path<(String, String)>, Json(v): Json<Value>) -> impl IntoResponse {
|
||||
match storage::set_property(&state.db, &owner, &key, &v).await {
|
||||
Ok(_) => (StatusCode::NO_CONTENT, ()).into_response(),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn trust_anchors(State(state): State<AppState>) -> impl IntoResponse {
|
||||
Json(json!({ "trust_anchors": state.settings.federation.trust_anchors }))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LoginQuery {
|
||||
return_to: Option<String>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
async fn login_page(Query(q): Query<LoginQuery>) -> impl IntoResponse {
|
||||
let error_html = if let Some(err) = q.error {
|
||||
format!("<p style='color: red;'>{}</p>", html_escape(&err))
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let return_to = q.return_to.unwrap_or_default();
|
||||
|
||||
let html = format!(r#"
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Login - Barycenter OpenID Provider</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; max-width: 400px; margin: 100px auto; padding: 20px; }}
|
||||
h1 {{ color: #333; }}
|
||||
label {{ display: block; margin-top: 10px; }}
|
||||
input[type="text"], input[type="password"] {{ width: 100%; padding: 8px; margin-top: 5px; box-sizing: border-box; }}
|
||||
button {{ margin-top: 20px; padding: 10px 20px; background-color: #007bff; color: white; border: none; cursor: pointer; }}
|
||||
button:hover {{ background-color: #0056b3; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Login</h1>
|
||||
{error_html}
|
||||
<form method="POST" action="/login">
|
||||
<input type="hidden" name="return_to" value="{return_to}">
|
||||
<label>
|
||||
Username:
|
||||
<input type="text" name="username" required autofocus>
|
||||
</label>
|
||||
<label>
|
||||
Password:
|
||||
<input type="password" name="password" required>
|
||||
</label>
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
"#);
|
||||
|
||||
Html(html)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LoginForm {
|
||||
username: String,
|
||||
password: String,
|
||||
return_to: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RegisterForm {
|
||||
username: String,
|
||||
password: String,
|
||||
email: Option<String>,
|
||||
}
|
||||
|
||||
async fn register_user(
|
||||
State(state): State<AppState>,
|
||||
Form(form): Form<RegisterForm>,
|
||||
) -> impl IntoResponse {
|
||||
// Create the user
|
||||
match storage::create_user(&state.db, &form.username, &form.password, form.email).await {
|
||||
Ok(_) => {
|
||||
// Return success response
|
||||
Response::builder()
|
||||
.status(StatusCode::CREATED)
|
||||
.body(Body::from("User created"))
|
||||
.unwrap()
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
// Return error response
|
||||
Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::from(format!("Failed to create user: {}", e)))
|
||||
.unwrap()
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn login_submit(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Form(form): Form<LoginForm>,
|
||||
) -> impl IntoResponse {
|
||||
// Verify credentials
|
||||
let subject = match storage::verify_user_password(&state.db, &form.username, &form.password).await {
|
||||
Ok(Some(sub)) => sub,
|
||||
_ => {
|
||||
// Redirect back to login with error
|
||||
let return_to = urlencoded(&form.return_to.unwrap_or_default());
|
||||
let error = urlencoded("Invalid username or password");
|
||||
return Redirect::temporary(&format!("/login?error={error}&return_to={return_to}")).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Create session
|
||||
let user_agent = headers
|
||||
.get(axum::http::header::USER_AGENT)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let session = match storage::create_session(&state.db, &subject, 3600, user_agent, None).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
let return_to = urlencoded(&form.return_to.unwrap_or_default());
|
||||
let error = urlencoded("Failed to create session");
|
||||
return Redirect::temporary(&format!("/login?error={error}&return_to={return_to}")).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Set cookie and redirect
|
||||
let cookie = SessionCookie::new(session.session_id);
|
||||
let cookie_header = cookie.to_cookie_header(&state.settings);
|
||||
|
||||
let redirect_url = form.return_to.unwrap_or_else(|| "/".to_string());
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::SEE_OTHER)
|
||||
.header(axum::http::header::SET_COOKIE, cookie_header)
|
||||
.header(axum::http::header::LOCATION, redirect_url)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn logout(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
|
||||
if let Some(cookie) = SessionCookie::from_headers(&headers) {
|
||||
let _ = storage::delete_session(&state.db, &cookie.session_id).await;
|
||||
}
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::SEE_OTHER)
|
||||
.header(axum::http::header::SET_COOKIE, SessionCookie::delete_cookie_header())
|
||||
.header(axum::http::header::LOCATION, "/")
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn html_escape(s: &str) -> String {
|
||||
s.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
.replace('\'', "'")
|
||||
}
|
||||
|
||||
fn urlencoded(s: &str) -> String {
|
||||
serde_urlencoded::to_string(&[("", s)]).unwrap_or_default().trim_start_matches('=').to_string()
|
||||
}
|
||||
479
tests/integration_test.rs
Normal file
479
tests/integration_test.rs
Normal file
|
|
@ -0,0 +1,479 @@
|
|||
use std::process::{Child, Command};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use base64ct::Encoding;
|
||||
use sha2::Digest;
|
||||
|
||||
/// Helper to start the barycenter server for integration tests
|
||||
struct TestServer {
|
||||
process: Child,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl TestServer {
|
||||
fn start() -> Self {
|
||||
let port = 8080;
|
||||
let base_url = format!("http://0.0.0.0:{}", port);
|
||||
|
||||
let process = Command::new("cargo")
|
||||
.args(["run", "--release", "--"])
|
||||
.env("RUST_LOG", "error")
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.spawn()
|
||||
.expect("Failed to start server");
|
||||
|
||||
// Wait for server to start - give it more time for first compilation
|
||||
thread::sleep(Duration::from_secs(5));
|
||||
|
||||
// Verify server is running by checking discovery endpoint
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let max_retries = 30;
|
||||
for i in 0..max_retries {
|
||||
if let Ok(_) = client
|
||||
.get(format!("{}/.well-known/openid-configuration", base_url))
|
||||
.send()
|
||||
{
|
||||
println!("Server started successfully");
|
||||
return Self { process, base_url };
|
||||
}
|
||||
if i < max_retries - 1 {
|
||||
thread::sleep(Duration::from_secs(1));
|
||||
}
|
||||
}
|
||||
|
||||
panic!("Server failed to start within timeout");
|
||||
}
|
||||
|
||||
fn base_url(&self) -> &str {
|
||||
&self.base_url
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TestServer {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.process.kill();
|
||||
let _ = self.process.wait();
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a test client with the IdP
|
||||
fn register_client(base_url: &str) -> (String, String, String) {
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let redirect_uri = "http://localhost:3000/callback";
|
||||
|
||||
let response = client
|
||||
.post(format!("{}/connect/register", base_url))
|
||||
.json(&serde_json::json!({
|
||||
"redirect_uris": [redirect_uri],
|
||||
"token_endpoint_auth_method": "client_secret_basic"
|
||||
}))
|
||||
.send()
|
||||
.expect("Failed to register client")
|
||||
.json::<serde_json::Value>()
|
||||
.expect("Failed to parse registration response");
|
||||
|
||||
let client_id = response["client_id"]
|
||||
.as_str()
|
||||
.expect("No client_id in response")
|
||||
.to_string();
|
||||
let client_secret = response["client_secret"]
|
||||
.as_str()
|
||||
.expect("No client_secret in response")
|
||||
.to_string();
|
||||
|
||||
(client_id, client_secret, redirect_uri.to_string())
|
||||
}
|
||||
|
||||
/// Perform login and return an HTTP client with session cookie
|
||||
fn login_and_get_client(base_url: &str, username: &str, password: &str) -> (reqwest::blocking::Client, std::sync::Arc<reqwest::cookie::Jar>) {
|
||||
let jar = std::sync::Arc::new(reqwest::cookie::Jar::default());
|
||||
let client = reqwest::blocking::ClientBuilder::new()
|
||||
.cookie_provider(jar.clone())
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.expect("Failed to build client");
|
||||
|
||||
// First register the user
|
||||
let _register_response = client
|
||||
.post(format!("{}/register", base_url))
|
||||
.form(&[
|
||||
("username", username),
|
||||
("password", password),
|
||||
("email", "test@example.com"),
|
||||
])
|
||||
.send()
|
||||
.expect("Failed to register user");
|
||||
|
||||
// Then login to create a session
|
||||
let _login_response = client
|
||||
.post(format!("{}/login", base_url))
|
||||
.form(&[
|
||||
("username", username),
|
||||
("password", password),
|
||||
])
|
||||
.send()
|
||||
.expect("Failed to login");
|
||||
|
||||
(client, jar)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openidconnect_authorization_code_flow() {
|
||||
use openidconnect::{
|
||||
core::{CoreClient, CoreProviderMetadata},
|
||||
AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl,
|
||||
Nonce, OAuth2TokenResponse, PkceCodeChallenge, RedirectUrl, Scope, TokenResponse,
|
||||
};
|
||||
|
||||
let server = TestServer::start();
|
||||
let (client_id, client_secret, redirect_uri) = register_client(server.base_url());
|
||||
let (authenticated_client, _jar) = login_and_get_client(server.base_url(), "testuser", "testpass123");
|
||||
|
||||
let issuer_url = IssuerUrl::new(server.base_url().to_string())
|
||||
.expect("Invalid issuer URL");
|
||||
|
||||
let http_client = reqwest::blocking::ClientBuilder::new()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.expect("Failed to build HTTP client");
|
||||
|
||||
let provider_metadata = CoreProviderMetadata::discover(&issuer_url, &http_client)
|
||||
.expect("Failed to discover provider metadata");
|
||||
|
||||
let client = CoreClient::from_provider_metadata(
|
||||
provider_metadata,
|
||||
ClientId::new(client_id.clone()),
|
||||
Some(ClientSecret::new(client_secret.clone())),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_uri.clone()).expect("Invalid redirect URI"));
|
||||
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
use openidconnect::core::CoreAuthenticationFlow;
|
||||
let (auth_url, csrf_token, _nonce) = client
|
||||
.authorize_url(
|
||||
CoreAuthenticationFlow::AuthorizationCode,
|
||||
CsrfToken::new_random,
|
||||
Nonce::new_random
|
||||
)
|
||||
.add_scope(Scope::new("openid".to_string()))
|
||||
.add_scope(Scope::new("profile".to_string()))
|
||||
.add_scope(Scope::new("email".to_string()))
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.url();
|
||||
|
||||
println!("Authorization URL: {}", auth_url);
|
||||
|
||||
let auth_response = authenticated_client
|
||||
.get(auth_url.as_str())
|
||||
.send()
|
||||
.expect("Failed to request authorization");
|
||||
|
||||
let status = auth_response.status();
|
||||
assert!(status.is_redirection(), "Expected redirect, got {}", status);
|
||||
|
||||
let location = auth_response
|
||||
.headers()
|
||||
.get("location")
|
||||
.expect("No location header")
|
||||
.to_str()
|
||||
.expect("Invalid location header");
|
||||
|
||||
println!("Redirect location: {}", location);
|
||||
|
||||
let redirect_url_parsed = if location.starts_with("http") {
|
||||
url::Url::parse(location).expect("Invalid redirect URL")
|
||||
} else {
|
||||
let base_url_for_redirect = url::Url::parse(&redirect_uri).expect("Invalid redirect URI");
|
||||
base_url_for_redirect.join(location).expect("Invalid redirect URL")
|
||||
};
|
||||
let code = redirect_url_parsed
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == "code")
|
||||
.map(|(_, v)| v.to_string())
|
||||
.expect("No code in redirect");
|
||||
|
||||
let returned_state = redirect_url_parsed
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == "state")
|
||||
.map(|(_, v)| v.to_string())
|
||||
.expect("No state in redirect");
|
||||
|
||||
assert_eq!(returned_state, *csrf_token.secret());
|
||||
|
||||
let http_client = reqwest::blocking::ClientBuilder::new()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.expect("Failed to build HTTP client");
|
||||
|
||||
let token_response = client
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.expect("Failed to create code exchange request")
|
||||
.set_pkce_verifier(pkce_verifier)
|
||||
.request(&http_client)
|
||||
.expect("Failed to exchange code for token");
|
||||
|
||||
assert!(token_response.access_token().secret().len() > 0);
|
||||
assert!(token_response.id_token().is_some());
|
||||
|
||||
let id_token = token_response.id_token().expect("No ID token");
|
||||
|
||||
// For testing purposes, we'll decode the ID token without signature verification
|
||||
// In production, signature verification is critical and performed by the library
|
||||
use base64ct::Encoding;
|
||||
let id_token_str = id_token.to_string();
|
||||
let parts: Vec<&str> = id_token_str.split('.').collect();
|
||||
assert_eq!(parts.len(), 3, "ID token should have 3 parts");
|
||||
|
||||
let payload = base64ct::Base64UrlUnpadded::decode_vec(parts[1])
|
||||
.expect("Failed to decode ID token payload");
|
||||
let claims: serde_json::Value = serde_json::from_slice(&payload)
|
||||
.expect("Failed to parse ID token claims");
|
||||
|
||||
// Verify required claims
|
||||
assert!(claims["sub"].is_string());
|
||||
assert!(!claims["sub"].as_str().unwrap().is_empty());
|
||||
assert_eq!(claims["iss"].as_str().unwrap(), server.base_url());
|
||||
assert_eq!(claims["aud"].as_str().unwrap(), client_id);
|
||||
assert!(claims["exp"].is_number());
|
||||
assert!(claims["iat"].is_number());
|
||||
assert!(claims["nonce"].is_string());
|
||||
|
||||
println!("✓ openidconnect-rs: Authorization Code + PKCE flow successful");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oauth2_authorization_code_flow() {
|
||||
use oauth2::{
|
||||
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId,
|
||||
ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope, TokenResponse, TokenUrl,
|
||||
};
|
||||
|
||||
let server = TestServer::start();
|
||||
let (client_id, client_secret, redirect_uri) = register_client(server.base_url());
|
||||
let (authenticated_client, _jar) = login_and_get_client(server.base_url(), "testuser2", "testpass123");
|
||||
|
||||
let http_client_blocking = reqwest::blocking::Client::new();
|
||||
let discovery_response = http_client_blocking
|
||||
.get(format!(
|
||||
"{}/.well-known/openid-configuration",
|
||||
server.base_url()
|
||||
))
|
||||
.send()
|
||||
.expect("Failed to fetch discovery")
|
||||
.json::<serde_json::Value>()
|
||||
.expect("Failed to parse discovery");
|
||||
|
||||
let auth_url = AuthUrl::new(
|
||||
discovery_response["authorization_endpoint"]
|
||||
.as_str()
|
||||
.expect("No authorization_endpoint")
|
||||
.to_string(),
|
||||
)
|
||||
.expect("Invalid auth URL");
|
||||
|
||||
let token_url = TokenUrl::new(
|
||||
discovery_response["token_endpoint"]
|
||||
.as_str()
|
||||
.expect("No token_endpoint")
|
||||
.to_string(),
|
||||
)
|
||||
.expect("Invalid token URL");
|
||||
|
||||
let client = BasicClient::new(ClientId::new(client_id.clone()))
|
||||
.set_client_secret(ClientSecret::new(client_secret.clone()))
|
||||
.set_auth_uri(auth_url)
|
||||
.set_token_uri(token_url)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_uri.clone()).expect("Invalid redirect URI"));
|
||||
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
let (auth_url, csrf_token) = client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scope(Scope::new("openid".to_string()))
|
||||
.add_scope(Scope::new("profile".to_string()))
|
||||
.add_scope(Scope::new("email".to_string()))
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.url();
|
||||
|
||||
println!("Authorization URL: {}", auth_url);
|
||||
|
||||
let auth_response = authenticated_client
|
||||
.get(auth_url.as_str())
|
||||
.send()
|
||||
.expect("Failed to request authorization");
|
||||
|
||||
let status = auth_response.status();
|
||||
assert!(status.is_redirection(), "Expected redirect, got {}", status);
|
||||
|
||||
let location = auth_response
|
||||
.headers()
|
||||
.get("location")
|
||||
.expect("No location header")
|
||||
.to_str()
|
||||
.expect("Invalid location header");
|
||||
|
||||
println!("Redirect location: {}", location);
|
||||
|
||||
let redirect_url_parsed = if location.starts_with("http") {
|
||||
url::Url::parse(location).expect("Invalid redirect URL")
|
||||
} else {
|
||||
let base_url_for_redirect = url::Url::parse(&redirect_uri).expect("Invalid redirect URI");
|
||||
base_url_for_redirect.join(location).expect("Invalid redirect URL")
|
||||
};
|
||||
let code = redirect_url_parsed
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == "code")
|
||||
.map(|(_, v)| v.to_string())
|
||||
.expect("No code in redirect");
|
||||
|
||||
let returned_state = redirect_url_parsed
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == "state")
|
||||
.map(|(_, v)| v.to_string())
|
||||
.expect("No state in redirect");
|
||||
|
||||
assert_eq!(returned_state, *csrf_token.secret());
|
||||
|
||||
let http_client = reqwest::blocking::Client::new();
|
||||
let token_response = client
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.set_pkce_verifier(pkce_verifier)
|
||||
.request(&http_client)
|
||||
.expect("Failed to exchange code for token");
|
||||
|
||||
assert!(token_response.access_token().secret().len() > 0);
|
||||
assert!(token_response.expires_in().is_some());
|
||||
|
||||
let access_token = token_response.access_token().secret();
|
||||
|
||||
let http_client_blocking = reqwest::blocking::Client::new();
|
||||
let userinfo_response = http_client_blocking
|
||||
.get(format!("{}/userinfo", server.base_url()))
|
||||
.bearer_auth(access_token)
|
||||
.send()
|
||||
.expect("Failed to fetch userinfo")
|
||||
.json::<serde_json::Value>()
|
||||
.expect("Failed to parse userinfo");
|
||||
|
||||
// Verify subject exists and is a non-empty string
|
||||
assert!(userinfo_response["sub"].is_string());
|
||||
assert!(!userinfo_response["sub"].as_str().unwrap().is_empty());
|
||||
|
||||
println!("✓ oauth2-rs: Authorization Code + PKCE flow successful");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_security_headers() {
|
||||
let server = TestServer::start();
|
||||
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.get(format!("{}/.well-known/openid-configuration", server.base_url()))
|
||||
.send()
|
||||
.expect("Failed to fetch discovery");
|
||||
|
||||
assert_eq!(
|
||||
response.headers().get("x-frame-options").unwrap(),
|
||||
"DENY"
|
||||
);
|
||||
assert_eq!(
|
||||
response.headers().get("x-content-type-options").unwrap(),
|
||||
"nosniff"
|
||||
);
|
||||
assert_eq!(
|
||||
response.headers().get("x-xss-protection").unwrap(),
|
||||
"1; mode=block"
|
||||
);
|
||||
assert!(response.headers().get("content-security-policy").is_some());
|
||||
assert!(response.headers().get("referrer-policy").is_some());
|
||||
assert!(response.headers().get("permissions-policy").is_some());
|
||||
|
||||
println!("✓ Security headers are present");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_endpoint_cache_control() {
|
||||
let server = TestServer::start();
|
||||
let (client_id, client_secret, redirect_uri) = register_client(server.base_url());
|
||||
let (authenticated_client, _jar) = login_and_get_client(server.base_url(), "testuser3", "testpass123");
|
||||
|
||||
let http_client = reqwest::blocking::ClientBuilder::new()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.expect("Failed to build HTTP client");
|
||||
|
||||
let discovery_response = http_client
|
||||
.get(format!("{}/.well-known/openid-configuration", server.base_url()))
|
||||
.send()
|
||||
.expect("Failed to fetch discovery")
|
||||
.json::<serde_json::Value>()
|
||||
.expect("Failed to parse discovery");
|
||||
|
||||
let auth_url = discovery_response["authorization_endpoint"]
|
||||
.as_str()
|
||||
.expect("No authorization_endpoint");
|
||||
|
||||
let pkce_verifier = "test_verifier_1234567890123456789012345678901234567890";
|
||||
let challenge_hash = sha2::Sha256::digest(pkce_verifier.as_bytes());
|
||||
let pkce_challenge = base64ct::Base64UrlUnpadded::encode_string(&challenge_hash);
|
||||
|
||||
let auth_request = format!(
|
||||
"{}?client_id={}&redirect_uri={}&response_type=code&scope=openid&state=test_state&nonce=test_nonce&code_challenge={}&code_challenge_method=S256",
|
||||
auth_url,
|
||||
urlencoding::encode(&client_id),
|
||||
urlencoding::encode(&redirect_uri),
|
||||
urlencoding::encode(&pkce_challenge)
|
||||
);
|
||||
|
||||
let auth_response = authenticated_client
|
||||
.get(&auth_request)
|
||||
.send()
|
||||
.expect("Failed to request authorization");
|
||||
|
||||
let location = auth_response
|
||||
.headers()
|
||||
.get("location")
|
||||
.expect("No location header")
|
||||
.to_str()
|
||||
.expect("Invalid location header");
|
||||
|
||||
println!("Redirect location: {}", location);
|
||||
|
||||
let redirect_url_parsed = if location.starts_with("http") {
|
||||
url::Url::parse(location).expect("Invalid redirect URL")
|
||||
} else {
|
||||
let base_url_for_redirect = url::Url::parse(&redirect_uri).expect("Invalid redirect URI");
|
||||
base_url_for_redirect.join(location).expect("Invalid redirect URL")
|
||||
};
|
||||
let code = redirect_url_parsed
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == "code")
|
||||
.map(|(_, v)| v.to_string())
|
||||
.expect("No code in redirect");
|
||||
|
||||
let auth_header = base64ct::Base64::encode_string(
|
||||
format!("{}:{}", client_id, client_secret).as_bytes()
|
||||
);
|
||||
|
||||
let token_response = http_client
|
||||
.post(format!("{}/token", server.base_url()))
|
||||
.header("Authorization", format!("Basic {}", auth_header))
|
||||
.form(&[
|
||||
("grant_type", "authorization_code"),
|
||||
("code", &code),
|
||||
("redirect_uri", &redirect_uri),
|
||||
("code_verifier", pkce_verifier),
|
||||
])
|
||||
.send()
|
||||
.expect("Failed to exchange token");
|
||||
|
||||
assert_eq!(
|
||||
token_response.headers().get("cache-control").unwrap(),
|
||||
"no-store"
|
||||
);
|
||||
assert_eq!(token_response.headers().get("pragma").unwrap(), "no-cache");
|
||||
|
||||
println!("✓ Token endpoint has correct Cache-Control headers");
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue