mirror of
https://github.com/CloudNebulaProject/wayray.git
synced 2026-04-10 13:10:41 +00:00
Add session lifecycle, token-based identity, and client token support
Core session management infrastructure for Phase 3: - Session module: SessionId, SessionToken, SessionState enum with validated transitions (Creating→Active⇄Suspended→Destroyed), Session struct with timeout tracking, SessionRegistry with O(1) lookup by ID and token (10 unit tests) - Protocol: ClientHello gains optional token field, ServerHello gains resumed flag and token echo, SessionStatus/SessionEvent enums added to ControlMessage for session state notifications - Server: SessionRegistry in headless CalloopData, sessions created on client connect (with token lookup for resumption), suspended on disconnect (not destroyed), periodic cleanup of expired sessions - Client: --token CLI flag, persistent token at ~/.config/wayray/token with auto-generation (random 16-byte hex), token sent in ClientHello, resumed state logged from ServerHello
This commit is contained in:
parent
04b8b2f28b
commit
d59411ca60
12 changed files with 653 additions and 12 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -3839,6 +3839,7 @@ name = "wrclient"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"env_logger",
|
"env_logger",
|
||||||
|
"getrandom 0.3.4",
|
||||||
"miette",
|
"miette",
|
||||||
"pollster",
|
"pollster",
|
||||||
"quinn",
|
"quinn",
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,7 @@ mod tests {
|
||||||
let msg = ControlMessage::ClientHello(ClientHello {
|
let msg = ControlMessage::ClientHello(ClientHello {
|
||||||
version: 1,
|
version: 1,
|
||||||
capabilities: vec!["display".to_string()],
|
capabilities: vec!["display".to_string()],
|
||||||
|
token: Some("test".to_string()),
|
||||||
});
|
});
|
||||||
let encoded = encode(&msg).unwrap();
|
let encoded = encode(&msg).unwrap();
|
||||||
let (len, payload) = read_length_prefix(&encoded).unwrap();
|
let (len, payload) = read_length_prefix(&encoded).unwrap();
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,9 @@ use serde::{Deserialize, Serialize};
|
||||||
pub struct ClientHello {
|
pub struct ClientHello {
|
||||||
pub version: u32,
|
pub version: u32,
|
||||||
pub capabilities: Vec<String>,
|
pub capabilities: Vec<String>,
|
||||||
|
/// Session token for session lookup/creation.
|
||||||
|
/// If None, the server creates a new session with a generated token.
|
||||||
|
pub token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
|
@ -19,6 +22,35 @@ pub struct ServerHello {
|
||||||
pub session_id: u64,
|
pub session_id: u64,
|
||||||
pub output_width: u32,
|
pub output_width: u32,
|
||||||
pub output_height: u32,
|
pub output_height: u32,
|
||||||
|
/// Whether this is a resumed session or a new one.
|
||||||
|
pub resumed: bool,
|
||||||
|
/// The token bound to this session (echoed back to the client).
|
||||||
|
pub token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Session lifecycle state as seen by the client.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum SessionStatus {
|
||||||
|
/// New session being created.
|
||||||
|
Creating,
|
||||||
|
/// Session is active.
|
||||||
|
Active,
|
||||||
|
/// Session was suspended and is being resumed.
|
||||||
|
Resuming,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Session-related control events sent by the server.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum SessionEvent {
|
||||||
|
/// Session state changed.
|
||||||
|
StateChanged {
|
||||||
|
session_id: u64,
|
||||||
|
status: SessionStatus,
|
||||||
|
},
|
||||||
|
/// Session is being suspended (client should prepare for disconnect).
|
||||||
|
Suspending { session_id: u64 },
|
||||||
|
/// Session has been destroyed (client should disconnect).
|
||||||
|
Destroyed { session_id: u64 },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
|
@ -43,6 +75,7 @@ pub enum ControlMessage {
|
||||||
Ping(Ping),
|
Ping(Ping),
|
||||||
Pong(Pong),
|
Pong(Pong),
|
||||||
FrameAck(FrameAck),
|
FrameAck(FrameAck),
|
||||||
|
SessionEvent(SessionEvent),
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Display channel (server → client, unidirectional) ───────────────
|
// ── Display channel (server → client, unidirectional) ───────────────
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ winit = "0.30"
|
||||||
wgpu = "24"
|
wgpu = "24"
|
||||||
env_logger = "0.11"
|
env_logger = "0.11"
|
||||||
pollster = "0.4"
|
pollster = "0.4"
|
||||||
|
getrandom = "0.3"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
rcgen.workspace = true
|
rcgen.workspace = true
|
||||||
|
|
|
||||||
|
|
@ -196,6 +196,48 @@ impl ApplicationHandler for App {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Load or generate a persistent session token.
|
||||||
|
///
|
||||||
|
/// Reads from `~/.config/wayray/token`. If the file doesn't exist,
|
||||||
|
/// generates a random hex token and writes it.
|
||||||
|
fn load_or_generate_token() -> String {
|
||||||
|
let config_dir = dirs_path().join("wayray");
|
||||||
|
let token_path = config_dir.join("token");
|
||||||
|
|
||||||
|
if let Ok(token) = std::fs::read_to_string(&token_path) {
|
||||||
|
let token = token.trim().to_string();
|
||||||
|
if !token.is_empty() {
|
||||||
|
return token;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random 16-byte hex token.
|
||||||
|
let mut bytes = [0u8; 16];
|
||||||
|
getrandom::fill(&mut bytes).expect("failed to generate random token");
|
||||||
|
let token = bytes.iter().map(|b| format!("{b:02x}")).collect::<String>();
|
||||||
|
|
||||||
|
// Persist it.
|
||||||
|
if let Err(e) = std::fs::create_dir_all(&config_dir) {
|
||||||
|
warn!(error = %e, "failed to create config dir");
|
||||||
|
} else if let Err(e) = std::fs::write(&token_path, &token) {
|
||||||
|
warn!(error = %e, "failed to persist token");
|
||||||
|
} else {
|
||||||
|
info!(path = %token_path.display(), "session token persisted");
|
||||||
|
}
|
||||||
|
|
||||||
|
token
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the user's config directory base path.
|
||||||
|
fn dirs_path() -> std::path::PathBuf {
|
||||||
|
std::env::var("XDG_CONFIG_HOME")
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.unwrap_or_else(|_| {
|
||||||
|
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||||
|
std::path::PathBuf::from(home).join(".config")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// Initialize tracing with RUST_LOG env filtering.
|
// Initialize tracing with RUST_LOG env filtering.
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
|
|
@ -206,20 +248,28 @@ fn main() {
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
let args: Vec<String> = std::env::args().collect();
|
let args: Vec<String> = std::env::args().collect();
|
||||||
if args.len() != 2 {
|
|
||||||
eprintln!("Usage: wrclient <host>:<port>");
|
|
||||||
std::process::exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
let (server_addr, server_name) = match network::resolve_server_addr(&args[1]) {
|
// Parse: wrclient <host>:<port> [--token <token>]
|
||||||
|
let addr_arg = args.get(1).cloned().unwrap_or_else(|| {
|
||||||
|
eprintln!("Usage: wrclient <host>:<port> [--token <token>]");
|
||||||
|
std::process::exit(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
let token = if let Some(pos) = args.iter().position(|a| a == "--token") {
|
||||||
|
args.get(pos + 1).cloned()
|
||||||
|
} else {
|
||||||
|
Some(load_or_generate_token())
|
||||||
|
};
|
||||||
|
|
||||||
|
let (server_addr, server_name) = match network::resolve_server_addr(&addr_arg) {
|
||||||
Ok(result) => result,
|
Ok(result) => result,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Invalid server address '{}': {}", args[1], e);
|
eprintln!("Invalid server address '{}': {}", addr_arg, e);
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
info!(server = %server_addr, name = %server_name, "connecting to server");
|
info!(server = %server_addr, name = %server_name, token = ?token, "connecting to server");
|
||||||
|
|
||||||
let (frame_tx, frame_rx) = mpsc::channel::<FrameData>();
|
let (frame_tx, frame_rx) = mpsc::channel::<FrameData>();
|
||||||
let (input_tx, input_rx) = mpsc::channel::<InputMessage>();
|
let (input_tx, input_rx) = mpsc::channel::<InputMessage>();
|
||||||
|
|
@ -250,6 +300,7 @@ fn main() {
|
||||||
server_addr,
|
server_addr,
|
||||||
server_name,
|
server_name,
|
||||||
capabilities: vec!["display".to_string()],
|
capabilities: vec!["display".to_string()],
|
||||||
|
token,
|
||||||
};
|
};
|
||||||
|
|
||||||
let (_endpoint, mut conn) = match network::connect(&config).await {
|
let (_endpoint, mut conn) = match network::connect(&config).await {
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,8 @@ pub struct ClientConfig {
|
||||||
pub server_name: String,
|
pub server_name: String,
|
||||||
/// Client capabilities to advertise in the hello.
|
/// Client capabilities to advertise in the hello.
|
||||||
pub capabilities: Vec<String>,
|
pub capabilities: Vec<String>,
|
||||||
|
/// Session token for session binding. If None, no token is sent.
|
||||||
|
pub token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve a "host:port" string to a SocketAddr, supporting both
|
/// Resolve a "host:port" string to a SocketAddr, supporting both
|
||||||
|
|
@ -64,6 +66,7 @@ impl Default for ClientConfig {
|
||||||
server_addr: "127.0.0.1:4433".parse().unwrap(),
|
server_addr: "127.0.0.1:4433".parse().unwrap(),
|
||||||
server_name: "localhost".to_string(),
|
server_name: "localhost".to_string(),
|
||||||
capabilities: vec!["display".to_string()],
|
capabilities: vec!["display".to_string()],
|
||||||
|
token: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -214,6 +217,7 @@ pub async fn connect(
|
||||||
let client_hello = ControlMessage::ClientHello(ClientHello {
|
let client_hello = ControlMessage::ClientHello(ClientHello {
|
||||||
version: wayray_protocol::PROTOCOL_VERSION,
|
version: wayray_protocol::PROTOCOL_VERSION,
|
||||||
capabilities: config.capabilities.clone(),
|
capabilities: config.capabilities.clone(),
|
||||||
|
token: config.token.clone(),
|
||||||
});
|
});
|
||||||
write_message(&mut control_send, &client_hello).await?;
|
write_message(&mut control_send, &client_hello).await?;
|
||||||
info!("sent ClientHello");
|
info!("sent ClientHello");
|
||||||
|
|
@ -227,6 +231,8 @@ pub async fn connect(
|
||||||
session_id = hello.session_id,
|
session_id = hello.session_id,
|
||||||
width = hello.output_width,
|
width = hello.output_width,
|
||||||
height = hello.output_height,
|
height = hello.output_height,
|
||||||
|
resumed = hello.resumed,
|
||||||
|
token = %hello.token,
|
||||||
"received ServerHello"
|
"received ServerHello"
|
||||||
);
|
);
|
||||||
hello
|
hello
|
||||||
|
|
@ -337,6 +343,8 @@ mod tests {
|
||||||
session_id: 99,
|
session_id: 99,
|
||||||
output_width: 1920,
|
output_width: 1920,
|
||||||
output_height: 1080,
|
output_height: 1080,
|
||||||
|
resumed: false,
|
||||||
|
token: String::new(),
|
||||||
});
|
});
|
||||||
write_message(&mut control_send, &server_hello)
|
write_message(&mut control_send, &server_hello)
|
||||||
.await
|
.await
|
||||||
|
|
@ -353,6 +361,7 @@ mod tests {
|
||||||
server_addr: addr,
|
server_addr: addr,
|
||||||
server_name: "localhost".to_string(),
|
server_name: "localhost".to_string(),
|
||||||
capabilities: vec!["test".to_string()],
|
capabilities: vec!["test".to_string()],
|
||||||
|
token: Some("test-token".to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (_endpoint, mut conn) = connect(&config).await.unwrap();
|
let (_endpoint, mut conn) = connect(&config).await.unwrap();
|
||||||
|
|
@ -381,6 +390,8 @@ mod tests {
|
||||||
session_id: 1,
|
session_id: 1,
|
||||||
output_width: 800,
|
output_width: 800,
|
||||||
output_height: 600,
|
output_height: 600,
|
||||||
|
resumed: false,
|
||||||
|
token: String::new(),
|
||||||
});
|
});
|
||||||
write_message(&mut control_send, &server_hello)
|
write_message(&mut control_send, &server_hello)
|
||||||
.await
|
.await
|
||||||
|
|
@ -406,6 +417,7 @@ mod tests {
|
||||||
server_addr: addr,
|
server_addr: addr,
|
||||||
server_name: "localhost".to_string(),
|
server_name: "localhost".to_string(),
|
||||||
capabilities: vec![],
|
capabilities: vec![],
|
||||||
|
token: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let (_endpoint, mut conn) = connect(&config).await.unwrap();
|
let (_endpoint, mut conn) = connect(&config).await.unwrap();
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ use wayray_protocol::messages::FrameUpdate;
|
||||||
use crate::errors::WayRayError;
|
use crate::errors::WayRayError;
|
||||||
use crate::handlers::ClientState;
|
use crate::handlers::ClientState;
|
||||||
use crate::network::{CompositorToNet, NetToCompositor, NetworkHandle};
|
use crate::network::{CompositorToNet, NetToCompositor, NetworkHandle};
|
||||||
|
use crate::session::{SessionRegistry, SessionState, SessionToken};
|
||||||
use crate::state::WayRay;
|
use crate::state::WayRay;
|
||||||
|
|
||||||
/// Dark grey clear color for the compositor background.
|
/// Dark grey clear color for the compositor background.
|
||||||
|
|
@ -50,6 +51,10 @@ struct CalloopData {
|
||||||
frame_sequence: u64,
|
frame_sequence: u64,
|
||||||
/// Whether a remote client is currently connected.
|
/// Whether a remote client is currently connected.
|
||||||
client_connected: bool,
|
client_connected: bool,
|
||||||
|
/// Session registry tracking all sessions by token.
|
||||||
|
session_registry: SessionRegistry,
|
||||||
|
/// The currently active session ID (if any).
|
||||||
|
active_session: Option<crate::session::SessionId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run the compositor with the headless PixmanRenderer backend.
|
/// Run the compositor with the headless PixmanRenderer backend.
|
||||||
|
|
@ -149,6 +154,8 @@ pub fn run(
|
||||||
previous_frame,
|
previous_frame,
|
||||||
frame_sequence: 0,
|
frame_sequence: 0,
|
||||||
client_connected: false,
|
client_connected: false,
|
||||||
|
session_registry: SessionRegistry::new(),
|
||||||
|
active_session: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
let running = Arc::new(AtomicBool::new(true));
|
||||||
|
|
@ -159,10 +166,22 @@ pub fn run(
|
||||||
|
|
||||||
info!("entering headless main event loop");
|
info!("entering headless main event loop");
|
||||||
|
|
||||||
|
let mut cleanup_counter: u32 = 0;
|
||||||
|
|
||||||
while running.load(Ordering::SeqCst) {
|
while running.load(Ordering::SeqCst) {
|
||||||
// Drain network events (input from remote clients, connection state).
|
// Drain network events (input from remote clients, connection state).
|
||||||
drain_network_events(&mut calloop_data);
|
drain_network_events(&mut calloop_data);
|
||||||
|
|
||||||
|
// Periodically clean up expired suspended sessions (~every 60s at 60fps).
|
||||||
|
cleanup_counter += 1;
|
||||||
|
if cleanup_counter >= 3600 {
|
||||||
|
cleanup_counter = 0;
|
||||||
|
let expired = calloop_data.session_registry.cleanup_expired();
|
||||||
|
if !expired.is_empty() {
|
||||||
|
calloop_data.session_registry.purge_destroyed();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Dispatch Wayland clients.
|
// Dispatch Wayland clients.
|
||||||
calloop_data
|
calloop_data
|
||||||
.display
|
.display
|
||||||
|
|
@ -196,15 +215,53 @@ fn drain_network_events(data: &mut CalloopData) {
|
||||||
data.state.inject_network_input(input_msg);
|
data.state.inject_network_input(input_msg);
|
||||||
}
|
}
|
||||||
Ok(NetToCompositor::ClientConnected(hello)) => {
|
Ok(NetToCompositor::ClientConnected(hello)) => {
|
||||||
|
let token_str = hello.token.clone().unwrap_or_else(|| {
|
||||||
|
// Generate a token if the client didn't provide one.
|
||||||
|
format!("auto-{}", data.session_registry.list().count() + 1)
|
||||||
|
});
|
||||||
|
let token = SessionToken::new(token_str);
|
||||||
|
|
||||||
|
// Look up existing session or create new one.
|
||||||
|
let (session_id, resumed) =
|
||||||
|
if let Some(existing) = data.session_registry.find_by_token(&token) {
|
||||||
|
let id = existing.id;
|
||||||
|
let was_suspended = existing.state == SessionState::Suspended;
|
||||||
|
if was_suspended {
|
||||||
|
if let Err(e) = data.session_registry.activate(id) {
|
||||||
|
warn!(error = %e, "failed to resume session");
|
||||||
|
} else {
|
||||||
|
info!(%id, "session resumed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(id, was_suspended)
|
||||||
|
} else {
|
||||||
|
let id = data.session_registry.create_session(token.clone());
|
||||||
|
if let Err(e) = data.session_registry.activate(id) {
|
||||||
|
warn!(error = %e, "failed to activate new session");
|
||||||
|
}
|
||||||
|
(id, false)
|
||||||
|
};
|
||||||
|
|
||||||
|
data.active_session = Some(session_id);
|
||||||
|
data.client_connected = true;
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
version = hello.version,
|
version = hello.version,
|
||||||
capabilities = ?hello.capabilities,
|
%session_id,
|
||||||
|
resumed,
|
||||||
|
%token,
|
||||||
"remote client connected"
|
"remote client connected"
|
||||||
);
|
);
|
||||||
data.client_connected = true;
|
|
||||||
}
|
}
|
||||||
Ok(NetToCompositor::ClientDisconnected) => {
|
Ok(NetToCompositor::ClientDisconnected) => {
|
||||||
info!("remote client disconnected");
|
// Suspend the session instead of destroying it.
|
||||||
|
if let Some(session_id) = data.active_session.take() {
|
||||||
|
if let Err(e) = data.session_registry.suspend(session_id) {
|
||||||
|
warn!(error = %e, "failed to suspend session");
|
||||||
|
} else {
|
||||||
|
info!(%session_id, "session suspended, waiting for reconnect");
|
||||||
|
}
|
||||||
|
}
|
||||||
data.client_connected = false;
|
data.client_connected = false;
|
||||||
}
|
}
|
||||||
Ok(NetToCompositor::Control(ctrl)) => {
|
Ok(NetToCompositor::Control(ctrl)) => {
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ mod backend;
|
||||||
mod errors;
|
mod errors;
|
||||||
mod handlers;
|
mod handlers;
|
||||||
pub mod network;
|
pub mod network;
|
||||||
|
pub mod session;
|
||||||
mod state;
|
mod state;
|
||||||
mod wm;
|
mod wm;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -253,13 +253,15 @@ async fn handle_connection(
|
||||||
return Err(format!("expected ClientHello, got {client_hello:?}").into());
|
return Err(format!("expected ClientHello, got {client_hello:?}").into());
|
||||||
};
|
};
|
||||||
info!(version = hello.version, "received ClientHello");
|
info!(version = hello.version, "received ClientHello");
|
||||||
|
let token = hello.token.clone().unwrap_or_default();
|
||||||
let _ = compositor_tx.send(NetToCompositor::ClientConnected(hello));
|
let _ = compositor_tx.send(NetToCompositor::ClientConnected(hello));
|
||||||
|
|
||||||
let server_hello = ControlMessage::ServerHello(ServerHello {
|
let server_hello = ControlMessage::ServerHello(ServerHello {
|
||||||
version: wayray_protocol::PROTOCOL_VERSION,
|
version: wayray_protocol::PROTOCOL_VERSION,
|
||||||
session_id: 1, // TODO: real session management
|
session_id: 1, // Assigned by session registry in task 3
|
||||||
output_width: config.output_width,
|
output_width: config.output_width,
|
||||||
output_height: config.output_height,
|
output_height: config.output_height,
|
||||||
|
resumed: false,
|
||||||
|
token,
|
||||||
});
|
});
|
||||||
write_message(&mut control_send, &server_hello).await?;
|
write_message(&mut control_send, &server_hello).await?;
|
||||||
info!("sent ServerHello");
|
info!("sent ServerHello");
|
||||||
|
|
@ -473,6 +475,8 @@ mod tests {
|
||||||
session_id: 42,
|
session_id: 42,
|
||||||
output_width: 1920,
|
output_width: 1920,
|
||||||
output_height: 1080,
|
output_height: 1080,
|
||||||
|
resumed: false,
|
||||||
|
token: "test-token".to_string(),
|
||||||
});
|
});
|
||||||
write_message(&mut control_send, &server_hello)
|
write_message(&mut control_send, &server_hello)
|
||||||
.await
|
.await
|
||||||
|
|
@ -497,6 +501,7 @@ mod tests {
|
||||||
let client_hello = ControlMessage::ClientHello(ClientHello {
|
let client_hello = ControlMessage::ClientHello(ClientHello {
|
||||||
version: wayray_protocol::PROTOCOL_VERSION,
|
version: wayray_protocol::PROTOCOL_VERSION,
|
||||||
capabilities: vec!["display".to_string()],
|
capabilities: vec!["display".to_string()],
|
||||||
|
token: Some("test-token".to_string()),
|
||||||
});
|
});
|
||||||
write_message(&mut control_send, &client_hello)
|
write_message(&mut control_send, &client_hello)
|
||||||
.await
|
.await
|
||||||
|
|
@ -538,6 +543,8 @@ mod tests {
|
||||||
session_id: 1,
|
session_id: 1,
|
||||||
output_width: 1280,
|
output_width: 1280,
|
||||||
output_height: 720,
|
output_height: 720,
|
||||||
|
resumed: false,
|
||||||
|
token: String::new(),
|
||||||
});
|
});
|
||||||
write_message(&mut control_send, &server_hello)
|
write_message(&mut control_send, &server_hello)
|
||||||
.await
|
.await
|
||||||
|
|
@ -572,6 +579,7 @@ mod tests {
|
||||||
let client_hello = ControlMessage::ClientHello(ClientHello {
|
let client_hello = ControlMessage::ClientHello(ClientHello {
|
||||||
version: wayray_protocol::PROTOCOL_VERSION,
|
version: wayray_protocol::PROTOCOL_VERSION,
|
||||||
capabilities: vec![],
|
capabilities: vec![],
|
||||||
|
token: None,
|
||||||
});
|
});
|
||||||
write_message(&mut control_send, &client_hello)
|
write_message(&mut control_send, &client_hello)
|
||||||
.await
|
.await
|
||||||
|
|
|
||||||
5
crates/wrsrvd/src/session/mod.rs
Normal file
5
crates/wrsrvd/src/session/mod.rs
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
pub mod registry;
|
||||||
|
pub mod types;
|
||||||
|
|
||||||
|
pub use registry::SessionRegistry;
|
||||||
|
pub use types::{Session, SessionId, SessionState, SessionToken};
|
||||||
232
crates/wrsrvd/src/session/registry.rs
Normal file
232
crates/wrsrvd/src/session/registry.rs
Normal file
|
|
@ -0,0 +1,232 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
use super::types::{Session, SessionId, SessionState, SessionToken, SessionTransitionError};
|
||||||
|
|
||||||
|
/// In-memory session registry.
|
||||||
|
///
|
||||||
|
/// Provides O(1) lookup by both session ID and token. Tracks all sessions
|
||||||
|
/// including suspended ones (until they time out and are cleaned up).
|
||||||
|
pub struct SessionRegistry {
|
||||||
|
sessions: HashMap<SessionId, Session>,
|
||||||
|
/// Reverse index: token → session ID for fast lookup on client connect.
|
||||||
|
token_index: HashMap<SessionToken, SessionId>,
|
||||||
|
next_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionRegistry {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sessions: HashMap::new(),
|
||||||
|
token_index: HashMap::new(),
|
||||||
|
next_id: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new session for the given token. Returns the session ID.
|
||||||
|
pub fn create_session(&mut self, token: SessionToken) -> SessionId {
|
||||||
|
let id = SessionId::from_raw(self.next_id);
|
||||||
|
self.next_id += 1;
|
||||||
|
|
||||||
|
let session = Session::new(id, token.clone());
|
||||||
|
self.token_index.insert(token.clone(), id);
|
||||||
|
self.sessions.insert(id, session);
|
||||||
|
|
||||||
|
info!(%id, %token, "session created");
|
||||||
|
id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up a session by its token.
|
||||||
|
pub fn find_by_token(&self, token: &SessionToken) -> Option<&Session> {
|
||||||
|
let id = self.token_index.get(token)?;
|
||||||
|
self.sessions.get(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up a session by ID.
|
||||||
|
pub fn get(&self, id: SessionId) -> Option<&Session> {
|
||||||
|
self.sessions.get(&id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transition a session to a new state.
|
||||||
|
pub fn transition(
|
||||||
|
&mut self,
|
||||||
|
id: SessionId,
|
||||||
|
next: SessionState,
|
||||||
|
) -> Result<(), SessionTransitionError> {
|
||||||
|
let session = self.sessions.get_mut(&id).ok_or(SessionTransitionError {
|
||||||
|
session_id: id,
|
||||||
|
from: SessionState::Destroyed,
|
||||||
|
to: next,
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let from = session.state;
|
||||||
|
session.transition(next)?;
|
||||||
|
info!(%id, %from, %next, "session state transition");
|
||||||
|
|
||||||
|
// Clean up destroyed sessions from the index.
|
||||||
|
if next == SessionState::Destroyed {
|
||||||
|
self.token_index.remove(&session.token);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Activate a session (Creating → Active or Suspended → Active).
|
||||||
|
pub fn activate(&mut self, id: SessionId) -> Result<(), SessionTransitionError> {
|
||||||
|
self.transition(id, SessionState::Active)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Suspend a session (Active → Suspended).
|
||||||
|
pub fn suspend(&mut self, id: SessionId) -> Result<(), SessionTransitionError> {
|
||||||
|
self.transition(id, SessionState::Suspended)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Destroy a session.
|
||||||
|
pub fn destroy(&mut self, id: SessionId) -> Result<(), SessionTransitionError> {
|
||||||
|
self.transition(id, SessionState::Destroyed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the user for a session (after authentication).
|
||||||
|
pub fn set_user(&mut self, id: SessionId, user: String) {
|
||||||
|
if let Some(session) = self.sessions.get_mut(&id) {
|
||||||
|
session.user = Some(user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean up expired suspended sessions. Returns IDs of destroyed sessions.
|
||||||
|
pub fn cleanup_expired(&mut self) -> Vec<SessionId> {
|
||||||
|
let expired: Vec<SessionId> = self
|
||||||
|
.sessions
|
||||||
|
.values()
|
||||||
|
.filter(|s| s.is_suspend_expired())
|
||||||
|
.map(|s| s.id)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for id in &expired {
|
||||||
|
if let Err(e) = self.destroy(*id) {
|
||||||
|
warn!(%id, error = %e, "failed to destroy expired session");
|
||||||
|
} else {
|
||||||
|
info!(%id, "expired suspended session destroyed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expired
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove destroyed sessions from memory entirely.
|
||||||
|
pub fn purge_destroyed(&mut self) {
|
||||||
|
self.sessions
|
||||||
|
.retain(|_, s| s.state != SessionState::Destroyed);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all sessions (for admin queries).
|
||||||
|
pub fn list(&self) -> impl Iterator<Item = &Session> {
|
||||||
|
self.sessions.values()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Count sessions by state.
|
||||||
|
pub fn count_by_state(&self, state: SessionState) -> usize {
|
||||||
|
self.sessions.values().filter(|s| s.state == state).count()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SessionRegistry {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn create_and_lookup() {
|
||||||
|
let mut reg = SessionRegistry::new();
|
||||||
|
let token = SessionToken::new("abc-123");
|
||||||
|
let id = reg.create_session(token.clone());
|
||||||
|
|
||||||
|
assert!(reg.get(id).is_some());
|
||||||
|
assert!(reg.find_by_token(&token).is_some());
|
||||||
|
assert_eq!(reg.find_by_token(&token).unwrap().id, id);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lifecycle_through_registry() {
|
||||||
|
let mut reg = SessionRegistry::new();
|
||||||
|
let token = SessionToken::new("tok");
|
||||||
|
let id = reg.create_session(token.clone());
|
||||||
|
|
||||||
|
assert_eq!(reg.get(id).unwrap().state, SessionState::Creating);
|
||||||
|
|
||||||
|
reg.activate(id).unwrap();
|
||||||
|
assert_eq!(reg.get(id).unwrap().state, SessionState::Active);
|
||||||
|
|
||||||
|
reg.suspend(id).unwrap();
|
||||||
|
assert_eq!(reg.get(id).unwrap().state, SessionState::Suspended);
|
||||||
|
|
||||||
|
// Resume
|
||||||
|
reg.activate(id).unwrap();
|
||||||
|
assert_eq!(reg.get(id).unwrap().state, SessionState::Active);
|
||||||
|
|
||||||
|
reg.destroy(id).unwrap();
|
||||||
|
// Token index should be cleaned up
|
||||||
|
assert!(reg.find_by_token(&token).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cleanup_expired() {
|
||||||
|
let mut reg = SessionRegistry::new();
|
||||||
|
let id = reg.create_session(SessionToken::new("t1"));
|
||||||
|
reg.activate(id).unwrap();
|
||||||
|
reg.suspend(id).unwrap();
|
||||||
|
|
||||||
|
// Set zero timeout so it expires immediately
|
||||||
|
reg.sessions.get_mut(&id).unwrap().suspend_timeout = std::time::Duration::ZERO;
|
||||||
|
|
||||||
|
let expired = reg.cleanup_expired();
|
||||||
|
assert_eq!(expired.len(), 1);
|
||||||
|
assert_eq!(expired[0], id);
|
||||||
|
assert_eq!(reg.get(id).unwrap().state, SessionState::Destroyed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn purge_destroyed() {
|
||||||
|
let mut reg = SessionRegistry::new();
|
||||||
|
let id = reg.create_session(SessionToken::new("t"));
|
||||||
|
reg.activate(id).unwrap();
|
||||||
|
reg.destroy(id).unwrap();
|
||||||
|
|
||||||
|
assert!(reg.get(id).is_some()); // Still in memory
|
||||||
|
reg.purge_destroyed();
|
||||||
|
assert!(reg.get(id).is_none()); // Gone
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn multiple_sessions() {
|
||||||
|
let mut reg = SessionRegistry::new();
|
||||||
|
let id1 = reg.create_session(SessionToken::new("tok-a"));
|
||||||
|
let id2 = reg.create_session(SessionToken::new("tok-b"));
|
||||||
|
|
||||||
|
assert_ne!(id1, id2);
|
||||||
|
reg.activate(id1).unwrap();
|
||||||
|
reg.activate(id2).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(reg.count_by_state(SessionState::Active), 2);
|
||||||
|
|
||||||
|
reg.suspend(id1).unwrap();
|
||||||
|
assert_eq!(reg.count_by_state(SessionState::Active), 1);
|
||||||
|
assert_eq!(reg.count_by_state(SessionState::Suspended), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn set_user() {
|
||||||
|
let mut reg = SessionRegistry::new();
|
||||||
|
let id = reg.create_session(SessionToken::new("t"));
|
||||||
|
assert!(reg.get(id).unwrap().user.is_none());
|
||||||
|
|
||||||
|
reg.set_user(id, "alice".to_string());
|
||||||
|
assert_eq!(reg.get(id).unwrap().user.as_deref(), Some("alice"));
|
||||||
|
}
|
||||||
|
}
|
||||||
239
crates/wrsrvd/src/session/types.rs
Normal file
239
crates/wrsrvd/src/session/types.rs
Normal file
|
|
@ -0,0 +1,239 @@
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
/// Unique identifier for a compositor session.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub struct SessionId(u64);
|
||||||
|
|
||||||
|
impl SessionId {
|
||||||
|
pub fn from_raw(id: u64) -> Self {
|
||||||
|
Self(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn raw(self) -> u64 {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for SessionId {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "session-{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Opaque token that identifies a session across connections.
|
||||||
|
///
|
||||||
|
/// This is NOT the user's identity — it's a session locator. Different
|
||||||
|
/// token backends (software UUID, smart card, NFC) all produce these.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub struct SessionToken(pub String);
|
||||||
|
|
||||||
|
impl SessionToken {
|
||||||
|
pub fn new(token: impl Into<String>) -> Self {
|
||||||
|
Self(token.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_str(&self) -> &str {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for SessionToken {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{}", self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Session lifecycle states per ADR-004.
|
||||||
|
///
|
||||||
|
/// ```text
|
||||||
|
/// [Creating] → [Active] ⇄ [Suspended] → [Destroyed]
|
||||||
|
/// ↓ ↑
|
||||||
|
/// [Destroyed] ←─────────────────┘
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum SessionState {
|
||||||
|
/// Session is being created (compositor starting, greeter launching).
|
||||||
|
Creating,
|
||||||
|
/// Session is active with a connected client.
|
||||||
|
Active,
|
||||||
|
/// Client disconnected; session preserved, rendering paused.
|
||||||
|
Suspended,
|
||||||
|
/// Session terminated; all resources cleaned up.
|
||||||
|
Destroyed,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionState {
|
||||||
|
/// Check if transitioning from this state to `next` is valid.
|
||||||
|
pub fn can_transition_to(self, next: SessionState) -> bool {
|
||||||
|
matches!(
|
||||||
|
(self, next),
|
||||||
|
(SessionState::Creating, SessionState::Active)
|
||||||
|
| (SessionState::Creating, SessionState::Destroyed)
|
||||||
|
| (SessionState::Active, SessionState::Suspended)
|
||||||
|
| (SessionState::Active, SessionState::Destroyed)
|
||||||
|
| (SessionState::Suspended, SessionState::Active)
|
||||||
|
| (SessionState::Suspended, SessionState::Destroyed)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for SessionState {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
SessionState::Creating => write!(f, "creating"),
|
||||||
|
SessionState::Active => write!(f, "active"),
|
||||||
|
SessionState::Suspended => write!(f, "suspended"),
|
||||||
|
SessionState::Destroyed => write!(f, "destroyed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A compositor session binding a token to a running desktop.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Session {
|
||||||
|
pub id: SessionId,
|
||||||
|
pub token: SessionToken,
|
||||||
|
pub state: SessionState,
|
||||||
|
pub created_at: Instant,
|
||||||
|
/// When the session entered Suspended state (for timeout tracking).
|
||||||
|
pub suspended_at: Option<Instant>,
|
||||||
|
/// How long a suspended session survives before being destroyed.
|
||||||
|
pub suspend_timeout: Duration,
|
||||||
|
/// Username (set after authentication).
|
||||||
|
pub user: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Session {
|
||||||
|
/// Create a new session in the Creating state.
|
||||||
|
pub fn new(id: SessionId, token: SessionToken) -> Self {
|
||||||
|
Self {
|
||||||
|
id,
|
||||||
|
token,
|
||||||
|
state: SessionState::Creating,
|
||||||
|
created_at: Instant::now(),
|
||||||
|
suspended_at: None,
|
||||||
|
suspend_timeout: Duration::from_secs(24 * 60 * 60), // 24 hours default
|
||||||
|
user: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attempt a state transition. Returns Ok(()) if valid, Err with reason if not.
|
||||||
|
pub fn transition(&mut self, next: SessionState) -> Result<(), SessionTransitionError> {
|
||||||
|
if !self.state.can_transition_to(next) {
|
||||||
|
return Err(SessionTransitionError {
|
||||||
|
session_id: self.id,
|
||||||
|
from: self.state,
|
||||||
|
to: next,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if next == SessionState::Suspended {
|
||||||
|
self.suspended_at = Some(Instant::now());
|
||||||
|
} else {
|
||||||
|
self.suspended_at = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.state = next;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a suspended session has exceeded its timeout.
|
||||||
|
pub fn is_suspend_expired(&self) -> bool {
|
||||||
|
if self.state != SessionState::Suspended {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
self.suspended_at
|
||||||
|
.is_some_and(|t| t.elapsed() >= self.suspend_timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error when an invalid session state transition is attempted.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SessionTransitionError {
|
||||||
|
pub session_id: SessionId,
|
||||||
|
pub from: SessionState,
|
||||||
|
pub to: SessionState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for SessionTransitionError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"invalid session transition for {}: {} → {}",
|
||||||
|
self.session_id, self.from, self.to
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for SessionTransitionError {}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_lifecycle_transitions() {
|
||||||
|
let mut session = Session::new(SessionId::from_raw(1), SessionToken::new("test-token"));
|
||||||
|
|
||||||
|
assert_eq!(session.state, SessionState::Creating);
|
||||||
|
|
||||||
|
// Creating → Active
|
||||||
|
session.transition(SessionState::Active).unwrap();
|
||||||
|
assert_eq!(session.state, SessionState::Active);
|
||||||
|
assert!(session.suspended_at.is_none());
|
||||||
|
|
||||||
|
// Active → Suspended
|
||||||
|
session.transition(SessionState::Suspended).unwrap();
|
||||||
|
assert_eq!(session.state, SessionState::Suspended);
|
||||||
|
assert!(session.suspended_at.is_some());
|
||||||
|
|
||||||
|
// Suspended → Active (resume)
|
||||||
|
session.transition(SessionState::Active).unwrap();
|
||||||
|
assert_eq!(session.state, SessionState::Active);
|
||||||
|
assert!(session.suspended_at.is_none());
|
||||||
|
|
||||||
|
// Active → Destroyed
|
||||||
|
session.transition(SessionState::Destroyed).unwrap();
|
||||||
|
assert_eq!(session.state, SessionState::Destroyed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_transitions_rejected() {
|
||||||
|
let mut session = Session::new(SessionId::from_raw(1), SessionToken::new("t"));
|
||||||
|
|
||||||
|
// Creating → Suspended is invalid
|
||||||
|
assert!(session.transition(SessionState::Suspended).is_err());
|
||||||
|
|
||||||
|
// Creating → Active is valid
|
||||||
|
session.transition(SessionState::Active).unwrap();
|
||||||
|
|
||||||
|
// Active → Creating is invalid
|
||||||
|
assert!(session.transition(SessionState::Creating).is_err());
|
||||||
|
|
||||||
|
// Destroyed is terminal
|
||||||
|
session.transition(SessionState::Destroyed).unwrap();
|
||||||
|
assert!(session.transition(SessionState::Active).is_err());
|
||||||
|
assert!(session.transition(SessionState::Suspended).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn creating_can_be_destroyed() {
|
||||||
|
let mut session = Session::new(SessionId::from_raw(1), SessionToken::new("t"));
|
||||||
|
// Creating → Destroyed (e.g., creation failed)
|
||||||
|
session.transition(SessionState::Destroyed).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn suspend_timeout_tracking() {
|
||||||
|
let mut session = Session::new(SessionId::from_raw(1), SessionToken::new("t"));
|
||||||
|
session.transition(SessionState::Active).unwrap();
|
||||||
|
session.transition(SessionState::Suspended).unwrap();
|
||||||
|
|
||||||
|
// With default 24h timeout, should not be expired yet
|
||||||
|
assert!(!session.is_suspend_expired());
|
||||||
|
|
||||||
|
// Set a zero timeout to simulate expiry
|
||||||
|
session.suspend_timeout = Duration::ZERO;
|
||||||
|
assert!(session.is_suspend_expired());
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue