solstice-ci/crates/orchestrator/src/scheduler.rs
Till Wegmueller 4ca78144f2
Add VM state monitoring and graceful shutdown enhancements
This commit enhances the `Scheduler` to monitor VM states for completion, enabling more accurate termination detection. It introduces periodic polling combined with shutdown signals to halt operations gracefully. Additionally, VM lifecycle management in the hypervisor is updated with `state` retrieval for precise status assessments. The VM domain configuration now includes serial console support.
2025-10-26 21:59:55 +01:00

298 lines
15 KiB
Rust

use std::{collections::HashMap, sync::Arc, time::Duration};
use dashmap::DashMap;
use miette::Result;
use tokio::sync::{mpsc, Semaphore, Notify};
use tracing::{error, info, warn};
use crate::hypervisor::{Hypervisor, VmSpec, JobContext, BackendTag};
use crate::persist::{Persist, VmPersistState, JobState};
pub struct Scheduler<H: Hypervisor + 'static> {
hv: Arc<H>,
tx: mpsc::Sender<SchedItem>,
rx: mpsc::Receiver<SchedItem>,
global_sem: Arc<Semaphore>,
label_sems: Arc<DashmapType>,
persist: Arc<Persist>,
placeholder_runtime: Duration,
}
type DashmapType = DashMap<String, Arc<Semaphore>>;
pub struct SchedItem {
pub spec: VmSpec,
pub ctx: JobContext,
}
impl<H: Hypervisor + 'static> Scheduler<H> {
pub fn new(hv: H, max_concurrency: usize, capacity_map: &HashMap<String, usize>, persist: Arc<Persist>, placeholder_runtime: Duration) -> Self {
let (tx, rx) = mpsc::channel::<SchedItem>(max_concurrency * 4);
let label_sems = DashMap::new();
for (label, cap) in capacity_map.iter() {
label_sems.insert(label.clone(), Arc::new(Semaphore::new(*cap)));
}
Self {
hv: Arc::new(hv),
tx,
rx,
global_sem: Arc::new(Semaphore::new(max_concurrency)),
label_sems: Arc::new(label_sems),
persist,
placeholder_runtime,
}
}
pub fn sender(&self) -> mpsc::Sender<SchedItem> { self.tx.clone() }
pub async fn run_with_shutdown(self, shutdown: Arc<Notify>) -> Result<()> {
let Scheduler { hv, mut rx, global_sem, label_sems, persist, placeholder_runtime, .. } = self;
let mut handles = Vec::new();
let mut shutting_down = false;
'scheduler: loop {
tokio::select! {
// Only react once to shutdown and then focus on draining the channel
_ = shutdown.notified(), if !shutting_down => {
shutting_down = true;
info!("scheduler: shutdown requested; will drain queue and interrupt placeholders");
break 'scheduler;
}
maybe_item = rx.recv() => {
match maybe_item {
Some(item) => {
let hv = hv.clone();
let global = global_sem.clone();
let label_sems = label_sems.clone();
let persist = persist.clone();
let shutdown = shutdown.clone();
let handle = tokio::spawn(async move {
// Acquire global and label permits (owned permits so they live inside the task)
let _g = match global.acquire_owned().await { Ok(p) => p, Err(_) => return };
let label_key = item.spec.label.clone();
let sem_arc = if let Some(entry) = label_sems.get(&label_key) { entry.clone() } else {
let s = Arc::new(Semaphore::new(1));
label_sems.insert(label_key.clone(), s.clone());
s
};
let _l = match sem_arc.acquire_owned().await { Ok(p) => p, Err(_) => return };
// Provision and run
match hv.prepare(&item.spec, &item.ctx).await {
Ok(h) => {
// persist prepared VM state
let overlay = h.overlay_path.as_ref().and_then(|p| p.to_str());
let seed = h.seed_iso_path.as_ref().and_then(|p| p.to_str());
let backend = match h.backend { BackendTag::Noop => Some("noop"), #[cfg(all(target_os = "linux", feature = "libvirt"))] BackendTag::Libvirt => Some("libvirt"), #[cfg(target_os = "illumos")] BackendTag::Zones => Some("zones") };
if let Err(e) = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Prepared).await {
warn!(error = %e, request_id = %item.ctx.request_id, domain = %h.id, "persist prepare failed");
}
if let Err(e) = hv.start(&h).await {
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to start VM");
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await;
return;
}
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Running).await;
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Running).await;
info!(request_id = %item.ctx.request_id, label = %label_key, "vm started (monitoring for completion)");
// Monitor VM state until it stops or until placeholder_runtime elapses; end early on shutdown
let start_time = std::time::Instant::now();
loop {
// Check current state first
if let Ok(crate::hypervisor::VmState::Stopped) = hv.state(&h).await {
info!(request_id = %item.ctx.request_id, label = %label_key, "vm reported stopped");
break;
}
if start_time.elapsed() >= placeholder_runtime { break; }
// Wait either for shutdown signal or a short delay before next poll
tokio::select! {
_ = shutdown.notified() => {
info!(request_id = %item.ctx.request_id, label = %label_key, "shutdown: ending early");
break;
}
_ = tokio::time::sleep(Duration::from_secs(2)) => {}
}
}
// Stop and destroy
if let Err(e) = hv.stop(&h, Duration::from_secs(10)).await {
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to stop VM");
}
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Stopped).await;
if let Err(e) = hv.destroy(h.clone()).await {
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to destroy VM");
}
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Destroyed).await;
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Succeeded).await;
}
Err(e) => {
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to prepare VM");
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await;
return;
}
}
});
handles.push(handle);
}
None => break,
}
}
}
}
// Wait for all in-flight tasks to finish
for h in handles { let _ = h.await; }
info!("scheduler: completed");
Ok(())
}
pub async fn run(self) -> Result<()> {
// Backward-compat: run until channel closed (no external shutdown)
self.run_with_shutdown(Arc::new(Notify::new())).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use async_trait::async_trait;
use dashmap::DashMap;
use crate::hypervisor::{VmHandle, VmState};
#[derive(Clone)]
struct MockHypervisor {
active_all: Arc<AtomicUsize>,
peak_all: Arc<AtomicUsize>,
// per-label current and peak
per_curr: Arc<DashMap<String, Arc<AtomicUsize>>>,
per_peak: Arc<DashMap<String, Arc<AtomicUsize>>>,
id_to_label: Arc<DashMap<String, String>>,
}
impl MockHypervisor {
fn new(active_all: Arc<AtomicUsize>, peak_all: Arc<AtomicUsize>) -> Self {
Self { active_all, peak_all, per_curr: Arc::new(DashMap::new()), per_peak: Arc::new(DashMap::new()), id_to_label: Arc::new(DashMap::new()) }
}
fn update_peak(peak: &AtomicUsize, current: usize) {
let mut prev = peak.load(Ordering::Relaxed);
while current > prev {
match peak.compare_exchange(prev, current, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => break,
Err(p) => prev = p,
}
}
}
}
#[async_trait]
impl Hypervisor for MockHypervisor {
async fn prepare(&self, spec: &VmSpec, ctx: &JobContext) -> miette::Result<VmHandle> {
let now = self.active_all.fetch_add(1, Ordering::SeqCst) + 1;
Self::update_peak(&self.peak_all, now);
// per-label current/peak
let entry = self.per_curr.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0)));
let curr = entry.fetch_add(1, Ordering::SeqCst) + 1;
let peak_entry = self.per_peak.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))).clone();
Self::update_peak(&peak_entry, curr);
let id = format!("job-{}", ctx.request_id);
self.id_to_label.insert(id.clone(), spec.label.clone());
Ok(VmHandle {
id,
backend: BackendTag::Noop,
work_dir: PathBuf::from("/tmp"),
overlay_path: None,
seed_iso_path: None,
})
}
async fn start(&self, _vm: &VmHandle) -> miette::Result<()> { Ok(()) }
async fn stop(&self, _vm: &VmHandle, _t: Duration) -> miette::Result<()> { Ok(()) }
async fn destroy(&self, vm: VmHandle) -> miette::Result<()> {
// decrement overall current
self.active_all.fetch_sub(1, Ordering::SeqCst);
// decrement per-label current using id->label map
if let Some(label) = self.id_to_label.remove(&vm.id).map(|e| e.1) {
if let Some(entry) = self.per_curr.get(&label) {
let _ = entry.fetch_sub(1, Ordering::SeqCst);
}
}
Ok(())
}
async fn state(&self, _vm: &VmHandle) -> miette::Result<VmState> { Ok(VmState::Prepared) }
}
fn make_spec(label: &str) -> VmSpec {
VmSpec {
label: label.to_string(),
image_path: PathBuf::from("/tmp/image"),
cpu: 1,
ram_mb: 512,
disk_gb: 1,
network: None,
nocloud: true,
user_data: None,
}
}
fn make_ctx() -> JobContext {
JobContext { request_id: uuid::Uuid::new_v4(), repo_url: "https://example.com/r.git".into(), commit_sha: "deadbeef".into(), workflow_job_id: None }
}
#[tokio::test(flavor = "multi_thread")]
async fn scheduler_respects_global_concurrency() {
let active_all = Arc::new(AtomicUsize::new(0));
let peak_all = Arc::new(AtomicUsize::new(0));
let hv = MockHypervisor::new(active_all.clone(), peak_all.clone());
let hv_probe = hv.clone();
let persist = Arc::new(Persist::new(None).await.unwrap());
let mut caps = HashMap::new();
caps.insert("x".to_string(), 10);
let sched = Scheduler::new(hv, 2, &caps, persist, Duration::from_millis(10));
let tx = sched.sender();
let run = tokio::spawn(async move { let _ = sched.run().await; });
for _ in 0..5 {
let _ = tx.send(SchedItem { spec: make_spec("x"), ctx: make_ctx() }).await;
}
drop(tx);
// Allow time for tasks to execute under concurrency limits
tokio::time::sleep(Duration::from_millis(500)).await;
assert!(hv_probe.peak_all.load(Ordering::SeqCst) <= 2, "peak should not exceed global limit");
// Stop the scheduler task
run.abort();
}
#[tokio::test(flavor = "multi_thread")]
async fn scheduler_respects_per_label_capacity() {
let active_all = Arc::new(AtomicUsize::new(0));
let peak_all = Arc::new(AtomicUsize::new(0));
let hv = MockHypervisor::new(active_all.clone(), peak_all.clone());
let hv_probe = hv.clone();
let persist = Arc::new(Persist::new(None).await.unwrap());
let mut caps = HashMap::new();
caps.insert("a".to_string(), 1);
caps.insert("b".to_string(), 2);
let sched = Scheduler::new(hv, 4, &caps, persist, Duration::from_millis(10));
let tx = sched.sender();
let run = tokio::spawn(async move { let _ = sched.run().await; });
for _ in 0..3 { let _ = tx.send(SchedItem { spec: make_spec("a"), ctx: make_ctx() }).await; }
for _ in 0..3 { let _ = tx.send(SchedItem { spec: make_spec("b"), ctx: make_ctx() }).await; }
drop(tx);
tokio::time::sleep(Duration::from_millis(800)).await;
// read per-label peaks
let a_peak = hv_probe.per_peak.get("a").map(|p| p.load(Ordering::SeqCst)).unwrap_or(0);
let b_peak = hv_probe.per_peak.get("b").map(|p| p.load(Ordering::SeqCst)).unwrap_or(0);
assert!(a_peak <= 1, "label a peak should be <= 1, got {}", a_peak);
assert!(b_peak <= 2, "label b peak should be <= 2, got {}", b_peak);
assert!(hv_probe.peak_all.load(Ordering::SeqCst) <= 4, "global peak should respect global limit");
run.abort();
}
}