mirror of
https://codeberg.org/Toasterson/solstice-ci.git
synced 2026-04-11 05:40:41 +00:00
This commit introduces a persistence layer to the Orchestrator, enabling it to optionally connect to a Postgres database for recording job and VM states. It includes: - SeaORM integration with support for migrations from the migration crate. - `Persist` module with methods for job and VM state upserts. - No-op fallback when persistence is disabled or unavailable. - Documentation updates and test coverage for persistence functionality.
255 lines
11 KiB
Rust
255 lines
11 KiB
Rust
use std::{collections::HashMap, sync::Arc, time::Duration};
|
|
|
|
use dashmap::DashMap;
|
|
use miette::Result;
|
|
use tokio::sync::{mpsc, Semaphore};
|
|
use tracing::{error, info, warn};
|
|
|
|
use crate::hypervisor::{Hypervisor, VmSpec, JobContext, BackendTag};
|
|
use crate::persist::{Persist, VmPersistState, JobState};
|
|
|
|
pub struct Scheduler<H: Hypervisor + 'static> {
|
|
hv: Arc<H>,
|
|
tx: mpsc::Sender<SchedItem>,
|
|
rx: mpsc::Receiver<SchedItem>,
|
|
global_sem: Arc<Semaphore>,
|
|
label_sems: Arc<DashmapType>,
|
|
persist: Arc<Persist>,
|
|
}
|
|
|
|
type DashmapType = DashMap<String, Arc<Semaphore>>;
|
|
|
|
pub struct SchedItem {
|
|
pub spec: VmSpec,
|
|
pub ctx: JobContext,
|
|
}
|
|
|
|
impl<H: Hypervisor + 'static> Scheduler<H> {
|
|
pub fn new(hv: H, max_concurrency: usize, capacity_map: &HashMap<String, usize>, persist: Arc<Persist>) -> Self {
|
|
let (tx, rx) = mpsc::channel::<SchedItem>(max_concurrency * 4);
|
|
let label_sems = DashMap::new();
|
|
for (label, cap) in capacity_map.iter() {
|
|
label_sems.insert(label.clone(), Arc::new(Semaphore::new(*cap)));
|
|
}
|
|
Self {
|
|
hv: Arc::new(hv),
|
|
tx,
|
|
rx,
|
|
global_sem: Arc::new(Semaphore::new(max_concurrency)),
|
|
label_sems: Arc::new(label_sems),
|
|
persist,
|
|
}
|
|
}
|
|
|
|
pub fn sender(&self) -> mpsc::Sender<SchedItem> { self.tx.clone() }
|
|
|
|
pub async fn run(self) -> Result<()> {
|
|
let Scheduler { hv, mut rx, global_sem, label_sems, persist, .. } = self;
|
|
let mut handles = Vec::new();
|
|
while let Some(item) = rx.recv().await {
|
|
let hv = hv.clone();
|
|
let global = global_sem.clone();
|
|
let label_sems = label_sems.clone();
|
|
let persist = persist.clone();
|
|
let handle = tokio::spawn(async move {
|
|
// Acquire global and label permits (owned permits so they live inside the task)
|
|
let _g = match global.acquire_owned().await {
|
|
Ok(p) => p,
|
|
Err(_) => return,
|
|
};
|
|
let label_key = item.spec.label.clone();
|
|
let sem_arc = if let Some(entry) = label_sems.get(&label_key) {
|
|
entry.clone()
|
|
} else {
|
|
let s = Arc::new(Semaphore::new(1));
|
|
label_sems.insert(label_key.clone(), s.clone());
|
|
s
|
|
};
|
|
let _l = match sem_arc.acquire_owned().await {
|
|
Ok(p) => p,
|
|
Err(_) => return,
|
|
};
|
|
|
|
// Provision and run
|
|
match hv.prepare(&item.spec, &item.ctx).await {
|
|
Ok(h) => {
|
|
// persist prepared VM state
|
|
let overlay = h.overlay_path.as_ref().and_then(|p| p.to_str());
|
|
let seed = h.seed_iso_path.as_ref().and_then(|p| p.to_str());
|
|
let backend = match h.backend { BackendTag::Noop => Some("noop"), #[cfg(all(target_os = "linux", feature = "libvirt"))] BackendTag::Libvirt => Some("libvirt"), #[cfg(target_os = "illumos")] BackendTag::Zones => Some("zones") };
|
|
if let Err(e) = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Prepared).await {
|
|
warn!(error = %e, request_id = %item.ctx.request_id, domain = %h.id, "persist prepare failed");
|
|
}
|
|
if let Err(e) = hv.start(&h).await {
|
|
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to start VM");
|
|
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await;
|
|
return;
|
|
}
|
|
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Running).await;
|
|
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Running).await;
|
|
info!(request_id = %item.ctx.request_id, label = %label_key, "vm started (workload execution placeholder)");
|
|
// Placeholder job runtime
|
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
|
// Stop and destroy
|
|
if let Err(e) = hv.stop(&h, Duration::from_secs(10)).await {
|
|
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to stop VM");
|
|
}
|
|
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Stopped).await;
|
|
if let Err(e) = hv.destroy(h.clone()).await {
|
|
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to destroy VM");
|
|
}
|
|
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Destroyed).await;
|
|
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Succeeded).await;
|
|
}
|
|
Err(e) => {
|
|
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to prepare VM");
|
|
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await;
|
|
return;
|
|
}
|
|
}
|
|
});
|
|
handles.push(handle);
|
|
}
|
|
// Wait for all in-flight tasks to finish
|
|
for h in handles { let _ = h.await; }
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::path::PathBuf;
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
use async_trait::async_trait;
|
|
use dashmap::DashMap;
|
|
|
|
use crate::hypervisor::{VmHandle, VmState};
|
|
|
|
#[derive(Clone)]
|
|
struct MockHypervisor {
|
|
active_all: Arc<AtomicUsize>,
|
|
peak_all: Arc<AtomicUsize>,
|
|
// per-label current and peak
|
|
per_curr: Arc<DashMap<String, Arc<AtomicUsize>>>,
|
|
per_peak: Arc<DashMap<String, Arc<AtomicUsize>>>,
|
|
}
|
|
|
|
impl MockHypervisor {
|
|
fn new(active_all: Arc<AtomicUsize>, peak_all: Arc<AtomicUsize>) -> Self {
|
|
Self { active_all, peak_all, per_curr: Arc::new(DashMap::new()), per_peak: Arc::new(DashMap::new()) }
|
|
}
|
|
fn update_peak(peak: &AtomicUsize, current: usize) {
|
|
let mut prev = peak.load(Ordering::Relaxed);
|
|
while current > prev {
|
|
match peak.compare_exchange(prev, current, Ordering::Relaxed, Ordering::Relaxed) {
|
|
Ok(_) => break,
|
|
Err(p) => prev = p,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Hypervisor for MockHypervisor {
|
|
async fn prepare(&self, spec: &VmSpec, ctx: &JobContext) -> miette::Result<VmHandle> {
|
|
let now = self.active_all.fetch_add(1, Ordering::SeqCst) + 1;
|
|
Self::update_peak(&self.peak_all, now);
|
|
// per-label
|
|
let entry = self.per_curr.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0)));
|
|
let curr = entry.fetch_add(1, Ordering::SeqCst) + 1;
|
|
let peak_entry = self.per_peak.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))).clone();
|
|
Self::update_peak(&peak_entry, curr);
|
|
|
|
Ok(VmHandle {
|
|
id: format!("job-{}", ctx.request_id),
|
|
backend: BackendTag::Noop,
|
|
work_dir: PathBuf::from("/tmp"),
|
|
overlay_path: None,
|
|
seed_iso_path: None,
|
|
})
|
|
}
|
|
async fn start(&self, _vm: &VmHandle) -> miette::Result<()> { Ok(()) }
|
|
async fn stop(&self, _vm: &VmHandle, _t: Duration) -> miette::Result<()> { Ok(()) }
|
|
async fn destroy(&self, _vm: VmHandle) -> miette::Result<()> {
|
|
// decrement overall current
|
|
self.active_all.fetch_sub(1, Ordering::SeqCst);
|
|
Ok(())
|
|
}
|
|
async fn state(&self, _vm: &VmHandle) -> miette::Result<VmState> { Ok(VmState::Prepared) }
|
|
}
|
|
|
|
fn make_spec(label: &str) -> VmSpec {
|
|
VmSpec {
|
|
label: label.to_string(),
|
|
image_path: PathBuf::from("/tmp/image"),
|
|
cpu: 1,
|
|
ram_mb: 512,
|
|
disk_gb: 1,
|
|
network: None,
|
|
nocloud: true,
|
|
user_data: None,
|
|
}
|
|
}
|
|
|
|
fn make_ctx() -> JobContext {
|
|
JobContext { request_id: uuid::Uuid::new_v4(), repo_url: "https://example.com/r.git".into(), commit_sha: "deadbeef".into(), workflow_job_id: None }
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn scheduler_respects_global_concurrency() {
|
|
let active_all = Arc::new(AtomicUsize::new(0));
|
|
let peak_all = Arc::new(AtomicUsize::new(0));
|
|
let hv = MockHypervisor::new(active_all.clone(), peak_all.clone());
|
|
let hv_probe = hv.clone();
|
|
let persist = Arc::new(Persist::new(None).await.unwrap());
|
|
|
|
let mut caps = HashMap::new();
|
|
caps.insert("x".to_string(), 10);
|
|
|
|
let sched = Scheduler::new(hv, 2, &caps, persist);
|
|
let tx = sched.sender();
|
|
let run = tokio::spawn(async move { let _ = sched.run().await; });
|
|
|
|
for _ in 0..5 {
|
|
let _ = tx.send(SchedItem { spec: make_spec("x"), ctx: make_ctx() }).await;
|
|
}
|
|
drop(tx);
|
|
// Allow time for tasks to execute under concurrency limits
|
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
|
assert!(hv_probe.peak_all.load(Ordering::SeqCst) <= 2, "peak should not exceed global limit");
|
|
// Stop the scheduler task
|
|
run.abort();
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn scheduler_respects_per_label_capacity() {
|
|
let active_all = Arc::new(AtomicUsize::new(0));
|
|
let peak_all = Arc::new(AtomicUsize::new(0));
|
|
let hv = MockHypervisor::new(active_all.clone(), peak_all.clone());
|
|
let hv_probe = hv.clone();
|
|
let persist = Arc::new(Persist::new(None).await.unwrap());
|
|
|
|
let mut caps = HashMap::new();
|
|
caps.insert("a".to_string(), 1);
|
|
caps.insert("b".to_string(), 2);
|
|
|
|
let sched = Scheduler::new(hv, 4, &caps, persist);
|
|
let tx = sched.sender();
|
|
let run = tokio::spawn(async move { let _ = sched.run().await; });
|
|
|
|
for _ in 0..3 { let _ = tx.send(SchedItem { spec: make_spec("a"), ctx: make_ctx() }).await; }
|
|
for _ in 0..3 { let _ = tx.send(SchedItem { spec: make_spec("b"), ctx: make_ctx() }).await; }
|
|
drop(tx);
|
|
tokio::time::sleep(Duration::from_millis(800)).await;
|
|
|
|
// read per-label peaks
|
|
let a_peak = hv_probe.per_peak.get("a").map(|p| p.load(Ordering::SeqCst)).unwrap_or(0);
|
|
let b_peak = hv_probe.per_peak.get("b").map(|p| p.load(Ordering::SeqCst)).unwrap_or(0);
|
|
assert!(a_peak <= 1, "label a peak should be <= 1, got {}", a_peak);
|
|
assert!(b_peak <= 2, "label b peak should be <= 2, got {}", b_peak);
|
|
assert!(hv_probe.peak_all.load(Ordering::SeqCst) <= 4, "global peak should respect global limit");
|
|
run.abort();
|
|
}
|
|
}
|