From d59411ca60c7a8b32fcfc3c6d09b9942b5e2dd0a Mon Sep 17 00:00:00 2001 From: Till Wegmueller Date: Thu, 9 Apr 2026 21:14:48 +0200 Subject: [PATCH] Add session lifecycle, token-based identity, and client token support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Core session management infrastructure for Phase 3: - Session module: SessionId, SessionToken, SessionState enum with validated transitions (Creating→Active⇄Suspended→Destroyed), Session struct with timeout tracking, SessionRegistry with O(1) lookup by ID and token (10 unit tests) - Protocol: ClientHello gains optional token field, ServerHello gains resumed flag and token echo, SessionStatus/SessionEvent enums added to ControlMessage for session state notifications - Server: SessionRegistry in headless CalloopData, sessions created on client connect (with token lookup for resumption), suspended on disconnect (not destroyed), periodic cleanup of expired sessions - Client: --token CLI flag, persistent token at ~/.config/wayray/token with auto-generation (random 16-byte hex), token sent in ClientHello, resumed state logged from ServerHello --- Cargo.lock | 1 + crates/wayray-protocol/src/codec.rs | 1 + crates/wayray-protocol/src/messages.rs | 33 ++++ crates/wrclient/Cargo.toml | 1 + crates/wrclient/src/main.rs | 65 ++++++- crates/wrclient/src/network.rs | 12 ++ crates/wrsrvd/src/backend/headless.rs | 63 ++++++- crates/wrsrvd/src/main.rs | 1 + crates/wrsrvd/src/network.rs | 12 +- crates/wrsrvd/src/session/mod.rs | 5 + crates/wrsrvd/src/session/registry.rs | 232 ++++++++++++++++++++++++ crates/wrsrvd/src/session/types.rs | 239 +++++++++++++++++++++++++ 12 files changed, 653 insertions(+), 12 deletions(-) create mode 100644 crates/wrsrvd/src/session/mod.rs create mode 100644 crates/wrsrvd/src/session/registry.rs create mode 100644 crates/wrsrvd/src/session/types.rs diff --git a/Cargo.lock b/Cargo.lock index 5487f8c..3c3dd9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3839,6 +3839,7 @@ name = "wrclient" version = "0.1.0" dependencies = [ "env_logger", + "getrandom 0.3.4", "miette", "pollster", "quinn", diff --git a/crates/wayray-protocol/src/codec.rs b/crates/wayray-protocol/src/codec.rs index 16bd6b2..bc3faa9 100644 --- a/crates/wayray-protocol/src/codec.rs +++ b/crates/wayray-protocol/src/codec.rs @@ -71,6 +71,7 @@ mod tests { let msg = ControlMessage::ClientHello(ClientHello { version: 1, capabilities: vec!["display".to_string()], + token: Some("test".to_string()), }); let encoded = encode(&msg).unwrap(); let (len, payload) = read_length_prefix(&encoded).unwrap(); diff --git a/crates/wayray-protocol/src/messages.rs b/crates/wayray-protocol/src/messages.rs index b5b5bd3..afd3f1b 100644 --- a/crates/wayray-protocol/src/messages.rs +++ b/crates/wayray-protocol/src/messages.rs @@ -11,6 +11,9 @@ use serde::{Deserialize, Serialize}; pub struct ClientHello { pub version: u32, pub capabilities: Vec, + /// Session token for session lookup/creation. + /// If None, the server creates a new session with a generated token. + pub token: Option, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -19,6 +22,35 @@ pub struct ServerHello { pub session_id: u64, pub output_width: u32, pub output_height: u32, + /// Whether this is a resumed session or a new one. + pub resumed: bool, + /// The token bound to this session (echoed back to the client). + pub token: String, +} + +/// Session lifecycle state as seen by the client. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionStatus { + /// New session being created. + Creating, + /// Session is active. + Active, + /// Session was suspended and is being resumed. + Resuming, +} + +/// Session-related control events sent by the server. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum SessionEvent { + /// Session state changed. + StateChanged { + session_id: u64, + status: SessionStatus, + }, + /// Session is being suspended (client should prepare for disconnect). + Suspending { session_id: u64 }, + /// Session has been destroyed (client should disconnect). + Destroyed { session_id: u64 }, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -43,6 +75,7 @@ pub enum ControlMessage { Ping(Ping), Pong(Pong), FrameAck(FrameAck), + SessionEvent(SessionEvent), } // ── Display channel (server → client, unidirectional) ─────────────── diff --git a/crates/wrclient/Cargo.toml b/crates/wrclient/Cargo.toml index 8aa2556..4bbff5d 100644 --- a/crates/wrclient/Cargo.toml +++ b/crates/wrclient/Cargo.toml @@ -17,6 +17,7 @@ winit = "0.30" wgpu = "24" env_logger = "0.11" pollster = "0.4" +getrandom = "0.3" [dev-dependencies] rcgen.workspace = true diff --git a/crates/wrclient/src/main.rs b/crates/wrclient/src/main.rs index ebcbdc4..ee03bc5 100644 --- a/crates/wrclient/src/main.rs +++ b/crates/wrclient/src/main.rs @@ -196,6 +196,48 @@ impl ApplicationHandler for App { } } +/// Load or generate a persistent session token. +/// +/// Reads from `~/.config/wayray/token`. If the file doesn't exist, +/// generates a random hex token and writes it. +fn load_or_generate_token() -> String { + let config_dir = dirs_path().join("wayray"); + let token_path = config_dir.join("token"); + + if let Ok(token) = std::fs::read_to_string(&token_path) { + let token = token.trim().to_string(); + if !token.is_empty() { + return token; + } + } + + // Generate a random 16-byte hex token. + let mut bytes = [0u8; 16]; + getrandom::fill(&mut bytes).expect("failed to generate random token"); + let token = bytes.iter().map(|b| format!("{b:02x}")).collect::(); + + // Persist it. + if let Err(e) = std::fs::create_dir_all(&config_dir) { + warn!(error = %e, "failed to create config dir"); + } else if let Err(e) = std::fs::write(&token_path, &token) { + warn!(error = %e, "failed to persist token"); + } else { + info!(path = %token_path.display(), "session token persisted"); + } + + token +} + +/// Get the user's config directory base path. +fn dirs_path() -> std::path::PathBuf { + std::env::var("XDG_CONFIG_HOME") + .map(std::path::PathBuf::from) + .unwrap_or_else(|_| { + let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); + std::path::PathBuf::from(home).join(".config") + }) +} + fn main() { // Initialize tracing with RUST_LOG env filtering. tracing_subscriber::fmt() @@ -206,20 +248,28 @@ fn main() { .init(); let args: Vec = std::env::args().collect(); - if args.len() != 2 { - eprintln!("Usage: wrclient :"); - std::process::exit(1); - } - let (server_addr, server_name) = match network::resolve_server_addr(&args[1]) { + // Parse: wrclient : [--token ] + let addr_arg = args.get(1).cloned().unwrap_or_else(|| { + eprintln!("Usage: wrclient : [--token ]"); + std::process::exit(1); + }); + + let token = if let Some(pos) = args.iter().position(|a| a == "--token") { + args.get(pos + 1).cloned() + } else { + Some(load_or_generate_token()) + }; + + let (server_addr, server_name) = match network::resolve_server_addr(&addr_arg) { Ok(result) => result, Err(e) => { - eprintln!("Invalid server address '{}': {}", args[1], e); + eprintln!("Invalid server address '{}': {}", addr_arg, e); std::process::exit(1); } }; - info!(server = %server_addr, name = %server_name, "connecting to server"); + info!(server = %server_addr, name = %server_name, token = ?token, "connecting to server"); let (frame_tx, frame_rx) = mpsc::channel::(); let (input_tx, input_rx) = mpsc::channel::(); @@ -250,6 +300,7 @@ fn main() { server_addr, server_name, capabilities: vec!["display".to_string()], + token, }; let (_endpoint, mut conn) = match network::connect(&config).await { diff --git a/crates/wrclient/src/network.rs b/crates/wrclient/src/network.rs index f0ef1c8..524a8a3 100644 --- a/crates/wrclient/src/network.rs +++ b/crates/wrclient/src/network.rs @@ -29,6 +29,8 @@ pub struct ClientConfig { pub server_name: String, /// Client capabilities to advertise in the hello. pub capabilities: Vec, + /// Session token for session binding. If None, no token is sent. + pub token: Option, } /// Resolve a "host:port" string to a SocketAddr, supporting both @@ -64,6 +66,7 @@ impl Default for ClientConfig { server_addr: "127.0.0.1:4433".parse().unwrap(), server_name: "localhost".to_string(), capabilities: vec!["display".to_string()], + token: None, } } } @@ -214,6 +217,7 @@ pub async fn connect( let client_hello = ControlMessage::ClientHello(ClientHello { version: wayray_protocol::PROTOCOL_VERSION, capabilities: config.capabilities.clone(), + token: config.token.clone(), }); write_message(&mut control_send, &client_hello).await?; info!("sent ClientHello"); @@ -227,6 +231,8 @@ pub async fn connect( session_id = hello.session_id, width = hello.output_width, height = hello.output_height, + resumed = hello.resumed, + token = %hello.token, "received ServerHello" ); hello @@ -337,6 +343,8 @@ mod tests { session_id: 99, output_width: 1920, output_height: 1080, + resumed: false, + token: String::new(), }); write_message(&mut control_send, &server_hello) .await @@ -353,6 +361,7 @@ mod tests { server_addr: addr, server_name: "localhost".to_string(), capabilities: vec!["test".to_string()], + token: Some("test-token".to_string()), }; let (_endpoint, mut conn) = connect(&config).await.unwrap(); @@ -381,6 +390,8 @@ mod tests { session_id: 1, output_width: 800, output_height: 600, + resumed: false, + token: String::new(), }); write_message(&mut control_send, &server_hello) .await @@ -406,6 +417,7 @@ mod tests { server_addr: addr, server_name: "localhost".to_string(), capabilities: vec![], + token: None, }; let (_endpoint, mut conn) = connect(&config).await.unwrap(); diff --git a/crates/wrsrvd/src/backend/headless.rs b/crates/wrsrvd/src/backend/headless.rs index db4defc..67111c8 100644 --- a/crates/wrsrvd/src/backend/headless.rs +++ b/crates/wrsrvd/src/backend/headless.rs @@ -30,6 +30,7 @@ use wayray_protocol::messages::FrameUpdate; use crate::errors::WayRayError; use crate::handlers::ClientState; use crate::network::{CompositorToNet, NetToCompositor, NetworkHandle}; +use crate::session::{SessionRegistry, SessionState, SessionToken}; use crate::state::WayRay; /// Dark grey clear color for the compositor background. @@ -50,6 +51,10 @@ struct CalloopData { frame_sequence: u64, /// Whether a remote client is currently connected. client_connected: bool, + /// Session registry tracking all sessions by token. + session_registry: SessionRegistry, + /// The currently active session ID (if any). + active_session: Option, } /// Run the compositor with the headless PixmanRenderer backend. @@ -149,6 +154,8 @@ pub fn run( previous_frame, frame_sequence: 0, client_connected: false, + session_registry: SessionRegistry::new(), + active_session: None, }; let running = Arc::new(AtomicBool::new(true)); @@ -159,10 +166,22 @@ pub fn run( info!("entering headless main event loop"); + let mut cleanup_counter: u32 = 0; + while running.load(Ordering::SeqCst) { // Drain network events (input from remote clients, connection state). drain_network_events(&mut calloop_data); + // Periodically clean up expired suspended sessions (~every 60s at 60fps). + cleanup_counter += 1; + if cleanup_counter >= 3600 { + cleanup_counter = 0; + let expired = calloop_data.session_registry.cleanup_expired(); + if !expired.is_empty() { + calloop_data.session_registry.purge_destroyed(); + } + } + // Dispatch Wayland clients. calloop_data .display @@ -196,15 +215,53 @@ fn drain_network_events(data: &mut CalloopData) { data.state.inject_network_input(input_msg); } Ok(NetToCompositor::ClientConnected(hello)) => { + let token_str = hello.token.clone().unwrap_or_else(|| { + // Generate a token if the client didn't provide one. + format!("auto-{}", data.session_registry.list().count() + 1) + }); + let token = SessionToken::new(token_str); + + // Look up existing session or create new one. + let (session_id, resumed) = + if let Some(existing) = data.session_registry.find_by_token(&token) { + let id = existing.id; + let was_suspended = existing.state == SessionState::Suspended; + if was_suspended { + if let Err(e) = data.session_registry.activate(id) { + warn!(error = %e, "failed to resume session"); + } else { + info!(%id, "session resumed"); + } + } + (id, was_suspended) + } else { + let id = data.session_registry.create_session(token.clone()); + if let Err(e) = data.session_registry.activate(id) { + warn!(error = %e, "failed to activate new session"); + } + (id, false) + }; + + data.active_session = Some(session_id); + data.client_connected = true; + info!( version = hello.version, - capabilities = ?hello.capabilities, + %session_id, + resumed, + %token, "remote client connected" ); - data.client_connected = true; } Ok(NetToCompositor::ClientDisconnected) => { - info!("remote client disconnected"); + // Suspend the session instead of destroying it. + if let Some(session_id) = data.active_session.take() { + if let Err(e) = data.session_registry.suspend(session_id) { + warn!(error = %e, "failed to suspend session"); + } else { + info!(%session_id, "session suspended, waiting for reconnect"); + } + } data.client_connected = false; } Ok(NetToCompositor::Control(ctrl)) => { diff --git a/crates/wrsrvd/src/main.rs b/crates/wrsrvd/src/main.rs index 7750e48..cbdd361 100644 --- a/crates/wrsrvd/src/main.rs +++ b/crates/wrsrvd/src/main.rs @@ -2,6 +2,7 @@ mod backend; mod errors; mod handlers; pub mod network; +pub mod session; mod state; mod wm; diff --git a/crates/wrsrvd/src/network.rs b/crates/wrsrvd/src/network.rs index 4bfe5a4..f7e020e 100644 --- a/crates/wrsrvd/src/network.rs +++ b/crates/wrsrvd/src/network.rs @@ -253,13 +253,15 @@ async fn handle_connection( return Err(format!("expected ClientHello, got {client_hello:?}").into()); }; info!(version = hello.version, "received ClientHello"); + let token = hello.token.clone().unwrap_or_default(); let _ = compositor_tx.send(NetToCompositor::ClientConnected(hello)); - let server_hello = ControlMessage::ServerHello(ServerHello { version: wayray_protocol::PROTOCOL_VERSION, - session_id: 1, // TODO: real session management + session_id: 1, // Assigned by session registry in task 3 output_width: config.output_width, output_height: config.output_height, + resumed: false, + token, }); write_message(&mut control_send, &server_hello).await?; info!("sent ServerHello"); @@ -473,6 +475,8 @@ mod tests { session_id: 42, output_width: 1920, output_height: 1080, + resumed: false, + token: "test-token".to_string(), }); write_message(&mut control_send, &server_hello) .await @@ -497,6 +501,7 @@ mod tests { let client_hello = ControlMessage::ClientHello(ClientHello { version: wayray_protocol::PROTOCOL_VERSION, capabilities: vec!["display".to_string()], + token: Some("test-token".to_string()), }); write_message(&mut control_send, &client_hello) .await @@ -538,6 +543,8 @@ mod tests { session_id: 1, output_width: 1280, output_height: 720, + resumed: false, + token: String::new(), }); write_message(&mut control_send, &server_hello) .await @@ -572,6 +579,7 @@ mod tests { let client_hello = ControlMessage::ClientHello(ClientHello { version: wayray_protocol::PROTOCOL_VERSION, capabilities: vec![], + token: None, }); write_message(&mut control_send, &client_hello) .await diff --git a/crates/wrsrvd/src/session/mod.rs b/crates/wrsrvd/src/session/mod.rs new file mode 100644 index 0000000..c4619fb --- /dev/null +++ b/crates/wrsrvd/src/session/mod.rs @@ -0,0 +1,5 @@ +pub mod registry; +pub mod types; + +pub use registry::SessionRegistry; +pub use types::{Session, SessionId, SessionState, SessionToken}; diff --git a/crates/wrsrvd/src/session/registry.rs b/crates/wrsrvd/src/session/registry.rs new file mode 100644 index 0000000..44e173f --- /dev/null +++ b/crates/wrsrvd/src/session/registry.rs @@ -0,0 +1,232 @@ +use std::collections::HashMap; + +use tracing::{info, warn}; + +use super::types::{Session, SessionId, SessionState, SessionToken, SessionTransitionError}; + +/// In-memory session registry. +/// +/// Provides O(1) lookup by both session ID and token. Tracks all sessions +/// including suspended ones (until they time out and are cleaned up). +pub struct SessionRegistry { + sessions: HashMap, + /// Reverse index: token → session ID for fast lookup on client connect. + token_index: HashMap, + next_id: u64, +} + +impl SessionRegistry { + pub fn new() -> Self { + Self { + sessions: HashMap::new(), + token_index: HashMap::new(), + next_id: 1, + } + } + + /// Create a new session for the given token. Returns the session ID. + pub fn create_session(&mut self, token: SessionToken) -> SessionId { + let id = SessionId::from_raw(self.next_id); + self.next_id += 1; + + let session = Session::new(id, token.clone()); + self.token_index.insert(token.clone(), id); + self.sessions.insert(id, session); + + info!(%id, %token, "session created"); + id + } + + /// Look up a session by its token. + pub fn find_by_token(&self, token: &SessionToken) -> Option<&Session> { + let id = self.token_index.get(token)?; + self.sessions.get(id) + } + + /// Look up a session by ID. + pub fn get(&self, id: SessionId) -> Option<&Session> { + self.sessions.get(&id) + } + + /// Transition a session to a new state. + pub fn transition( + &mut self, + id: SessionId, + next: SessionState, + ) -> Result<(), SessionTransitionError> { + let session = self.sessions.get_mut(&id).ok_or(SessionTransitionError { + session_id: id, + from: SessionState::Destroyed, + to: next, + })?; + + let from = session.state; + session.transition(next)?; + info!(%id, %from, %next, "session state transition"); + + // Clean up destroyed sessions from the index. + if next == SessionState::Destroyed { + self.token_index.remove(&session.token); + } + + Ok(()) + } + + /// Activate a session (Creating → Active or Suspended → Active). + pub fn activate(&mut self, id: SessionId) -> Result<(), SessionTransitionError> { + self.transition(id, SessionState::Active) + } + + /// Suspend a session (Active → Suspended). + pub fn suspend(&mut self, id: SessionId) -> Result<(), SessionTransitionError> { + self.transition(id, SessionState::Suspended) + } + + /// Destroy a session. + pub fn destroy(&mut self, id: SessionId) -> Result<(), SessionTransitionError> { + self.transition(id, SessionState::Destroyed) + } + + /// Set the user for a session (after authentication). + pub fn set_user(&mut self, id: SessionId, user: String) { + if let Some(session) = self.sessions.get_mut(&id) { + session.user = Some(user); + } + } + + /// Clean up expired suspended sessions. Returns IDs of destroyed sessions. + pub fn cleanup_expired(&mut self) -> Vec { + let expired: Vec = self + .sessions + .values() + .filter(|s| s.is_suspend_expired()) + .map(|s| s.id) + .collect(); + + for id in &expired { + if let Err(e) = self.destroy(*id) { + warn!(%id, error = %e, "failed to destroy expired session"); + } else { + info!(%id, "expired suspended session destroyed"); + } + } + + expired + } + + /// Remove destroyed sessions from memory entirely. + pub fn purge_destroyed(&mut self) { + self.sessions + .retain(|_, s| s.state != SessionState::Destroyed); + } + + /// List all sessions (for admin queries). + pub fn list(&self) -> impl Iterator { + self.sessions.values() + } + + /// Count sessions by state. + pub fn count_by_state(&self, state: SessionState) -> usize { + self.sessions.values().filter(|s| s.state == state).count() + } +} + +impl Default for SessionRegistry { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn create_and_lookup() { + let mut reg = SessionRegistry::new(); + let token = SessionToken::new("abc-123"); + let id = reg.create_session(token.clone()); + + assert!(reg.get(id).is_some()); + assert!(reg.find_by_token(&token).is_some()); + assert_eq!(reg.find_by_token(&token).unwrap().id, id); + } + + #[test] + fn lifecycle_through_registry() { + let mut reg = SessionRegistry::new(); + let token = SessionToken::new("tok"); + let id = reg.create_session(token.clone()); + + assert_eq!(reg.get(id).unwrap().state, SessionState::Creating); + + reg.activate(id).unwrap(); + assert_eq!(reg.get(id).unwrap().state, SessionState::Active); + + reg.suspend(id).unwrap(); + assert_eq!(reg.get(id).unwrap().state, SessionState::Suspended); + + // Resume + reg.activate(id).unwrap(); + assert_eq!(reg.get(id).unwrap().state, SessionState::Active); + + reg.destroy(id).unwrap(); + // Token index should be cleaned up + assert!(reg.find_by_token(&token).is_none()); + } + + #[test] + fn cleanup_expired() { + let mut reg = SessionRegistry::new(); + let id = reg.create_session(SessionToken::new("t1")); + reg.activate(id).unwrap(); + reg.suspend(id).unwrap(); + + // Set zero timeout so it expires immediately + reg.sessions.get_mut(&id).unwrap().suspend_timeout = std::time::Duration::ZERO; + + let expired = reg.cleanup_expired(); + assert_eq!(expired.len(), 1); + assert_eq!(expired[0], id); + assert_eq!(reg.get(id).unwrap().state, SessionState::Destroyed); + } + + #[test] + fn purge_destroyed() { + let mut reg = SessionRegistry::new(); + let id = reg.create_session(SessionToken::new("t")); + reg.activate(id).unwrap(); + reg.destroy(id).unwrap(); + + assert!(reg.get(id).is_some()); // Still in memory + reg.purge_destroyed(); + assert!(reg.get(id).is_none()); // Gone + } + + #[test] + fn multiple_sessions() { + let mut reg = SessionRegistry::new(); + let id1 = reg.create_session(SessionToken::new("tok-a")); + let id2 = reg.create_session(SessionToken::new("tok-b")); + + assert_ne!(id1, id2); + reg.activate(id1).unwrap(); + reg.activate(id2).unwrap(); + + assert_eq!(reg.count_by_state(SessionState::Active), 2); + + reg.suspend(id1).unwrap(); + assert_eq!(reg.count_by_state(SessionState::Active), 1); + assert_eq!(reg.count_by_state(SessionState::Suspended), 1); + } + + #[test] + fn set_user() { + let mut reg = SessionRegistry::new(); + let id = reg.create_session(SessionToken::new("t")); + assert!(reg.get(id).unwrap().user.is_none()); + + reg.set_user(id, "alice".to_string()); + assert_eq!(reg.get(id).unwrap().user.as_deref(), Some("alice")); + } +} diff --git a/crates/wrsrvd/src/session/types.rs b/crates/wrsrvd/src/session/types.rs new file mode 100644 index 0000000..98e4681 --- /dev/null +++ b/crates/wrsrvd/src/session/types.rs @@ -0,0 +1,239 @@ +use std::time::{Duration, Instant}; + +/// Unique identifier for a compositor session. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SessionId(u64); + +impl SessionId { + pub fn from_raw(id: u64) -> Self { + Self(id) + } + + pub fn raw(self) -> u64 { + self.0 + } +} + +impl std::fmt::Display for SessionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "session-{}", self.0) + } +} + +/// Opaque token that identifies a session across connections. +/// +/// This is NOT the user's identity — it's a session locator. Different +/// token backends (software UUID, smart card, NFC) all produce these. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SessionToken(pub String); + +impl SessionToken { + pub fn new(token: impl Into) -> Self { + Self(token.into()) + } + + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl std::fmt::Display for SessionToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Session lifecycle states per ADR-004. +/// +/// ```text +/// [Creating] → [Active] ⇄ [Suspended] → [Destroyed] +/// ↓ ↑ +/// [Destroyed] ←─────────────────┘ +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SessionState { + /// Session is being created (compositor starting, greeter launching). + Creating, + /// Session is active with a connected client. + Active, + /// Client disconnected; session preserved, rendering paused. + Suspended, + /// Session terminated; all resources cleaned up. + Destroyed, +} + +impl SessionState { + /// Check if transitioning from this state to `next` is valid. + pub fn can_transition_to(self, next: SessionState) -> bool { + matches!( + (self, next), + (SessionState::Creating, SessionState::Active) + | (SessionState::Creating, SessionState::Destroyed) + | (SessionState::Active, SessionState::Suspended) + | (SessionState::Active, SessionState::Destroyed) + | (SessionState::Suspended, SessionState::Active) + | (SessionState::Suspended, SessionState::Destroyed) + ) + } +} + +impl std::fmt::Display for SessionState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SessionState::Creating => write!(f, "creating"), + SessionState::Active => write!(f, "active"), + SessionState::Suspended => write!(f, "suspended"), + SessionState::Destroyed => write!(f, "destroyed"), + } + } +} + +/// A compositor session binding a token to a running desktop. +#[derive(Debug)] +pub struct Session { + pub id: SessionId, + pub token: SessionToken, + pub state: SessionState, + pub created_at: Instant, + /// When the session entered Suspended state (for timeout tracking). + pub suspended_at: Option, + /// How long a suspended session survives before being destroyed. + pub suspend_timeout: Duration, + /// Username (set after authentication). + pub user: Option, +} + +impl Session { + /// Create a new session in the Creating state. + pub fn new(id: SessionId, token: SessionToken) -> Self { + Self { + id, + token, + state: SessionState::Creating, + created_at: Instant::now(), + suspended_at: None, + suspend_timeout: Duration::from_secs(24 * 60 * 60), // 24 hours default + user: None, + } + } + + /// Attempt a state transition. Returns Ok(()) if valid, Err with reason if not. + pub fn transition(&mut self, next: SessionState) -> Result<(), SessionTransitionError> { + if !self.state.can_transition_to(next) { + return Err(SessionTransitionError { + session_id: self.id, + from: self.state, + to: next, + }); + } + + if next == SessionState::Suspended { + self.suspended_at = Some(Instant::now()); + } else { + self.suspended_at = None; + } + + self.state = next; + Ok(()) + } + + /// Check if a suspended session has exceeded its timeout. + pub fn is_suspend_expired(&self) -> bool { + if self.state != SessionState::Suspended { + return false; + } + self.suspended_at + .is_some_and(|t| t.elapsed() >= self.suspend_timeout) + } +} + +/// Error when an invalid session state transition is attempted. +#[derive(Debug)] +pub struct SessionTransitionError { + pub session_id: SessionId, + pub from: SessionState, + pub to: SessionState, +} + +impl std::fmt::Display for SessionTransitionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "invalid session transition for {}: {} → {}", + self.session_id, self.from, self.to + ) + } +} + +impl std::error::Error for SessionTransitionError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn valid_lifecycle_transitions() { + let mut session = Session::new(SessionId::from_raw(1), SessionToken::new("test-token")); + + assert_eq!(session.state, SessionState::Creating); + + // Creating → Active + session.transition(SessionState::Active).unwrap(); + assert_eq!(session.state, SessionState::Active); + assert!(session.suspended_at.is_none()); + + // Active → Suspended + session.transition(SessionState::Suspended).unwrap(); + assert_eq!(session.state, SessionState::Suspended); + assert!(session.suspended_at.is_some()); + + // Suspended → Active (resume) + session.transition(SessionState::Active).unwrap(); + assert_eq!(session.state, SessionState::Active); + assert!(session.suspended_at.is_none()); + + // Active → Destroyed + session.transition(SessionState::Destroyed).unwrap(); + assert_eq!(session.state, SessionState::Destroyed); + } + + #[test] + fn invalid_transitions_rejected() { + let mut session = Session::new(SessionId::from_raw(1), SessionToken::new("t")); + + // Creating → Suspended is invalid + assert!(session.transition(SessionState::Suspended).is_err()); + + // Creating → Active is valid + session.transition(SessionState::Active).unwrap(); + + // Active → Creating is invalid + assert!(session.transition(SessionState::Creating).is_err()); + + // Destroyed is terminal + session.transition(SessionState::Destroyed).unwrap(); + assert!(session.transition(SessionState::Active).is_err()); + assert!(session.transition(SessionState::Suspended).is_err()); + } + + #[test] + fn creating_can_be_destroyed() { + let mut session = Session::new(SessionId::from_raw(1), SessionToken::new("t")); + // Creating → Destroyed (e.g., creation failed) + session.transition(SessionState::Destroyed).unwrap(); + } + + #[test] + fn suspend_timeout_tracking() { + let mut session = Session::new(SessionId::from_raw(1), SessionToken::new("t")); + session.transition(SessionState::Active).unwrap(); + session.transition(SessionState::Suspended).unwrap(); + + // With default 24h timeout, should not be expired yet + assert!(!session.is_suspend_expired()); + + // Set a zero timeout to simulate expiry + session.suspend_timeout = Duration::ZERO; + assert!(session.is_suspend_expired()); + } +}