diff --git a/crates/orchestrator/Cargo.toml b/crates/orchestrator/Cargo.toml index 070ef0f..9542b32 100644 --- a/crates/orchestrator/Cargo.toml +++ b/crates/orchestrator/Cargo.toml @@ -12,6 +12,8 @@ libvirt = [] common = { path = "../common" } clap = { version = "4", features = ["derive", "env"] } miette = { version = "7", features = ["fancy"] } +thiserror = "1" +anyhow = "1" tracing = "0.1" tokio = { version = "1", features = ["rt-multi-thread", "macros", "signal", "fs", "io-util", "process", "net"] } serde = { version = "1", features = ["derive"] } diff --git a/crates/orchestrator/src/main.rs b/crates/orchestrator/src/main.rs index ed38051..ffee08b 100644 --- a/crates/orchestrator/src/main.rs +++ b/crates/orchestrator/src/main.rs @@ -3,6 +3,7 @@ mod hypervisor; mod persist; mod scheduler; mod http; +mod error; use std::{collections::HashMap, path::PathBuf, time::Duration}; #[cfg(unix)] @@ -13,6 +14,7 @@ use miette::{IntoDiagnostic as _, Result}; use tracing::{info, warn}; use crate::persist::{JobState, Persist}; +use crate::error::OrchestratorError; use config::OrchestratorConfig; use hypervisor::{JobContext, RouterHypervisor, VmSpec}; use scheduler::{SchedItem, Scheduler, ExecConfig}; @@ -218,11 +220,18 @@ async fn main() -> Result<()> { use ssh_key::{Algorithm, PrivateKey, rand_core::OsRng}; let mut rng = OsRng; let sk = PrivateKey::random(&mut rng, Algorithm::Ed25519) - .map_err(|e| miette::miette!("ssh keygen failed: {e}"))?; + .map_err(|e| OrchestratorError::SshKeygen(e.into())) + .into_diagnostic()?; let pk = sk.public_key(); - let pub_txt = pk.to_openssh().map_err(|e| miette::miette!("encode pubkey: {e}"))?; + let pub_txt = pk + .to_openssh() + .map_err(|e| OrchestratorError::EncodePubKey(e.into())) + .into_diagnostic()?; // Encode private key in OpenSSH format (Zeroizing -> String) - let priv_txt_z = sk.to_openssh(ssh_key::LineEnding::LF).map_err(|e| miette::miette!("encode privkey: {e}"))?; + let priv_txt_z = sk + .to_openssh(ssh_key::LineEnding::LF) + .map_err(|e| OrchestratorError::EncodePrivKey(e.into())) + .into_diagnostic()?; let priv_txt: String = priv_txt_z.to_string(); let _ = persist.save_job_ssh_keys(job.request_id, &pub_txt, &priv_txt).await; info!(request_id = %job.request_id, "generated and saved new per-job SSH keys"); diff --git a/crates/orchestrator/src/persist.rs b/crates/orchestrator/src/persist.rs index a8c86a2..ebe7179 100644 --- a/crates/orchestrator/src/persist.rs +++ b/crates/orchestrator/src/persist.rs @@ -1,5 +1,6 @@ use chrono::Utc; use miette::{IntoDiagnostic as _, Result}; +use crate::error::OrchestratorError; use sea_orm::sea_query::{Expr, OnConflict}; use sea_orm::{ entity::prelude::*, ColumnTrait, Database, DatabaseConnection, QueryFilter, @@ -189,12 +190,13 @@ impl Persist { opts.max_connections(1) .min_connections(1) .sqlx_logging(false); - let db = Database::connect(opts).await.into_diagnostic() - .map_err(|e| miette::miette!("failed to connect to database: {e}"))?; + let db = Database::connect(opts).await + .map_err(|e| OrchestratorError::DbConnect(e.into())) + .into_diagnostic()?; migration::Migrator::up(&db, None) .await - .into_diagnostic() - .map_err(|e| miette::miette!("failed to apply migrations: {e}"))?; + .map_err(|e| OrchestratorError::DbMigrate(e.into())) + .into_diagnostic()?; info!("persistence initialized (connected and migrated)"); Ok(Self { db: Some(db) }) } else { diff --git a/crates/orchestrator/src/scheduler.rs b/crates/orchestrator/src/scheduler.rs index 19ea2a0..4c4716b 100644 --- a/crates/orchestrator/src/scheduler.rs +++ b/crates/orchestrator/src/scheduler.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; use dashmap::DashMap; use miette::{IntoDiagnostic as _, Result}; +use crate::error::OrchestratorError; use tokio::sync::{Notify, Semaphore, mpsc}; use tracing::{error, info, warn}; use common::{publish_deadletter, publish_job_result, DeadLetter, JobRequest, messages::JobResult}; @@ -173,7 +174,18 @@ impl Scheduler { let repo_url = item.ctx.repo_url.clone(); let commit_sha = item.ctx.commit_sha.clone(); let request_id = item.ctx.request_id; - match run_job_via_ssh_owned(ip_owned, user, key_mem_opt, per_job_key_path, local_runner, remote_path, repo_url, commit_sha, request_id).await { + match run_job_via_ssh_with_retry( + ip_owned, + user, + key_mem_opt, + per_job_key_path, + local_runner, + remote_path, + repo_url, + commit_sha, + request_id, + Duration::from_secs(exec_cfg.ssh_connect_timeout_secs), + ).await { Ok((ok, code, lines)) => { success = ok; exit_code = code; // Persist lines @@ -287,50 +299,162 @@ async fn tail_console_to_joblog(persist: Arc, request_id: Uuid, console } #[cfg(all(target_os = "linux", feature = "libvirt"))] -async fn discover_guest_ip_virsh(domain: &str, _timeout: Duration) -> Option { - use tokio::task; - // Best-effort: call `virsh domifaddr --source lease` - let dom = domain.to_string(); - let out = task::spawn_blocking(move || { - std::process::Command::new("virsh") - .args(["domifaddr", &dom, "--source", "lease"]) - .output() - }) - .await - .ok()? - .ok()?; - if !out.status.success() { - return None; - } - let s = String::from_utf8_lossy(&out.stdout); - for line in s.lines() { - // crude parse: look for IPv4 address in the line - if line.contains("ipv4") { - // Line format often contains '192.168.x.y/24' +async fn discover_guest_ip_virsh(domain: &str, timeout: Duration) -> Option { + use tokio::{task, time::{sleep, Instant, Duration}}; + use std::process::Command; + use tracing::debug; + + fn parse_ipv4_from_text(s: &str) -> Option { + for line in s.lines() { if let Some(slash) = line.find('/') { - // find last whitespace before slash let start = line[..slash].rfind(' ').map(|i| i + 1).unwrap_or(0); let ip = &line[start..slash]; if ip.chars().all(|c| c.is_ascii_digit() || c == '.') { return Some(ip.to_string()); } } + let mut cur = String::new(); + for ch in line.chars() { + if ch.is_ascii_digit() || ch == '.' { cur.push(ch); } else { + if cur.split('.').count() == 4 { return Some(cur.clone()); } + cur.clear(); + } + } + if cur.split('.').count() == 4 { return Some(cur); } } - // Fallback: extract first a.b.c.d sequence - let mut cur = String::new(); - for ch in line.chars() { - if ch.is_ascii_digit() || ch == '.' { - cur.push(ch); - } else { - if cur.split('.').count() == 4 { return Some(cur.clone()); } - cur.clear(); + None + } + + async fn run_cmd(args: &[&str]) -> Option { + let args_vec = args.iter().map(|s| s.to_string()).collect::>(); + task::spawn_blocking(move || { + Command::new("virsh").args(&args_vec).output() + }).await.ok()?.ok().and_then(|out| { + if out.status.success() { + Some(String::from_utf8_lossy(&out.stdout).to_string()) + } else { None } + }) + } + + let deadline = Instant::now() + timeout; + while Instant::now() < deadline { + // 1) Try domifaddr via agent then lease then default + for source in [Some("agent"), Some("lease"), None] { + let mut args = vec!["domifaddr", domain]; + if let Some(src) = source { args.push("--source"); args.push(src); } + if let Some(out) = run_cmd(&args).await { + if let Some(ip) = parse_ipv4_from_text(&out) { debug!(domain=%domain, method=%format!("domifaddr/{:?}", source), ip=%ip, "discovered IP"); return Some(ip); } } } - if cur.split('.').count() == 4 { return Some(cur); } + + // 2) Try domiflist to get MAC and possibly network name + let mut mac: Option = None; + let mut net_name: Option = None; + if let Some(out) = run_cmd(&["domiflist", domain]).await { + for line in out.lines().skip(2) { // skip header lines + let cols: Vec<&str> = line.split_whitespace().collect(); + if cols.len() >= 5 { + // columns: Interface Type Source Model MAC + net_name = Some(cols[2].to_string()); + mac = Some(cols[4].to_ascii_lowercase()); + break; + } + } + } + // Fallback to env if domiflist didn't give a Source network + if net_name.is_none() { + if let Ok(env_net) = std::env::var("LIBVIRT_NETWORK") { if !env_net.is_empty() { net_name = Some(env_net); } } + } + // 2a) Parse leases file if we have network and MAC + if let (Some(net), Some(mac_s)) = (net_name.clone(), mac.clone()) { + let path = format!("/var/lib/libvirt/dnsmasq/{}.leases", net); + let content_opt = task::spawn_blocking(move || std::fs::read_to_string(path)).await.ok().and_then(|r| r.ok()); + if let Some(content) = content_opt { + for line in content.lines() { + // Format: epoch MAC IP hostname clientid + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() >= 3 && parts[1].eq_ignore_ascii_case(mac_s.as_str()) { + let ip = parts[2].to_string(); + debug!(domain=%domain, method="dnsmasq.leases", ip=%ip, "discovered IP"); + return Some(ip); + } + } + } + } + // 2b) Try virsh net-dhcp-leases + if let Some(net) = net_name { + if let Some(out) = run_cmd(&["net-dhcp-leases", &net]).await { + if let Some(ref mac_s) = mac { + for line in out.lines() { + if line.to_ascii_lowercase().contains(mac_s.as_str()) { + if let Some(ip) = parse_ipv4_from_text(line) { debug!(domain=%domain, method="net-dhcp-leases", ip=%ip, "discovered IP"); return Some(ip); } + } + } + } + if let Some(ip) = parse_ipv4_from_text(&out) { debug!(domain=%domain, method="net-dhcp-leases-any", ip=%ip, "discovered IP"); return Some(ip); } + } + } + + sleep(Duration::from_secs(1)).await; } None } +#[cfg(all(target_os = "linux", feature = "libvirt"))] +async fn run_job_via_ssh_with_retry( + ip: String, + user: String, + key_mem_opt: Option, + per_job_key_path: Option, + local_runner: PathBuf, + remote_path: PathBuf, + repo_url: String, + commit_sha: String, + request_id: uuid::Uuid, + timeout: Duration, +) -> miette::Result<(bool, i32, Vec<(bool, String)>)> { + use tokio::time::{sleep, Instant}; + use tracing::warn; + + let deadline = Instant::now() + timeout; + let mut attempt: u32 = 0; + let mut backoff = Duration::from_secs(1); + loop { + attempt += 1; + match run_job_via_ssh_owned( + ip.clone(), + user.clone(), + key_mem_opt.clone(), + per_job_key_path.clone(), + local_runner.clone(), + remote_path.clone(), + repo_url.clone(), + commit_sha.clone(), + request_id, + ).await { + Ok(r) => return Ok(r), + Err(e) => { + let now = Instant::now(); + if now >= deadline { + return Err(e); // give up with last error + } + warn!( + attempt, + remaining_ms = (deadline - now).as_millis() as u64, + ip = %ip, + request_id = %request_id, + error = %e, + "SSH connect/exec not ready yet; will retry" + ); + // Cap backoff at 5s and don’t sleep past the deadline + let sleep_dur = std::cmp::min(backoff, deadline.saturating_duration_since(now)); + sleep(sleep_dur).await; + backoff = std::cmp::min(backoff.saturating_mul(2), Duration::from_secs(5)); + } + } + } +} + #[cfg(all(target_os = "linux", feature = "libvirt"))] async fn run_job_via_ssh_owned( ip: String, @@ -351,17 +475,21 @@ async fn run_job_via_ssh_owned( let (ok, code, lines) = task::spawn_blocking(move || -> miette::Result<(bool, i32, Vec<(bool, String)>)> { // Connect TCP - let tcp = TcpStream::connect(format!("{}:22", ip)).map_err(|e| miette::miette!("tcp connect failed: {e}"))?; - let mut sess = Session::new().or_else(|_| Err(miette::miette!("ssh session new failed")))?; + let tcp = TcpStream::connect(format!("{}:22", ip)) + .map_err(OrchestratorError::TcpConnect) + .into_diagnostic()?; + let mut sess = Session::new().map_err(|_| OrchestratorError::SshSessionNew).into_diagnostic()?; sess.set_tcp_stream(tcp); - sess.handshake().map_err(|e| miette::miette!("ssh handshake: {e}"))?; + sess.handshake().map_err(OrchestratorError::SshHandshake).into_diagnostic()?; // Authenticate priority: memory key -> per-job file -> global file if let Some(ref key_pem) = key_mem_opt { sess.userauth_pubkey_memory(&user, None, key_pem, None) - .map_err(|e| miette::miette!("ssh auth failed for {user} (memory key): {e}"))?; + .map_err(|e| OrchestratorError::SshAuth { user: user.clone(), source: e }) + .into_diagnostic()?; } else if let Some(ref p) = per_job_key_path { sess.userauth_pubkey_file(&user, None, p, None) - .map_err(|e| miette::miette!("ssh auth failed for {user} (per-job file): {e}"))?; + .map_err(|e| OrchestratorError::SshAuth { user: user.clone(), source: e }) + .into_diagnostic()?; } else { return Err(miette::miette!("no SSH key available for job (neither in-memory nor per-job file)")); } @@ -369,13 +497,19 @@ async fn run_job_via_ssh_owned( return Err(miette::miette!("ssh not authenticated")); } // Upload runner via SFTP - let sftp = sess.sftp().map_err(|e| miette::miette!("sftp init failed: {e}"))?; - let mut local = StdFile::open(&local_runner).map_err(|e| miette::miette!("open local runner: {e}"))?; + let sftp = sess.sftp().map_err(OrchestratorError::SftpInit).into_diagnostic()?; + let mut local = StdFile::open(&local_runner) + .map_err(OrchestratorError::OpenLocalRunner) + .into_diagnostic()?; let mut buf = Vec::new(); - local.read_to_end(&mut buf).map_err(|e| miette::miette!("read runner: {e}"))?; - let mut remote = sftp.create(&remote_path).map_err(|e| miette::miette!("sftp create: {e}"))?; + local.read_to_end(&mut buf).map_err(OrchestratorError::ReadRunner).into_diagnostic()?; + let mut remote = sftp.create(&remote_path) + .map_err(OrchestratorError::SftpCreate) + .into_diagnostic()?; use std::io::Write as _; - remote.write_all(&buf).map_err(|e| miette::miette!("sftp write: {e}"))?; + remote.write_all(&buf) + .map_err(OrchestratorError::SftpWrite) + .into_diagnostic()?; let remote_file_stat = ssh2::FileStat { size: None, uid: None, gid: None, perm: Some(0o755), atime: None, mtime: None }; let _ = sftp.setstat(&remote_path, remote_file_stat); @@ -385,8 +519,12 @@ async fn run_job_via_ssh_owned( repo_url, commit_sha, request_id, remote_path.display() ); - let mut channel = sess.channel_session().map_err(|e| miette::miette!("channel session: {e}"))?; - channel.exec(&format!("sh -lc '{}'", cmd)).map_err(|e| miette::miette!("exec: {e}"))?; + let mut channel = sess.channel_session() + .map_err(OrchestratorError::ChannelSession) + .into_diagnostic()?; + channel.exec(&format!("sh -lc '{}'", cmd)) + .map_err(OrchestratorError::Exec) + .into_diagnostic()?; let mut out = String::new(); let mut err = String::new();