use std::{collections::HashMap, sync::Arc, time::Duration}; use dashmap::DashMap; use miette::Result; use tokio::sync::{mpsc, Semaphore, Notify}; use tracing::{error, info, warn}; use crate::hypervisor::{Hypervisor, VmSpec, JobContext, BackendTag}; use crate::persist::{Persist, VmPersistState, JobState}; pub struct Scheduler { hv: Arc, tx: mpsc::Sender, rx: mpsc::Receiver, global_sem: Arc, label_sems: Arc, persist: Arc, placeholder_runtime: Duration, } type DashmapType = DashMap>; pub struct SchedItem { pub spec: VmSpec, pub ctx: JobContext, } impl Scheduler { pub fn new(hv: H, max_concurrency: usize, capacity_map: &HashMap, persist: Arc, placeholder_runtime: Duration) -> Self { let (tx, rx) = mpsc::channel::(max_concurrency * 4); let label_sems = DashMap::new(); for (label, cap) in capacity_map.iter() { label_sems.insert(label.clone(), Arc::new(Semaphore::new(*cap))); } Self { hv: Arc::new(hv), tx, rx, global_sem: Arc::new(Semaphore::new(max_concurrency)), label_sems: Arc::new(label_sems), persist, placeholder_runtime, } } pub fn sender(&self) -> mpsc::Sender { self.tx.clone() } pub async fn run_with_shutdown(self, shutdown: Arc) -> Result<()> { let Scheduler { hv, mut rx, global_sem, label_sems, persist, placeholder_runtime, .. } = self; let mut handles = Vec::new(); let mut shutting_down = false; 'scheduler: loop { tokio::select! { // Only react once to shutdown and then focus on draining the channel _ = shutdown.notified(), if !shutting_down => { shutting_down = true; info!("scheduler: shutdown requested; will drain queue and interrupt placeholders"); break 'scheduler; } maybe_item = rx.recv() => { match maybe_item { Some(item) => { let hv = hv.clone(); let global = global_sem.clone(); let label_sems = label_sems.clone(); let persist = persist.clone(); let shutdown = shutdown.clone(); let handle = tokio::spawn(async move { // Acquire global and label permits (owned permits so they live inside the task) let _g = match global.acquire_owned().await { Ok(p) => p, Err(_) => return }; let label_key = item.spec.label.clone(); let sem_arc = if let Some(entry) = label_sems.get(&label_key) { entry.clone() } else { let s = Arc::new(Semaphore::new(1)); label_sems.insert(label_key.clone(), s.clone()); s }; let _l = match sem_arc.acquire_owned().await { Ok(p) => p, Err(_) => return }; // Provision and run match hv.prepare(&item.spec, &item.ctx).await { Ok(h) => { // persist prepared VM state let overlay = h.overlay_path.as_ref().and_then(|p| p.to_str()); let seed = h.seed_iso_path.as_ref().and_then(|p| p.to_str()); let backend = match h.backend { BackendTag::Noop => Some("noop"), #[cfg(all(target_os = "linux", feature = "libvirt"))] BackendTag::Libvirt => Some("libvirt"), #[cfg(target_os = "illumos")] BackendTag::Zones => Some("zones") }; if let Err(e) = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Prepared).await { warn!(error = %e, request_id = %item.ctx.request_id, domain = %h.id, "persist prepare failed"); } if let Err(e) = hv.start(&h).await { error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to start VM"); let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await; return; } let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Running).await; let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Running).await; info!(request_id = %item.ctx.request_id, label = %label_key, "vm started (monitoring for completion)"); // Monitor VM state until it stops or until placeholder_runtime elapses; end early on shutdown let start_time = std::time::Instant::now(); loop { // Check current state first if let Ok(crate::hypervisor::VmState::Stopped) = hv.state(&h).await { info!(request_id = %item.ctx.request_id, label = %label_key, "vm reported stopped"); break; } if start_time.elapsed() >= placeholder_runtime { break; } // Wait either for shutdown signal or a short delay before next poll tokio::select! { _ = shutdown.notified() => { info!(request_id = %item.ctx.request_id, label = %label_key, "shutdown: ending early"); break; } _ = tokio::time::sleep(Duration::from_secs(2)) => {} } } // Stop and destroy if let Err(e) = hv.stop(&h, Duration::from_secs(10)).await { error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to stop VM"); } let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Stopped).await; if let Err(e) = hv.destroy(h.clone()).await { error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to destroy VM"); } let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Destroyed).await; let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Succeeded).await; } Err(e) => { error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to prepare VM"); let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await; return; } } }); handles.push(handle); } None => break, } } } } // Wait for all in-flight tasks to finish for h in handles { let _ = h.await; } info!("scheduler: completed"); Ok(()) } pub async fn run(self) -> Result<()> { // Backward-compat: run until channel closed (no external shutdown) self.run_with_shutdown(Arc::new(Notify::new())).await } } #[cfg(test)] mod tests { use super::*; use std::path::PathBuf; use std::sync::atomic::{AtomicUsize, Ordering}; use async_trait::async_trait; use dashmap::DashMap; use crate::hypervisor::{VmHandle, VmState}; #[derive(Clone)] struct MockHypervisor { active_all: Arc, peak_all: Arc, // per-label current and peak per_curr: Arc>>, per_peak: Arc>>, id_to_label: Arc>, } impl MockHypervisor { fn new(active_all: Arc, peak_all: Arc) -> Self { Self { active_all, peak_all, per_curr: Arc::new(DashMap::new()), per_peak: Arc::new(DashMap::new()), id_to_label: Arc::new(DashMap::new()) } } fn update_peak(peak: &AtomicUsize, current: usize) { let mut prev = peak.load(Ordering::Relaxed); while current > prev { match peak.compare_exchange(prev, current, Ordering::Relaxed, Ordering::Relaxed) { Ok(_) => break, Err(p) => prev = p, } } } } #[async_trait] impl Hypervisor for MockHypervisor { async fn prepare(&self, spec: &VmSpec, ctx: &JobContext) -> miette::Result { let now = self.active_all.fetch_add(1, Ordering::SeqCst) + 1; Self::update_peak(&self.peak_all, now); // per-label current/peak let entry = self.per_curr.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))); let curr = entry.fetch_add(1, Ordering::SeqCst) + 1; let peak_entry = self.per_peak.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))).clone(); Self::update_peak(&peak_entry, curr); let id = format!("job-{}", ctx.request_id); self.id_to_label.insert(id.clone(), spec.label.clone()); Ok(VmHandle { id, backend: BackendTag::Noop, work_dir: PathBuf::from("/tmp"), overlay_path: None, seed_iso_path: None, }) } async fn start(&self, _vm: &VmHandle) -> miette::Result<()> { Ok(()) } async fn stop(&self, _vm: &VmHandle, _t: Duration) -> miette::Result<()> { Ok(()) } async fn destroy(&self, vm: VmHandle) -> miette::Result<()> { // decrement overall current self.active_all.fetch_sub(1, Ordering::SeqCst); // decrement per-label current using id->label map if let Some(label) = self.id_to_label.remove(&vm.id).map(|e| e.1) { if let Some(entry) = self.per_curr.get(&label) { let _ = entry.fetch_sub(1, Ordering::SeqCst); } } Ok(()) } async fn state(&self, _vm: &VmHandle) -> miette::Result { Ok(VmState::Prepared) } } fn make_spec(label: &str) -> VmSpec { VmSpec { label: label.to_string(), image_path: PathBuf::from("/tmp/image"), cpu: 1, ram_mb: 512, disk_gb: 1, network: None, nocloud: true, user_data: None, } } fn make_ctx() -> JobContext { JobContext { request_id: uuid::Uuid::new_v4(), repo_url: "https://example.com/r.git".into(), commit_sha: "deadbeef".into(), workflow_job_id: None } } #[tokio::test(flavor = "multi_thread")] async fn scheduler_respects_global_concurrency() { let active_all = Arc::new(AtomicUsize::new(0)); let peak_all = Arc::new(AtomicUsize::new(0)); let hv = MockHypervisor::new(active_all.clone(), peak_all.clone()); let hv_probe = hv.clone(); let persist = Arc::new(Persist::new(None).await.unwrap()); let mut caps = HashMap::new(); caps.insert("x".to_string(), 10); let sched = Scheduler::new(hv, 2, &caps, persist, Duration::from_millis(10)); let tx = sched.sender(); let run = tokio::spawn(async move { let _ = sched.run().await; }); for _ in 0..5 { let _ = tx.send(SchedItem { spec: make_spec("x"), ctx: make_ctx() }).await; } drop(tx); // Allow time for tasks to execute under concurrency limits tokio::time::sleep(Duration::from_millis(500)).await; assert!(hv_probe.peak_all.load(Ordering::SeqCst) <= 2, "peak should not exceed global limit"); // Stop the scheduler task run.abort(); } #[tokio::test(flavor = "multi_thread")] async fn scheduler_respects_per_label_capacity() { let active_all = Arc::new(AtomicUsize::new(0)); let peak_all = Arc::new(AtomicUsize::new(0)); let hv = MockHypervisor::new(active_all.clone(), peak_all.clone()); let hv_probe = hv.clone(); let persist = Arc::new(Persist::new(None).await.unwrap()); let mut caps = HashMap::new(); caps.insert("a".to_string(), 1); caps.insert("b".to_string(), 2); let sched = Scheduler::new(hv, 4, &caps, persist, Duration::from_millis(10)); let tx = sched.sender(); let run = tokio::spawn(async move { let _ = sched.run().await; }); for _ in 0..3 { let _ = tx.send(SchedItem { spec: make_spec("a"), ctx: make_ctx() }).await; } for _ in 0..3 { let _ = tx.send(SchedItem { spec: make_spec("b"), ctx: make_ctx() }).await; } drop(tx); tokio::time::sleep(Duration::from_millis(800)).await; // read per-label peaks let a_peak = hv_probe.per_peak.get("a").map(|p| p.load(Ordering::SeqCst)).unwrap_or(0); let b_peak = hv_probe.per_peak.get("b").map(|p| p.load(Ordering::SeqCst)).unwrap_or(0); assert!(a_peak <= 1, "label a peak should be <= 1, got {}", a_peak); assert!(b_peak <= 2, "label b peak should be <= 2, got {}", b_peak); assert!(hv_probe.peak_all.load(Ordering::SeqCst) <= 4, "global peak should respect global limit"); run.abort(); } }