Add configurable placeholder VM runtime and graceful shutdown logic

This commit introduces the ability to configure placeholder VM run time via an environment variable (`VM_PLACEHOLDER_RUN_SECS`) and updates the `Scheduler` to accept this duration. Additionally, it implements a graceful shutdown mechanism for the orchestrator, allowing cooperative shutdown of consumers and cleanup of resources.
This commit is contained in:
Till Wegmueller 2025-10-26 19:06:32 +01:00
parent 7918db3468
commit 6ff88529e6
No known key found for this signature in database
4 changed files with 99 additions and 47 deletions

View file

@ -6,4 +6,4 @@ pub mod mq;
pub use telemetry::{init_tracing, TelemetryGuard}; pub use telemetry::{init_tracing, TelemetryGuard};
pub use job::{Workflow, Job, Step, parse_workflow_str, parse_workflow_file}; pub use job::{Workflow, Job, Step, parse_workflow_str, parse_workflow_file};
pub use messages::{JobRequest, SourceSystem}; pub use messages::{JobRequest, SourceSystem};
pub use mq::{MqConfig, publish_job, consume_jobs}; pub use mq::{MqConfig, publish_job, consume_jobs, consume_jobs_until};

View file

@ -4,7 +4,7 @@ use futures_util::StreamExt;
use lapin::{ use lapin::{
options::{ options::{
BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions, BasicQosOptions, BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions, BasicQosOptions,
ConfirmSelectOptions, ExchangeDeclareOptions, QueueBindOptions, QueueDeclareOptions, ConfirmSelectOptions, ExchangeDeclareOptions, QueueBindOptions, QueueDeclareOptions, BasicCancelOptions,
}, },
types::{AMQPValue, FieldTable, LongString, ShortString}, types::{AMQPValue, FieldTable, LongString, ShortString},
BasicProperties, Channel, Connection, ConnectionProperties, Consumer, BasicProperties, Channel, Connection, ConnectionProperties, Consumer,
@ -180,6 +180,17 @@ pub async fn consume_jobs<F, Fut>(cfg: &MqConfig, handler: F) -> Result<()>
where where
F: Fn(JobRequest) -> Fut + Send + Sync + 'static, F: Fn(JobRequest) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<()>> + Send + 'static, Fut: std::future::Future<Output = Result<()>> + Send + 'static,
{
// Backward-compatible wrapper that runs until process exit (no cooperative shutdown)
consume_jobs_until(cfg, futures_util::future::pending(), handler).await
}
#[instrument(skip(cfg, shutdown, handler))]
pub async fn consume_jobs_until<F, Fut, S>(cfg: &MqConfig, shutdown: S, handler: F) -> Result<()>
where
F: Fn(JobRequest) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<()>> + Send + 'static,
S: std::future::Future<Output = ()> + Send + 'static,
{ {
let conn = connect(cfg).await?; let conn = connect(cfg).await?;
let channel = conn.create_channel().await.into_diagnostic()?; let channel = conn.create_channel().await.into_diagnostic()?;
@ -204,10 +215,19 @@ where
info!(queue = %cfg.queue, prefetch = cfg.prefetch, "consumer started"); info!(queue = %cfg.queue, prefetch = cfg.prefetch, "consumer started");
tokio::pin!(consumer); tokio::pin!(consumer);
let mut shutdown = Box::pin(shutdown);
while let Some(delivery) = consumer.next().await { loop {
match delivery { tokio::select! {
Ok(d) => { _ = &mut shutdown => {
info!("shutdown requested; canceling consumer");
// Best-effort cancel; ignore errors during shutdown.
let _ = channel.basic_cancel("orchestrator", BasicCancelOptions { nowait: false }).await;
break;
}
maybe_delivery = consumer.next() => {
match maybe_delivery {
Some(Ok(d)) => {
let tag = d.delivery_tag; let tag = d.delivery_tag;
match serde_json::from_slice::<JobRequest>(&d.data) { match serde_json::from_slice::<JobRequest>(&d.data) {
Ok(job) => { Ok(job) => {
@ -234,11 +254,16 @@ where
} }
} }
} }
Err(e) => { Some(Err(e)) => {
error!(error = %e, "consumer delivery error; continuing"); error!(error = %e, "consumer delivery error; continuing");
// Backoff briefly to avoid tight loop on errors
tokio::time::sleep(Duration::from_millis(200)).await; tokio::time::sleep(Duration::from_millis(200)).await;
} }
None => {
// Stream ended; exit loop.
break;
}
}
}
} }
} }

View file

@ -3,7 +3,7 @@ mod hypervisor;
mod scheduler; mod scheduler;
mod persist; mod persist;
use std::{collections::HashMap, path::PathBuf}; use std::{collections::HashMap, path::PathBuf, time::Duration};
use clap::Parser; use clap::Parser;
use miette::{IntoDiagnostic as _, Result}; use miette::{IntoDiagnostic as _, Result};
@ -69,6 +69,10 @@ struct Opts {
/// OTLP endpoint (e.g., http://localhost:4317) /// OTLP endpoint (e.g., http://localhost:4317)
#[arg(long, env = "OTEL_EXPORTER_OTLP_ENDPOINT")] #[arg(long, env = "OTEL_EXPORTER_OTLP_ENDPOINT")]
otlp: Option<String>, otlp: Option<String>,
/// Placeholder VM run time in seconds (temporary until agent wiring)
#[arg(long, env = "VM_PLACEHOLDER_RUN_SECS", default_value_t = 300)]
vm_placeholder_run_secs: u64,
} }
#[tokio::main(flavor = "multi_thread")] #[tokio::main(flavor = "multi_thread")]
@ -102,7 +106,13 @@ async fn main() -> Result<()> {
}; };
// Scheduler // Scheduler
let sched = Scheduler::new(router, opts.max_concurrency, &capacity_map, persist.clone()); let sched = Scheduler::new(
router,
opts.max_concurrency,
&capacity_map,
persist.clone(),
Duration::from_secs(opts.vm_placeholder_run_secs),
);
let sched_tx = sched.sender(); let sched_tx = sched.sender();
let scheduler_task = tokio::spawn(async move { let scheduler_task = tokio::spawn(async move {
if let Err(e) = sched.run().await { if let Err(e) = sched.run().await {
@ -115,8 +125,12 @@ async fn main() -> Result<()> {
let mq_cfg_clone = mq_cfg.clone(); let mq_cfg_clone = mq_cfg.clone();
let tx_for_consumer = sched_tx.clone(); let tx_for_consumer = sched_tx.clone();
let persist_for_consumer = persist.clone(); let persist_for_consumer = persist.clone();
// Start consumer that can be shut down cooperatively on ctrl-c
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let consumer_task = tokio::spawn(async move { let consumer_task = tokio::spawn(async move {
common::consume_jobs(&mq_cfg_clone, move |job| { common::consume_jobs_until(&mq_cfg_clone, async move {
let _ = shutdown_rx.await;
}, move |job| {
let sched_tx = tx_for_consumer.clone(); let sched_tx = tx_for_consumer.clone();
let cfg = cfg_clone.clone(); let cfg = cfg_clone.clone();
let persist = persist_for_consumer.clone(); let persist = persist_for_consumer.clone();
@ -165,13 +179,14 @@ async fn main() -> Result<()> {
// Wait for ctrl-c // Wait for ctrl-c
tokio::signal::ctrl_c().await.into_diagnostic()?; tokio::signal::ctrl_c().await.into_diagnostic()?;
// Cancel consumer to stop accepting new jobs // Signal consumer to stop accepting new jobs
consumer_task.abort(); let _ = shutdown_tx.send(());
// Drop sender to let scheduler drain and exit // Drop sender to let scheduler drain and exit
drop(sched_tx); drop(sched_tx);
// Wait for scheduler to finish draining // Wait for consumer and scheduler to finish
let _ = consumer_task.await;
let _ = scheduler_task.await; let _ = scheduler_task.await;
Ok(()) Ok(())

View file

@ -15,6 +15,7 @@ pub struct Scheduler<H: Hypervisor + 'static> {
global_sem: Arc<Semaphore>, global_sem: Arc<Semaphore>,
label_sems: Arc<DashmapType>, label_sems: Arc<DashmapType>,
persist: Arc<Persist>, persist: Arc<Persist>,
placeholder_runtime: Duration,
} }
type DashmapType = DashMap<String, Arc<Semaphore>>; type DashmapType = DashMap<String, Arc<Semaphore>>;
@ -25,7 +26,7 @@ pub struct SchedItem {
} }
impl<H: Hypervisor + 'static> Scheduler<H> { impl<H: Hypervisor + 'static> Scheduler<H> {
pub fn new(hv: H, max_concurrency: usize, capacity_map: &HashMap<String, usize>, persist: Arc<Persist>) -> Self { pub fn new(hv: H, max_concurrency: usize, capacity_map: &HashMap<String, usize>, persist: Arc<Persist>, placeholder_runtime: Duration) -> Self {
let (tx, rx) = mpsc::channel::<SchedItem>(max_concurrency * 4); let (tx, rx) = mpsc::channel::<SchedItem>(max_concurrency * 4);
let label_sems = DashMap::new(); let label_sems = DashMap::new();
for (label, cap) in capacity_map.iter() { for (label, cap) in capacity_map.iter() {
@ -38,13 +39,14 @@ impl<H: Hypervisor + 'static> Scheduler<H> {
global_sem: Arc::new(Semaphore::new(max_concurrency)), global_sem: Arc::new(Semaphore::new(max_concurrency)),
label_sems: Arc::new(label_sems), label_sems: Arc::new(label_sems),
persist, persist,
placeholder_runtime,
} }
} }
pub fn sender(&self) -> mpsc::Sender<SchedItem> { self.tx.clone() } pub fn sender(&self) -> mpsc::Sender<SchedItem> { self.tx.clone() }
pub async fn run(self) -> Result<()> { pub async fn run(self) -> Result<()> {
let Scheduler { hv, mut rx, global_sem, label_sems, persist, .. } = self; let Scheduler { hv, mut rx, global_sem, label_sems, persist, placeholder_runtime, .. } = self;
let mut handles = Vec::new(); let mut handles = Vec::new();
while let Some(item) = rx.recv().await { while let Some(item) = rx.recv().await {
let hv = hv.clone(); let hv = hv.clone();
@ -88,8 +90,8 @@ impl<H: Hypervisor + 'static> Scheduler<H> {
let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Running).await; let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Running).await;
let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Running).await; let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Running).await;
info!(request_id = %item.ctx.request_id, label = %label_key, "vm started (workload execution placeholder)"); info!(request_id = %item.ctx.request_id, label = %label_key, "vm started (workload execution placeholder)");
// Placeholder job runtime // Placeholder job runtime (configurable)
tokio::time::sleep(Duration::from_secs(1)).await; tokio::time::sleep(placeholder_runtime).await;
// Stop and destroy // Stop and destroy
if let Err(e) = hv.stop(&h, Duration::from_secs(10)).await { if let Err(e) = hv.stop(&h, Duration::from_secs(10)).await {
error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to stop VM"); error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to stop VM");
@ -134,11 +136,12 @@ mod tests {
// per-label current and peak // per-label current and peak
per_curr: Arc<DashMap<String, Arc<AtomicUsize>>>, per_curr: Arc<DashMap<String, Arc<AtomicUsize>>>,
per_peak: Arc<DashMap<String, Arc<AtomicUsize>>>, per_peak: Arc<DashMap<String, Arc<AtomicUsize>>>,
id_to_label: Arc<DashMap<String, String>>,
} }
impl MockHypervisor { impl MockHypervisor {
fn new(active_all: Arc<AtomicUsize>, peak_all: Arc<AtomicUsize>) -> Self { fn new(active_all: Arc<AtomicUsize>, peak_all: Arc<AtomicUsize>) -> Self {
Self { active_all, peak_all, per_curr: Arc::new(DashMap::new()), per_peak: Arc::new(DashMap::new()) } Self { active_all, peak_all, per_curr: Arc::new(DashMap::new()), per_peak: Arc::new(DashMap::new()), id_to_label: Arc::new(DashMap::new()) }
} }
fn update_peak(peak: &AtomicUsize, current: usize) { fn update_peak(peak: &AtomicUsize, current: usize) {
let mut prev = peak.load(Ordering::Relaxed); let mut prev = peak.load(Ordering::Relaxed);
@ -156,14 +159,17 @@ mod tests {
async fn prepare(&self, spec: &VmSpec, ctx: &JobContext) -> miette::Result<VmHandle> { async fn prepare(&self, spec: &VmSpec, ctx: &JobContext) -> miette::Result<VmHandle> {
let now = self.active_all.fetch_add(1, Ordering::SeqCst) + 1; let now = self.active_all.fetch_add(1, Ordering::SeqCst) + 1;
Self::update_peak(&self.peak_all, now); Self::update_peak(&self.peak_all, now);
// per-label // per-label current/peak
let entry = self.per_curr.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))); let entry = self.per_curr.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0)));
let curr = entry.fetch_add(1, Ordering::SeqCst) + 1; let curr = entry.fetch_add(1, Ordering::SeqCst) + 1;
let peak_entry = self.per_peak.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))).clone(); let peak_entry = self.per_peak.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))).clone();
Self::update_peak(&peak_entry, curr); Self::update_peak(&peak_entry, curr);
let id = format!("job-{}", ctx.request_id);
self.id_to_label.insert(id.clone(), spec.label.clone());
Ok(VmHandle { Ok(VmHandle {
id: format!("job-{}", ctx.request_id), id,
backend: BackendTag::Noop, backend: BackendTag::Noop,
work_dir: PathBuf::from("/tmp"), work_dir: PathBuf::from("/tmp"),
overlay_path: None, overlay_path: None,
@ -172,9 +178,15 @@ mod tests {
} }
async fn start(&self, _vm: &VmHandle) -> miette::Result<()> { Ok(()) } async fn start(&self, _vm: &VmHandle) -> miette::Result<()> { Ok(()) }
async fn stop(&self, _vm: &VmHandle, _t: Duration) -> miette::Result<()> { Ok(()) } async fn stop(&self, _vm: &VmHandle, _t: Duration) -> miette::Result<()> { Ok(()) }
async fn destroy(&self, _vm: VmHandle) -> miette::Result<()> { async fn destroy(&self, vm: VmHandle) -> miette::Result<()> {
// decrement overall current // decrement overall current
self.active_all.fetch_sub(1, Ordering::SeqCst); self.active_all.fetch_sub(1, Ordering::SeqCst);
// decrement per-label current using id->label map
if let Some(label) = self.id_to_label.remove(&vm.id).map(|e| e.1) {
if let Some(entry) = self.per_curr.get(&label) {
let _ = entry.fetch_sub(1, Ordering::SeqCst);
}
}
Ok(()) Ok(())
} }
async fn state(&self, _vm: &VmHandle) -> miette::Result<VmState> { Ok(VmState::Prepared) } async fn state(&self, _vm: &VmHandle) -> miette::Result<VmState> { Ok(VmState::Prepared) }
@ -208,7 +220,7 @@ mod tests {
let mut caps = HashMap::new(); let mut caps = HashMap::new();
caps.insert("x".to_string(), 10); caps.insert("x".to_string(), 10);
let sched = Scheduler::new(hv, 2, &caps, persist); let sched = Scheduler::new(hv, 2, &caps, persist, Duration::from_millis(10));
let tx = sched.sender(); let tx = sched.sender();
let run = tokio::spawn(async move { let _ = sched.run().await; }); let run = tokio::spawn(async move { let _ = sched.run().await; });
@ -235,7 +247,7 @@ mod tests {
caps.insert("a".to_string(), 1); caps.insert("a".to_string(), 1);
caps.insert("b".to_string(), 2); caps.insert("b".to_string(), 2);
let sched = Scheduler::new(hv, 4, &caps, persist); let sched = Scheduler::new(hv, 4, &caps, persist, Duration::from_millis(10));
let tx = sched.sender(); let tx = sched.sender();
let run = tokio::spawn(async move { let _ = sched.run().await; }); let run = tokio::spawn(async move { let _ = sched.run().await; });