solstice-ci/crates/orchestrator/src/scheduler.rs

299 lines
15 KiB
Rust
Raw Normal View History

use std::{collections::HashMap, sync::Arc, time::Duration};
use dashmap::DashMap;
use miette::Result;
use tokio::sync::{mpsc, Semaphore, Notify};
use tracing::{error, info, warn};
use crate::hypervisor::{Hypervisor, VmSpec, JobContext, BackendTag};
use crate::persist::{Persist, VmPersistState, JobState};
pub struct Scheduler<H: Hypervisor + 'static> {
hv: Arc<H>,
tx: mpsc::Sender<SchedItem>,
rx: mpsc::Receiver<SchedItem>,
global_sem: Arc<Semaphore>,
label_sems: Arc<DashmapType>,
persist: Arc<Persist>,
placeholder_runtime: Duration,
}
type DashmapType = DashMap<String, Arc<Semaphore>>;
pub struct SchedItem {
pub spec: VmSpec,
pub ctx: JobContext,
}
impl<H: Hypervisor + 'static> Scheduler<H> {
pub fn new(hv: H, max_concurrency: usize, capacity_map: &HashMap<String, usize>, persist: Arc<Persist>, placeholder_runtime: Duration) -> Self {
let (tx, rx) = mpsc::channel::<SchedItem>(max_concurrency * 4);
let label_sems = DashMap::new();
for (label, cap) in capacity_map.iter() {
label_sems.insert(label.clone(), Arc::new(Semaphore::new(*cap)));
}
Self {
hv: Arc::new(hv),
tx,
rx,
global_sem: Arc::new(Semaphore::new(max_concurrency)),
label_sems: Arc::new(label_sems),
persist,
placeholder_runtime,
}
}
pub fn sender(&self) -> mpsc::Sender<SchedItem> { self.tx.clone() }
pub async fn run_with_shutdown(self, shutdown: Arc<Notify>) -> Result<()> {
let Scheduler { hv, mut rx, global_sem, label_sems, persist, placeholder_runtime, .. } = self;
let mut handles = Vec::new();
let mut shutting_down = false;
'scheduler: loop {
tokio::select! {
// Only react once to shutdown and then focus on draining the channel
_ = shutdown.notified(), if !shutting_down => {
shutting_down = true;
info!("scheduler: shutdown requested; will drain queue and interrupt placeholders");
break 'scheduler;
}
maybe_item = rx.recv() => {
match maybe_item {
Some(item) => {
let hv = hv.clone();
let global = global_sem.clone();
let label_sems = label_sems.clone();
let persist = persist.clone();
let shutdown = shutdown.clone();
let handle = tokio::spawn(async move {
// Acquire global and label permits (owned permits so they live inside the task)
let _g = match global.acquire_owned().await { Ok(p) => p, Err(_) => return };
let label_key = item.spec.label.clone();
let sem_arc = if let Some(entry) = label_sems.get(&label_key) { entry.clone() } else {
let s = Arc::new(Semaphore::new(1));
label_sems.insert(label_key.clone(), s.clone());
s
};
let _l = match sem_arc.acquire_owned().await { Ok(p) => p, Err(_) => return };
// Provision and run
match hv.prepare(&item.spec, &item.ctx).await {
Ok(h) => {
// persist prepared VM state
let overlay = h.overlay_path.as_ref().and_then(|p| p.to_str());
let seed = h.seed_iso_path.as_ref().and_then(|p| p.to_str());
let backend = match h.backend { BackendTag::Noop => Some("noop"), #[cfg(all(target_os = "linux", feature = "libvirt"))] BackendTag::Libvirt => Some("libvirt"), #[cfg(target_os = "illumos")] BackendTag::Zones => Some("zones") };
if let Err(e) = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Prepared).await {
warn!(error = %e, request_id = %item.ctx.request_id, domain = %h.id, "persist prepare failed");
}
if let Err(e) = hv.start(&h).await {
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to start VM");
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await;
return;
}
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Running).await;
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Running).await;
info!(request_id = %item.ctx.request_id, label = %label_key, "vm started (monitoring for completion)");
// Monitor VM state until it stops or until placeholder_runtime elapses; end early on shutdown
let start_time = std::time::Instant::now();
loop {
// Check current state first
if let Ok(crate::hypervisor::VmState::Stopped) = hv.state(&h).await {
info!(request_id = %item.ctx.request_id, label = %label_key, "vm reported stopped");
break;
}
if start_time.elapsed() >= placeholder_runtime { break; }
// Wait either for shutdown signal or a short delay before next poll
tokio::select! {
_ = shutdown.notified() => {
info!(request_id = %item.ctx.request_id, label = %label_key, "shutdown: ending early");
break;
}
_ = tokio::time::sleep(Duration::from_secs(2)) => {}
}
}
// Stop and destroy
if let Err(e) = hv.stop(&h, Duration::from_secs(10)).await {
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to stop VM");
}
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Stopped).await;
if let Err(e) = hv.destroy(h.clone()).await {
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to destroy VM");
}
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Destroyed).await;
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Succeeded).await;
}
Err(e) => {
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to prepare VM");
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await;
return;
}
}
});
handles.push(handle);
}
None => break,
}
}
}
}
// Wait for all in-flight tasks to finish
for h in handles { let _ = h.await; }
info!("scheduler: completed");
Ok(())
}
pub async fn run(self) -> Result<()> {
// Backward-compat: run until channel closed (no external shutdown)
self.run_with_shutdown(Arc::new(Notify::new())).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use async_trait::async_trait;
use dashmap::DashMap;
use crate::hypervisor::{VmHandle, VmState};
#[derive(Clone)]
struct MockHypervisor {
active_all: Arc<AtomicUsize>,
peak_all: Arc<AtomicUsize>,
// per-label current and peak
per_curr: Arc<DashMap<String, Arc<AtomicUsize>>>,
per_peak: Arc<DashMap<String, Arc<AtomicUsize>>>,
id_to_label: Arc<DashMap<String, String>>,
}
impl MockHypervisor {
fn new(active_all: Arc<AtomicUsize>, peak_all: Arc<AtomicUsize>) -> Self {
Self { active_all, peak_all, per_curr: Arc::new(DashMap::new()), per_peak: Arc::new(DashMap::new()), id_to_label: Arc::new(DashMap::new()) }
}
fn update_peak(peak: &AtomicUsize, current: usize) {
let mut prev = peak.load(Ordering::Relaxed);
while current > prev {
match peak.compare_exchange(prev, current, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => break,
Err(p) => prev = p,
}
}
}
}
#[async_trait]
impl Hypervisor for MockHypervisor {
async fn prepare(&self, spec: &VmSpec, ctx: &JobContext) -> miette::Result<VmHandle> {
let now = self.active_all.fetch_add(1, Ordering::SeqCst) + 1;
Self::update_peak(&self.peak_all, now);
// per-label current/peak
let entry = self.per_curr.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0)));
let curr = entry.fetch_add(1, Ordering::SeqCst) + 1;
let peak_entry = self.per_peak.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))).clone();
Self::update_peak(&peak_entry, curr);
let id = format!("job-{}", ctx.request_id);
self.id_to_label.insert(id.clone(), spec.label.clone());
Ok(VmHandle {
id,
backend: BackendTag::Noop,
work_dir: PathBuf::from("/tmp"),
overlay_path: None,
seed_iso_path: None,
})
}
async fn start(&self, _vm: &VmHandle) -> miette::Result<()> { Ok(()) }
async fn stop(&self, _vm: &VmHandle, _t: Duration) -> miette::Result<()> { Ok(()) }
async fn destroy(&self, vm: VmHandle) -> miette::Result<()> {
// decrement overall current
self.active_all.fetch_sub(1, Ordering::SeqCst);
// decrement per-label current using id->label map
if let Some(label) = self.id_to_label.remove(&vm.id).map(|e| e.1) {
if let Some(entry) = self.per_curr.get(&label) {
let _ = entry.fetch_sub(1, Ordering::SeqCst);
}
}
Ok(())
}
async fn state(&self, _vm: &VmHandle) -> miette::Result<VmState> { Ok(VmState::Prepared) }
}
fn make_spec(label: &str) -> VmSpec {
VmSpec {
label: label.to_string(),
image_path: PathBuf::from("/tmp/image"),
cpu: 1,
ram_mb: 512,
disk_gb: 1,
network: None,
nocloud: true,
user_data: None,
}
}
fn make_ctx() -> JobContext {
JobContext { request_id: uuid::Uuid::new_v4(), repo_url: "https://example.com/r.git".into(), commit_sha: "deadbeef".into(), workflow_job_id: None }
}
#[tokio::test(flavor = "multi_thread")]
async fn scheduler_respects_global_concurrency() {
let active_all = Arc::new(AtomicUsize::new(0));
let peak_all = Arc::new(AtomicUsize::new(0));
let hv = MockHypervisor::new(active_all.clone(), peak_all.clone());
let hv_probe = hv.clone();
let persist = Arc::new(Persist::new(None).await.unwrap());
let mut caps = HashMap::new();
caps.insert("x".to_string(), 10);
let sched = Scheduler::new(hv, 2, &caps, persist, Duration::from_millis(10));
let tx = sched.sender();
let run = tokio::spawn(async move { let _ = sched.run().await; });
for _ in 0..5 {
let _ = tx.send(SchedItem { spec: make_spec("x"), ctx: make_ctx() }).await;
}
drop(tx);
// Allow time for tasks to execute under concurrency limits
tokio::time::sleep(Duration::from_millis(500)).await;
assert!(hv_probe.peak_all.load(Ordering::SeqCst) <= 2, "peak should not exceed global limit");
// Stop the scheduler task
run.abort();
}
#[tokio::test(flavor = "multi_thread")]
async fn scheduler_respects_per_label_capacity() {
let active_all = Arc::new(AtomicUsize::new(0));
let peak_all = Arc::new(AtomicUsize::new(0));
let hv = MockHypervisor::new(active_all.clone(), peak_all.clone());
let hv_probe = hv.clone();
let persist = Arc::new(Persist::new(None).await.unwrap());
let mut caps = HashMap::new();
caps.insert("a".to_string(), 1);
caps.insert("b".to_string(), 2);
let sched = Scheduler::new(hv, 4, &caps, persist, Duration::from_millis(10));
let tx = sched.sender();
let run = tokio::spawn(async move { let _ = sched.run().await; });
for _ in 0..3 { let _ = tx.send(SchedItem { spec: make_spec("a"), ctx: make_ctx() }).await; }
for _ in 0..3 { let _ = tx.send(SchedItem { spec: make_spec("b"), ctx: make_ctx() }).await; }
drop(tx);
tokio::time::sleep(Duration::from_millis(800)).await;
// read per-label peaks
let a_peak = hv_probe.per_peak.get("a").map(|p| p.load(Ordering::SeqCst)).unwrap_or(0);
let b_peak = hv_probe.per_peak.get("b").map(|p| p.load(Ordering::SeqCst)).unwrap_or(0);
assert!(a_peak <= 1, "label a peak should be <= 1, got {}", a_peak);
assert!(b_peak <= 2, "label b peak should be <= 2, got {}", b_peak);
assert!(hv_probe.peak_all.load(Ordering::SeqCst) <= 4, "global peak should respect global limit");
run.abort();
}
}