diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index a4bcf17..311c814 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -6,4 +6,4 @@ pub mod mq; pub use telemetry::{init_tracing, TelemetryGuard}; pub use job::{Workflow, Job, Step, parse_workflow_str, parse_workflow_file}; pub use messages::{JobRequest, SourceSystem}; -pub use mq::{MqConfig, publish_job, consume_jobs}; +pub use mq::{MqConfig, publish_job, consume_jobs, consume_jobs_until}; diff --git a/crates/common/src/mq.rs b/crates/common/src/mq.rs index 460a508..a886cf6 100644 --- a/crates/common/src/mq.rs +++ b/crates/common/src/mq.rs @@ -4,7 +4,7 @@ use futures_util::StreamExt; use lapin::{ options::{ BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions, BasicQosOptions, - ConfirmSelectOptions, ExchangeDeclareOptions, QueueBindOptions, QueueDeclareOptions, + ConfirmSelectOptions, ExchangeDeclareOptions, QueueBindOptions, QueueDeclareOptions, BasicCancelOptions, }, types::{AMQPValue, FieldTable, LongString, ShortString}, BasicProperties, Channel, Connection, ConnectionProperties, Consumer, @@ -180,6 +180,17 @@ pub async fn consume_jobs(cfg: &MqConfig, handler: F) -> Result<()> where F: Fn(JobRequest) -> Fut + Send + Sync + 'static, Fut: std::future::Future> + Send + 'static, +{ + // Backward-compatible wrapper that runs until process exit (no cooperative shutdown) + consume_jobs_until(cfg, futures_util::future::pending(), handler).await +} + +#[instrument(skip(cfg, shutdown, handler))] +pub async fn consume_jobs_until(cfg: &MqConfig, shutdown: S, handler: F) -> Result<()> +where + F: Fn(JobRequest) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, + S: std::future::Future + Send + 'static, { let conn = connect(cfg).await?; let channel = conn.create_channel().await.into_diagnostic()?; @@ -204,41 +215,55 @@ where info!(queue = %cfg.queue, prefetch = cfg.prefetch, "consumer started"); tokio::pin!(consumer); + let mut shutdown = Box::pin(shutdown); - while let Some(delivery) = consumer.next().await { - match delivery { - Ok(d) => { - let tag = d.delivery_tag; - match serde_json::from_slice::(&d.data) { - Ok(job) => { - let span = tracing::info_span!("handle_job", request_id = %job.request_id, repo = %job.repo_url, sha = %job.commit_sha); - if let Err(err) = handler(job).instrument(span).await { - error!(error = %err, "handler error; nacking to DLQ"); - channel - .basic_nack(tag, BasicNackOptions { requeue: false, multiple: false }) - .await - .into_diagnostic()?; - } else { - channel - .basic_ack(tag, BasicAckOptions { multiple: false }) - .await - .into_diagnostic()?; + loop { + tokio::select! { + _ = &mut shutdown => { + info!("shutdown requested; canceling consumer"); + // Best-effort cancel; ignore errors during shutdown. + let _ = channel.basic_cancel("orchestrator", BasicCancelOptions { nowait: false }).await; + break; + } + maybe_delivery = consumer.next() => { + match maybe_delivery { + Some(Ok(d)) => { + let tag = d.delivery_tag; + match serde_json::from_slice::(&d.data) { + Ok(job) => { + let span = tracing::info_span!("handle_job", request_id = %job.request_id, repo = %job.repo_url, sha = %job.commit_sha); + if let Err(err) = handler(job).instrument(span).await { + error!(error = %err, "handler error; nacking to DLQ"); + channel + .basic_nack(tag, BasicNackOptions { requeue: false, multiple: false }) + .await + .into_diagnostic()?; + } else { + channel + .basic_ack(tag, BasicAckOptions { multiple: false }) + .await + .into_diagnostic()?; + } + } + Err(e) => { + warn!(error = %e, "failed to deserialize JobRequest; dead-lettering"); + channel + .basic_nack(tag, BasicNackOptions { requeue: false, multiple: false }) + .await + .into_diagnostic()?; + } } } - Err(e) => { - warn!(error = %e, "failed to deserialize JobRequest; dead-lettering"); - channel - .basic_nack(tag, BasicNackOptions { requeue: false, multiple: false }) - .await - .into_diagnostic()?; + Some(Err(e)) => { + error!(error = %e, "consumer delivery error; continuing"); + tokio::time::sleep(Duration::from_millis(200)).await; + } + None => { + // Stream ended; exit loop. + break; } } } - Err(e) => { - error!(error = %e, "consumer delivery error; continuing"); - // Backoff briefly to avoid tight loop on errors - tokio::time::sleep(Duration::from_millis(200)).await; - } } } diff --git a/crates/orchestrator/src/main.rs b/crates/orchestrator/src/main.rs index 7b4c63a..4876875 100644 --- a/crates/orchestrator/src/main.rs +++ b/crates/orchestrator/src/main.rs @@ -3,7 +3,7 @@ mod hypervisor; mod scheduler; mod persist; -use std::{collections::HashMap, path::PathBuf}; +use std::{collections::HashMap, path::PathBuf, time::Duration}; use clap::Parser; use miette::{IntoDiagnostic as _, Result}; @@ -69,6 +69,10 @@ struct Opts { /// OTLP endpoint (e.g., http://localhost:4317) #[arg(long, env = "OTEL_EXPORTER_OTLP_ENDPOINT")] otlp: Option, + + /// Placeholder VM run time in seconds (temporary until agent wiring) + #[arg(long, env = "VM_PLACEHOLDER_RUN_SECS", default_value_t = 300)] + vm_placeholder_run_secs: u64, } #[tokio::main(flavor = "multi_thread")] @@ -102,7 +106,13 @@ async fn main() -> Result<()> { }; // Scheduler - let sched = Scheduler::new(router, opts.max_concurrency, &capacity_map, persist.clone()); + let sched = Scheduler::new( + router, + opts.max_concurrency, + &capacity_map, + persist.clone(), + Duration::from_secs(opts.vm_placeholder_run_secs), + ); let sched_tx = sched.sender(); let scheduler_task = tokio::spawn(async move { if let Err(e) = sched.run().await { @@ -115,8 +125,12 @@ async fn main() -> Result<()> { let mq_cfg_clone = mq_cfg.clone(); let tx_for_consumer = sched_tx.clone(); let persist_for_consumer = persist.clone(); + // Start consumer that can be shut down cooperatively on ctrl-c + let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); let consumer_task = tokio::spawn(async move { - common::consume_jobs(&mq_cfg_clone, move |job| { + common::consume_jobs_until(&mq_cfg_clone, async move { + let _ = shutdown_rx.await; + }, move |job| { let sched_tx = tx_for_consumer.clone(); let cfg = cfg_clone.clone(); let persist = persist_for_consumer.clone(); @@ -165,13 +179,14 @@ async fn main() -> Result<()> { // Wait for ctrl-c tokio::signal::ctrl_c().await.into_diagnostic()?; - // Cancel consumer to stop accepting new jobs - consumer_task.abort(); + // Signal consumer to stop accepting new jobs + let _ = shutdown_tx.send(()); // Drop sender to let scheduler drain and exit drop(sched_tx); - // Wait for scheduler to finish draining + // Wait for consumer and scheduler to finish + let _ = consumer_task.await; let _ = scheduler_task.await; Ok(()) diff --git a/crates/orchestrator/src/scheduler.rs b/crates/orchestrator/src/scheduler.rs index 5706b8b..05caa53 100644 --- a/crates/orchestrator/src/scheduler.rs +++ b/crates/orchestrator/src/scheduler.rs @@ -15,6 +15,7 @@ pub struct Scheduler { global_sem: Arc, label_sems: Arc, persist: Arc, + placeholder_runtime: Duration, } type DashmapType = DashMap>; @@ -25,7 +26,7 @@ pub struct SchedItem { } impl Scheduler { - pub fn new(hv: H, max_concurrency: usize, capacity_map: &HashMap, persist: Arc) -> Self { + pub fn new(hv: H, max_concurrency: usize, capacity_map: &HashMap, persist: Arc, placeholder_runtime: Duration) -> Self { let (tx, rx) = mpsc::channel::(max_concurrency * 4); let label_sems = DashMap::new(); for (label, cap) in capacity_map.iter() { @@ -38,13 +39,14 @@ impl Scheduler { global_sem: Arc::new(Semaphore::new(max_concurrency)), label_sems: Arc::new(label_sems), persist, + placeholder_runtime, } } pub fn sender(&self) -> mpsc::Sender { self.tx.clone() } pub async fn run(self) -> Result<()> { - let Scheduler { hv, mut rx, global_sem, label_sems, persist, .. } = self; + let Scheduler { hv, mut rx, global_sem, label_sems, persist, placeholder_runtime, .. } = self; let mut handles = Vec::new(); while let Some(item) = rx.recv().await { let hv = hv.clone(); @@ -88,8 +90,8 @@ impl Scheduler { let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Running).await; let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Running).await; info!(request_id = %item.ctx.request_id, label = %label_key, "vm started (workload execution placeholder)"); - // Placeholder job runtime - tokio::time::sleep(Duration::from_secs(1)).await; + // Placeholder job runtime (configurable) + tokio::time::sleep(placeholder_runtime).await; // Stop and destroy if let Err(e) = hv.stop(&h, Duration::from_secs(10)).await { error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to stop VM"); @@ -134,11 +136,12 @@ mod tests { // per-label current and peak per_curr: Arc>>, per_peak: Arc>>, + id_to_label: Arc>, } impl MockHypervisor { fn new(active_all: Arc, peak_all: Arc) -> Self { - Self { active_all, peak_all, per_curr: Arc::new(DashMap::new()), per_peak: Arc::new(DashMap::new()) } + Self { active_all, peak_all, per_curr: Arc::new(DashMap::new()), per_peak: Arc::new(DashMap::new()), id_to_label: Arc::new(DashMap::new()) } } fn update_peak(peak: &AtomicUsize, current: usize) { let mut prev = peak.load(Ordering::Relaxed); @@ -156,14 +159,17 @@ mod tests { async fn prepare(&self, spec: &VmSpec, ctx: &JobContext) -> miette::Result { let now = self.active_all.fetch_add(1, Ordering::SeqCst) + 1; Self::update_peak(&self.peak_all, now); - // per-label + // per-label current/peak let entry = self.per_curr.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))); let curr = entry.fetch_add(1, Ordering::SeqCst) + 1; let peak_entry = self.per_peak.entry(spec.label.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))).clone(); Self::update_peak(&peak_entry, curr); + let id = format!("job-{}", ctx.request_id); + self.id_to_label.insert(id.clone(), spec.label.clone()); + Ok(VmHandle { - id: format!("job-{}", ctx.request_id), + id, backend: BackendTag::Noop, work_dir: PathBuf::from("/tmp"), overlay_path: None, @@ -172,9 +178,15 @@ mod tests { } async fn start(&self, _vm: &VmHandle) -> miette::Result<()> { Ok(()) } async fn stop(&self, _vm: &VmHandle, _t: Duration) -> miette::Result<()> { Ok(()) } - async fn destroy(&self, _vm: VmHandle) -> miette::Result<()> { + async fn destroy(&self, vm: VmHandle) -> miette::Result<()> { // decrement overall current self.active_all.fetch_sub(1, Ordering::SeqCst); + // decrement per-label current using id->label map + if let Some(label) = self.id_to_label.remove(&vm.id).map(|e| e.1) { + if let Some(entry) = self.per_curr.get(&label) { + let _ = entry.fetch_sub(1, Ordering::SeqCst); + } + } Ok(()) } async fn state(&self, _vm: &VmHandle) -> miette::Result { Ok(VmState::Prepared) } @@ -208,7 +220,7 @@ mod tests { let mut caps = HashMap::new(); caps.insert("x".to_string(), 10); - let sched = Scheduler::new(hv, 2, &caps, persist); + let sched = Scheduler::new(hv, 2, &caps, persist, Duration::from_millis(10)); let tx = sched.sender(); let run = tokio::spawn(async move { let _ = sched.run().await; }); @@ -235,7 +247,7 @@ mod tests { caps.insert("a".to_string(), 1); caps.insert("b".to_string(), 2); - let sched = Scheduler::new(hv, 4, &caps, persist); + let sched = Scheduler::new(hv, 4, &caps, persist, Duration::from_millis(10)); let tx = sched.sender(); let run = tokio::spawn(async move { let _ = sched.run().await; });