Enhance SSH handling with retries and robust error management, refactor guest IP discovery

- Implement SSH execution retries with exponential backoff and timeout handling.
- Replace `virsh domifaddr` with a multi-strategy IP discovery approach.
- Introduce `OrchestratorError` for consistent, structured error reporting.
- Improve runner deployment and SSH session utilities for readability and reliability.
- Add dependencies: `thiserror`, `anyhow` for streamlined error handling.

Signed-off-by: Till Wegmueller <toasterson@gmail.com>
This commit is contained in:
Till Wegmueller 2025-11-15 21:46:54 +01:00
parent 038d1161a6
commit 9dfa9c4b95
No known key found for this signature in database
4 changed files with 201 additions and 50 deletions

View file

@ -12,6 +12,8 @@ libvirt = []
common = { path = "../common" }
clap = { version = "4", features = ["derive", "env"] }
miette = { version = "7", features = ["fancy"] }
thiserror = "1"
anyhow = "1"
tracing = "0.1"
tokio = { version = "1", features = ["rt-multi-thread", "macros", "signal", "fs", "io-util", "process", "net"] }
serde = { version = "1", features = ["derive"] }

View file

@ -3,6 +3,7 @@ mod hypervisor;
mod persist;
mod scheduler;
mod http;
mod error;
use std::{collections::HashMap, path::PathBuf, time::Duration};
#[cfg(unix)]
@ -13,6 +14,7 @@ use miette::{IntoDiagnostic as _, Result};
use tracing::{info, warn};
use crate::persist::{JobState, Persist};
use crate::error::OrchestratorError;
use config::OrchestratorConfig;
use hypervisor::{JobContext, RouterHypervisor, VmSpec};
use scheduler::{SchedItem, Scheduler, ExecConfig};
@ -218,11 +220,18 @@ async fn main() -> Result<()> {
use ssh_key::{Algorithm, PrivateKey, rand_core::OsRng};
let mut rng = OsRng;
let sk = PrivateKey::random(&mut rng, Algorithm::Ed25519)
.map_err(|e| miette::miette!("ssh keygen failed: {e}"))?;
.map_err(|e| OrchestratorError::SshKeygen(e.into()))
.into_diagnostic()?;
let pk = sk.public_key();
let pub_txt = pk.to_openssh().map_err(|e| miette::miette!("encode pubkey: {e}"))?;
let pub_txt = pk
.to_openssh()
.map_err(|e| OrchestratorError::EncodePubKey(e.into()))
.into_diagnostic()?;
// Encode private key in OpenSSH format (Zeroizing<String> -> String)
let priv_txt_z = sk.to_openssh(ssh_key::LineEnding::LF).map_err(|e| miette::miette!("encode privkey: {e}"))?;
let priv_txt_z = sk
.to_openssh(ssh_key::LineEnding::LF)
.map_err(|e| OrchestratorError::EncodePrivKey(e.into()))
.into_diagnostic()?;
let priv_txt: String = priv_txt_z.to_string();
let _ = persist.save_job_ssh_keys(job.request_id, &pub_txt, &priv_txt).await;
info!(request_id = %job.request_id, "generated and saved new per-job SSH keys");

View file

@ -1,5 +1,6 @@
use chrono::Utc;
use miette::{IntoDiagnostic as _, Result};
use crate::error::OrchestratorError;
use sea_orm::sea_query::{Expr, OnConflict};
use sea_orm::{
entity::prelude::*, ColumnTrait, Database, DatabaseConnection, QueryFilter,
@ -189,12 +190,13 @@ impl Persist {
opts.max_connections(1)
.min_connections(1)
.sqlx_logging(false);
let db = Database::connect(opts).await.into_diagnostic()
.map_err(|e| miette::miette!("failed to connect to database: {e}"))?;
let db = Database::connect(opts).await
.map_err(|e| OrchestratorError::DbConnect(e.into()))
.into_diagnostic()?;
migration::Migrator::up(&db, None)
.await
.into_diagnostic()
.map_err(|e| miette::miette!("failed to apply migrations: {e}"))?;
.map_err(|e| OrchestratorError::DbMigrate(e.into()))
.into_diagnostic()?;
info!("persistence initialized (connected and migrated)");
Ok(Self { db: Some(db) })
} else {

View file

@ -1,6 +1,7 @@
use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration};
use dashmap::DashMap;
use miette::{IntoDiagnostic as _, Result};
use crate::error::OrchestratorError;
use tokio::sync::{Notify, Semaphore, mpsc};
use tracing::{error, info, warn};
use common::{publish_deadletter, publish_job_result, DeadLetter, JobRequest, messages::JobResult};
@ -173,7 +174,18 @@ impl<H: Hypervisor + 'static> Scheduler<H> {
let repo_url = item.ctx.repo_url.clone();
let commit_sha = item.ctx.commit_sha.clone();
let request_id = item.ctx.request_id;
match run_job_via_ssh_owned(ip_owned, user, key_mem_opt, per_job_key_path, local_runner, remote_path, repo_url, commit_sha, request_id).await {
match run_job_via_ssh_with_retry(
ip_owned,
user,
key_mem_opt,
per_job_key_path,
local_runner,
remote_path,
repo_url,
commit_sha,
request_id,
Duration::from_secs(exec_cfg.ssh_connect_timeout_secs),
).await {
Ok((ok, code, lines)) => {
success = ok; exit_code = code;
// Persist lines
@ -287,41 +299,23 @@ async fn tail_console_to_joblog(persist: Arc<Persist>, request_id: Uuid, console
}
#[cfg(all(target_os = "linux", feature = "libvirt"))]
async fn discover_guest_ip_virsh(domain: &str, _timeout: Duration) -> Option<String> {
use tokio::task;
// Best-effort: call `virsh domifaddr <domain> --source lease`
let dom = domain.to_string();
let out = task::spawn_blocking(move || {
std::process::Command::new("virsh")
.args(["domifaddr", &dom, "--source", "lease"])
.output()
})
.await
.ok()?
.ok()?;
if !out.status.success() {
return None;
}
let s = String::from_utf8_lossy(&out.stdout);
async fn discover_guest_ip_virsh(domain: &str, timeout: Duration) -> Option<String> {
use tokio::{task, time::{sleep, Instant, Duration}};
use std::process::Command;
use tracing::debug;
fn parse_ipv4_from_text(s: &str) -> Option<String> {
for line in s.lines() {
// crude parse: look for IPv4 address in the line
if line.contains("ipv4") {
// Line format often contains '192.168.x.y/24'
if let Some(slash) = line.find('/') {
// find last whitespace before slash
let start = line[..slash].rfind(' ').map(|i| i + 1).unwrap_or(0);
let ip = &line[start..slash];
if ip.chars().all(|c| c.is_ascii_digit() || c == '.') {
return Some(ip.to_string());
}
}
}
// Fallback: extract first a.b.c.d sequence
let mut cur = String::new();
for ch in line.chars() {
if ch.is_ascii_digit() || ch == '.' {
cur.push(ch);
} else {
if ch.is_ascii_digit() || ch == '.' { cur.push(ch); } else {
if cur.split('.').count() == 4 { return Some(cur.clone()); }
cur.clear();
}
@ -331,6 +325,136 @@ async fn discover_guest_ip_virsh(domain: &str, _timeout: Duration) -> Option<Str
None
}
async fn run_cmd(args: &[&str]) -> Option<String> {
let args_vec = args.iter().map(|s| s.to_string()).collect::<Vec<_>>();
task::spawn_blocking(move || {
Command::new("virsh").args(&args_vec).output()
}).await.ok()?.ok().and_then(|out| {
if out.status.success() {
Some(String::from_utf8_lossy(&out.stdout).to_string())
} else { None }
})
}
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
// 1) Try domifaddr via agent then lease then default
for source in [Some("agent"), Some("lease"), None] {
let mut args = vec!["domifaddr", domain];
if let Some(src) = source { args.push("--source"); args.push(src); }
if let Some(out) = run_cmd(&args).await {
if let Some(ip) = parse_ipv4_from_text(&out) { debug!(domain=%domain, method=%format!("domifaddr/{:?}", source), ip=%ip, "discovered IP"); return Some(ip); }
}
}
// 2) Try domiflist to get MAC and possibly network name
let mut mac: Option<String> = None;
let mut net_name: Option<String> = None;
if let Some(out) = run_cmd(&["domiflist", domain]).await {
for line in out.lines().skip(2) { // skip header lines
let cols: Vec<&str> = line.split_whitespace().collect();
if cols.len() >= 5 {
// columns: Interface Type Source Model MAC
net_name = Some(cols[2].to_string());
mac = Some(cols[4].to_ascii_lowercase());
break;
}
}
}
// Fallback to env if domiflist didn't give a Source network
if net_name.is_none() {
if let Ok(env_net) = std::env::var("LIBVIRT_NETWORK") { if !env_net.is_empty() { net_name = Some(env_net); } }
}
// 2a) Parse leases file if we have network and MAC
if let (Some(net), Some(mac_s)) = (net_name.clone(), mac.clone()) {
let path = format!("/var/lib/libvirt/dnsmasq/{}.leases", net);
let content_opt = task::spawn_blocking(move || std::fs::read_to_string(path)).await.ok().and_then(|r| r.ok());
if let Some(content) = content_opt {
for line in content.lines() {
// Format: epoch MAC IP hostname clientid
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 3 && parts[1].eq_ignore_ascii_case(mac_s.as_str()) {
let ip = parts[2].to_string();
debug!(domain=%domain, method="dnsmasq.leases", ip=%ip, "discovered IP");
return Some(ip);
}
}
}
}
// 2b) Try virsh net-dhcp-leases <network>
if let Some(net) = net_name {
if let Some(out) = run_cmd(&["net-dhcp-leases", &net]).await {
if let Some(ref mac_s) = mac {
for line in out.lines() {
if line.to_ascii_lowercase().contains(mac_s.as_str()) {
if let Some(ip) = parse_ipv4_from_text(line) { debug!(domain=%domain, method="net-dhcp-leases", ip=%ip, "discovered IP"); return Some(ip); }
}
}
}
if let Some(ip) = parse_ipv4_from_text(&out) { debug!(domain=%domain, method="net-dhcp-leases-any", ip=%ip, "discovered IP"); return Some(ip); }
}
}
sleep(Duration::from_secs(1)).await;
}
None
}
#[cfg(all(target_os = "linux", feature = "libvirt"))]
async fn run_job_via_ssh_with_retry(
ip: String,
user: String,
key_mem_opt: Option<String>,
per_job_key_path: Option<PathBuf>,
local_runner: PathBuf,
remote_path: PathBuf,
repo_url: String,
commit_sha: String,
request_id: uuid::Uuid,
timeout: Duration,
) -> miette::Result<(bool, i32, Vec<(bool, String)>)> {
use tokio::time::{sleep, Instant};
use tracing::warn;
let deadline = Instant::now() + timeout;
let mut attempt: u32 = 0;
let mut backoff = Duration::from_secs(1);
loop {
attempt += 1;
match run_job_via_ssh_owned(
ip.clone(),
user.clone(),
key_mem_opt.clone(),
per_job_key_path.clone(),
local_runner.clone(),
remote_path.clone(),
repo_url.clone(),
commit_sha.clone(),
request_id,
).await {
Ok(r) => return Ok(r),
Err(e) => {
let now = Instant::now();
if now >= deadline {
return Err(e); // give up with last error
}
warn!(
attempt,
remaining_ms = (deadline - now).as_millis() as u64,
ip = %ip,
request_id = %request_id,
error = %e,
"SSH connect/exec not ready yet; will retry"
);
// Cap backoff at 5s and dont sleep past the deadline
let sleep_dur = std::cmp::min(backoff, deadline.saturating_duration_since(now));
sleep(sleep_dur).await;
backoff = std::cmp::min(backoff.saturating_mul(2), Duration::from_secs(5));
}
}
}
}
#[cfg(all(target_os = "linux", feature = "libvirt"))]
async fn run_job_via_ssh_owned(
ip: String,
@ -351,17 +475,21 @@ async fn run_job_via_ssh_owned(
let (ok, code, lines) = task::spawn_blocking(move || -> miette::Result<(bool, i32, Vec<(bool, String)>)> {
// Connect TCP
let tcp = TcpStream::connect(format!("{}:22", ip)).map_err(|e| miette::miette!("tcp connect failed: {e}"))?;
let mut sess = Session::new().or_else(|_| Err(miette::miette!("ssh session new failed")))?;
let tcp = TcpStream::connect(format!("{}:22", ip))
.map_err(OrchestratorError::TcpConnect)
.into_diagnostic()?;
let mut sess = Session::new().map_err(|_| OrchestratorError::SshSessionNew).into_diagnostic()?;
sess.set_tcp_stream(tcp);
sess.handshake().map_err(|e| miette::miette!("ssh handshake: {e}"))?;
sess.handshake().map_err(OrchestratorError::SshHandshake).into_diagnostic()?;
// Authenticate priority: memory key -> per-job file -> global file
if let Some(ref key_pem) = key_mem_opt {
sess.userauth_pubkey_memory(&user, None, key_pem, None)
.map_err(|e| miette::miette!("ssh auth failed for {user} (memory key): {e}"))?;
.map_err(|e| OrchestratorError::SshAuth { user: user.clone(), source: e })
.into_diagnostic()?;
} else if let Some(ref p) = per_job_key_path {
sess.userauth_pubkey_file(&user, None, p, None)
.map_err(|e| miette::miette!("ssh auth failed for {user} (per-job file): {e}"))?;
.map_err(|e| OrchestratorError::SshAuth { user: user.clone(), source: e })
.into_diagnostic()?;
} else {
return Err(miette::miette!("no SSH key available for job (neither in-memory nor per-job file)"));
}
@ -369,13 +497,19 @@ async fn run_job_via_ssh_owned(
return Err(miette::miette!("ssh not authenticated"));
}
// Upload runner via SFTP
let sftp = sess.sftp().map_err(|e| miette::miette!("sftp init failed: {e}"))?;
let mut local = StdFile::open(&local_runner).map_err(|e| miette::miette!("open local runner: {e}"))?;
let sftp = sess.sftp().map_err(OrchestratorError::SftpInit).into_diagnostic()?;
let mut local = StdFile::open(&local_runner)
.map_err(OrchestratorError::OpenLocalRunner)
.into_diagnostic()?;
let mut buf = Vec::new();
local.read_to_end(&mut buf).map_err(|e| miette::miette!("read runner: {e}"))?;
let mut remote = sftp.create(&remote_path).map_err(|e| miette::miette!("sftp create: {e}"))?;
local.read_to_end(&mut buf).map_err(OrchestratorError::ReadRunner).into_diagnostic()?;
let mut remote = sftp.create(&remote_path)
.map_err(OrchestratorError::SftpCreate)
.into_diagnostic()?;
use std::io::Write as _;
remote.write_all(&buf).map_err(|e| miette::miette!("sftp write: {e}"))?;
remote.write_all(&buf)
.map_err(OrchestratorError::SftpWrite)
.into_diagnostic()?;
let remote_file_stat = ssh2::FileStat { size: None, uid: None, gid: None, perm: Some(0o755), atime: None, mtime: None };
let _ = sftp.setstat(&remote_path, remote_file_stat);
@ -385,8 +519,12 @@ async fn run_job_via_ssh_owned(
repo_url, commit_sha, request_id, remote_path.display()
);
let mut channel = sess.channel_session().map_err(|e| miette::miette!("channel session: {e}"))?;
channel.exec(&format!("sh -lc '{}'", cmd)).map_err(|e| miette::miette!("exec: {e}"))?;
let mut channel = sess.channel_session()
.map_err(OrchestratorError::ChannelSession)
.into_diagnostic()?;
channel.exec(&format!("sh -lc '{}'", cmd))
.map_err(OrchestratorError::Exec)
.into_diagnostic()?;
let mut out = String::new();
let mut err = String::new();