mirror of
https://codeberg.org/Toasterson/solstice-ci.git
synced 2026-04-10 21:30:41 +00:00
This commit enhances the `Scheduler` to monitor VM states for completion, enabling more accurate termination detection. It introduces periodic polling combined with shutdown signals to halt operations gracefully. Additionally, VM lifecycle management in the hypervisor is updated with `state` retrieval for precise status assessments. The VM domain configuration now includes serial console support.
298 lines
15 KiB
Rust
298 lines
15 KiB
Rust
use std::{collections::HashMap, sync::Arc, time::Duration};
|
|
|
|
use dashmap::DashMap;
|
|
use miette::Result;
|
|
use tokio::sync::{mpsc, Semaphore, Notify};
|
|
use tracing::{error, info, warn};
|
|
|
|
use crate::hypervisor::{Hypervisor, VmSpec, JobContext, BackendTag};
|
|
use crate::persist::{Persist, VmPersistState, JobState};
|
|
|
|
pub struct Scheduler<H: Hypervisor + 'static> {
|
|
hv: Arc<H>,
|
|
tx: mpsc::Sender<SchedItem>,
|
|
rx: mpsc::Receiver<SchedItem>,
|
|
global_sem: Arc<Semaphore>,
|
|
label_sems: Arc<DashmapType>,
|
|
persist: Arc<Persist>,
|
|
placeholder_runtime: Duration,
|
|
}
|
|
|
|
type DashmapType = DashMap<String, Arc<Semaphore>>;
|
|
|
|
pub struct SchedItem {
|
|
pub spec: VmSpec,
|
|
pub ctx: JobContext,
|
|
}
|
|
|
|
impl<H: Hypervisor + 'static> Scheduler<H> {
|
|
pub fn new(hv: H, max_concurrency: usize, capacity_map: &HashMap<String, usize>, persist: Arc<Persist>, placeholder_runtime: Duration) -> Self {
|
|
let (tx, rx) = mpsc::channel::<SchedItem>(max_concurrency * 4);
|
|
let label_sems = DashMap::new();
|
|
for (label, cap) in capacity_map.iter() {
|
|
label_sems.insert(label.clone(), Arc::new(Semaphore::new(*cap)));
|
|
}
|
|
Self {
|
|
hv: Arc::new(hv),
|
|
tx,
|
|
rx,
|
|
global_sem: Arc::new(Semaphore::new(max_concurrency)),
|
|
label_sems: Arc::new(label_sems),
|
|
persist,
|
|
placeholder_runtime,
|
|
}
|
|
}
|
|
|
|
pub fn sender(&self) -> mpsc::Sender<SchedItem> { self.tx.clone() }
|
|
|
|
pub async fn run_with_shutdown(self, shutdown: Arc<Notify>) -> Result<()> {
|
|
let Scheduler { hv, mut rx, global_sem, label_sems, persist, placeholder_runtime, .. } = self;
|
|
let mut handles = Vec::new();
|
|
let mut shutting_down = false;
|
|
'scheduler: loop {
|
|
tokio::select! {
|
|
// Only react once to shutdown and then focus on draining the channel
|
|
_ = shutdown.notified(), if !shutting_down => {
|
|
shutting_down = true;
|
|
info!("scheduler: shutdown requested; will drain queue and interrupt placeholders");
|
|
break 'scheduler;
|
|
}
|
|
maybe_item = rx.recv() => {
|
|
match maybe_item {
|
|
Some(item) => {
|
|
let hv = hv.clone();
|
|
let global = global_sem.clone();
|
|
let label_sems = label_sems.clone();
|
|
let persist = persist.clone();
|
|
let shutdown = shutdown.clone();
|
|
let handle = tokio::spawn(async move {
|
|
// Acquire global and label permits (owned permits so they live inside the task)
|
|
let _g = match global.acquire_owned().await { Ok(p) => p, Err(_) => return };
|
|
let label_key = item.spec.label.clone();
|
|
let sem_arc = if let Some(entry) = label_sems.get(&label_key) { entry.clone() } else {
|
|
let s = Arc::new(Semaphore::new(1));
|
|
label_sems.insert(label_key.clone(), s.clone());
|
|
s
|
|
};
|
|
let _l = match sem_arc.acquire_owned().await { Ok(p) => p, Err(_) => return };
|
|
|
|
// Provision and run
|
|
match hv.prepare(&item.spec, &item.ctx).await {
|
|
Ok(h) => {
|
|
// persist prepared VM state
|
|
let overlay = h.overlay_path.as_ref().and_then(|p| p.to_str());
|
|
let seed = h.seed_iso_path.as_ref().and_then(|p| p.to_str());
|
|
let backend = match h.backend { BackendTag::Noop => Some("noop"), #[cfg(all(target_os = "linux", feature = "libvirt"))] BackendTag::Libvirt => Some("libvirt"), #[cfg(target_os = "illumos")] BackendTag::Zones => Some("zones") };
|
|
if let Err(e) = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Prepared).await {
|
|
warn!(error = %e, request_id = %item.ctx.request_id, domain = %h.id, "persist prepare failed");
|
|
}
|
|
if let Err(e) = hv.start(&h).await {
|
|
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to start VM");
|
|
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await;
|
|
return;
|
|
}
|
|
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Running).await;
|
|
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Running).await;
|
|
info!(request_id = %item.ctx.request_id, label = %label_key, "vm started (monitoring for completion)");
|
|
// Monitor VM state until it stops or until placeholder_runtime elapses; end early on shutdown
|
|
let start_time = std::time::Instant::now();
|
|
loop {
|
|
// Check current state first
|
|
if let Ok(crate::hypervisor::VmState::Stopped) = hv.state(&h).await {
|
|
info!(request_id = %item.ctx.request_id, label = %label_key, "vm reported stopped");
|
|
break;
|
|
}
|
|
if start_time.elapsed() >= placeholder_runtime { break; }
|
|
// Wait either for shutdown signal or a short delay before next poll
|
|
tokio::select! {
|
|
_ = shutdown.notified() => {
|
|
info!(request_id = %item.ctx.request_id, label = %label_key, "shutdown: ending early");
|
|
break;
|
|
}
|
|
_ = tokio::time::sleep(Duration::from_secs(2)) => {}
|
|
}
|
|
}
|
|
// Stop and destroy
|
|
if let Err(e) = hv.stop(&h, Duration::from_secs(10)).await {
|
|
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to stop VM");
|
|
}
|
|
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Stopped).await;
|
|
if let Err(e) = hv.destroy(h.clone()).await {
|
|
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to destroy VM");
|
|
}
|
|
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Destroyed).await;
|
|
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Succeeded).await;
|
|
}
|
|
Err(e) => {
|
|
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to prepare VM");
|
|
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await;
|
|
return;
|
|
}
|
|
}
|
|
});
|
|
handles.push(handle);
|
|
}
|
|
None => break,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Wait for all in-flight tasks to finish
|
|
for h in handles { let _ = h.await; }
|
|
info!("scheduler: completed");
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn run(self) -> Result<()> {
|
|
// Backward-compat: run until channel closed (no external shutdown)
|
|
self.run_with_shutdown(Arc::new(Notify::new())).await
|
|
}
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::path::PathBuf;
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
use async_trait::async_trait;
|
|
use dashmap::DashMap;
|
|
|
|
use crate::hypervisor::{VmHandle, VmState};
|
|
|
|
#[derive(Clone)]
|
|
struct MockHypervisor {
|
|
active_all: Arc<AtomicUsize>,
|
|
peak_all: Arc<AtomicUsize>,
|
|
// per-label current and peak
|
|
per_curr: Arc<DashMap<String, Arc<AtomicUsize>>>,
|
|
per_peak: Arc<DashMap<String, Arc<AtomicUsize>>>,
|
|
id_to_label: Arc<DashMap<String, String>>,
|
|
}
|
|
|
|
impl MockHypervisor {
|
|
fn new(active_all: Arc<AtomicUsize>, peak_all: Arc<AtomicUsize>) -> Self {
|
|
Self { active_all, peak_all, per_curr: Arc::new(DashMap::new()), per_peak: Arc::new(DashMap::new()), id_to_label: Arc::new(DashMap::new()) }
|
|
}
|
|
fn update_peak(peak: &AtomicUsize, current: usize) {
|
|
let mut prev = peak.load(Ordering::Relaxed);
|
|
while current > prev {
|
|
match peak.compare_exchange(prev, current, Ordering::Relaxed, Ordering::Relaxed) {
|
|
Ok(_) => break,
|
|
Err(p) => prev = p,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Hypervisor for MockHypervisor {
|
|
async fn prepare(&self, spec: &VmSpec, ctx: &JobContext) -> miette::Result<VmHandle> {
|
|
let now = self.active_all.fetch_add(1, Ordering::SeqCst) + 1;
|
|
Self::update_peak(&self.peak_all, now);
|
|
// per-label current/peak
|
|
let entry = self.per_curr.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0)));
|
|
let curr = entry.fetch_add(1, Ordering::SeqCst) + 1;
|
|
let peak_entry = self.per_peak.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))).clone();
|
|
Self::update_peak(&peak_entry, curr);
|
|
|
|
let id = format!("job-{}", ctx.request_id);
|
|
self.id_to_label.insert(id.clone(), spec.label.clone());
|
|
|
|
Ok(VmHandle {
|
|
id,
|
|
backend: BackendTag::Noop,
|
|
work_dir: PathBuf::from("/tmp"),
|
|
overlay_path: None,
|
|
seed_iso_path: None,
|
|
})
|
|
}
|
|
async fn start(&self, _vm: &VmHandle) -> miette::Result<()> { Ok(()) }
|
|
async fn stop(&self, _vm: &VmHandle, _t: Duration) -> miette::Result<()> { Ok(()) }
|
|
async fn destroy(&self, vm: VmHandle) -> miette::Result<()> {
|
|
// decrement overall current
|
|
self.active_all.fetch_sub(1, Ordering::SeqCst);
|
|
// decrement per-label current using id->label map
|
|
if let Some(label) = self.id_to_label.remove(&vm.id).map(|e| e.1) {
|
|
if let Some(entry) = self.per_curr.get(&label) {
|
|
let _ = entry.fetch_sub(1, Ordering::SeqCst);
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
async fn state(&self, _vm: &VmHandle) -> miette::Result<VmState> { Ok(VmState::Prepared) }
|
|
}
|
|
|
|
fn make_spec(label: &str) -> VmSpec {
|
|
VmSpec {
|
|
label: label.to_string(),
|
|
image_path: PathBuf::from("/tmp/image"),
|
|
cpu: 1,
|
|
ram_mb: 512,
|
|
disk_gb: 1,
|
|
network: None,
|
|
nocloud: true,
|
|
user_data: None,
|
|
}
|
|
}
|
|
|
|
fn make_ctx() -> JobContext {
|
|
JobContext { request_id: uuid::Uuid::new_v4(), repo_url: "https://example.com/r.git".into(), commit_sha: "deadbeef".into(), workflow_job_id: None }
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn scheduler_respects_global_concurrency() {
|
|
let active_all = Arc::new(AtomicUsize::new(0));
|
|
let peak_all = Arc::new(AtomicUsize::new(0));
|
|
let hv = MockHypervisor::new(active_all.clone(), peak_all.clone());
|
|
let hv_probe = hv.clone();
|
|
let persist = Arc::new(Persist::new(None).await.unwrap());
|
|
|
|
let mut caps = HashMap::new();
|
|
caps.insert("x".to_string(), 10);
|
|
|
|
let sched = Scheduler::new(hv, 2, &caps, persist, Duration::from_millis(10));
|
|
let tx = sched.sender();
|
|
let run = tokio::spawn(async move { let _ = sched.run().await; });
|
|
|
|
for _ in 0..5 {
|
|
let _ = tx.send(SchedItem { spec: make_spec("x"), ctx: make_ctx() }).await;
|
|
}
|
|
drop(tx);
|
|
// Allow time for tasks to execute under concurrency limits
|
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
|
assert!(hv_probe.peak_all.load(Ordering::SeqCst) <= 2, "peak should not exceed global limit");
|
|
// Stop the scheduler task
|
|
run.abort();
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn scheduler_respects_per_label_capacity() {
|
|
let active_all = Arc::new(AtomicUsize::new(0));
|
|
let peak_all = Arc::new(AtomicUsize::new(0));
|
|
let hv = MockHypervisor::new(active_all.clone(), peak_all.clone());
|
|
let hv_probe = hv.clone();
|
|
let persist = Arc::new(Persist::new(None).await.unwrap());
|
|
|
|
let mut caps = HashMap::new();
|
|
caps.insert("a".to_string(), 1);
|
|
caps.insert("b".to_string(), 2);
|
|
|
|
let sched = Scheduler::new(hv, 4, &caps, persist, Duration::from_millis(10));
|
|
let tx = sched.sender();
|
|
let run = tokio::spawn(async move { let _ = sched.run().await; });
|
|
|
|
for _ in 0..3 { let _ = tx.send(SchedItem { spec: make_spec("a"), ctx: make_ctx() }).await; }
|
|
for _ in 0..3 { let _ = tx.send(SchedItem { spec: make_spec("b"), ctx: make_ctx() }).await; }
|
|
drop(tx);
|
|
tokio::time::sleep(Duration::from_millis(800)).await;
|
|
|
|
// read per-label peaks
|
|
let a_peak = hv_probe.per_peak.get("a").map(|p| p.load(Ordering::SeqCst)).unwrap_or(0);
|
|
let b_peak = hv_probe.per_peak.get("b").map(|p| p.load(Ordering::SeqCst)).unwrap_or(0);
|
|
assert!(a_peak <= 1, "label a peak should be <= 1, got {}", a_peak);
|
|
assert!(b_peak <= 2, "label b peak should be <= 2, got {}", b_peak);
|
|
assert!(hv_probe.peak_all.load(Ordering::SeqCst) <= 4, "global peak should respect global limit");
|
|
run.abort();
|
|
}
|
|
}
|